src/Pure/sorts.ML
author wenzelm
Thu Jul 05 00:06:24 2007 +0200 (2007-07-05)
changeset 23585 f07ef41ffb87
parent 22570 f166a5416b3f
child 23655 d2d1138e0ddc
permissions -rw-r--r--
sort_le: tuned eq case;
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val union: sort list -> sort list -> sort list
    19   val subtract: sort list -> sort list -> sort list
    20   val remove_sort: sort -> sort list -> sort list
    21   val insert_sort: sort -> sort list -> sort list
    22   val insert_typ: typ -> sort list -> sort list
    23   val insert_typs: typ list -> sort list -> sort list
    24   val insert_term: term -> sort list -> sort list
    25   val insert_terms: term list -> sort list -> sort list
    26   type algebra
    27   val rep_algebra: algebra ->
    28    {classes: serial Graph.T,
    29     arities: (class * (class * sort list)) list Symtab.table}
    30   val all_classes: algebra -> class list
    31   val minimal_classes: algebra -> class list
    32   val super_classes: algebra -> class -> class list
    33   val class_less: algebra -> class * class -> bool
    34   val class_le: algebra -> class * class -> bool
    35   val sort_eq: algebra -> sort * sort -> bool
    36   val sort_le: algebra -> sort * sort -> bool
    37   val sorts_le: algebra -> sort list * sort list -> bool
    38   val inter_sort: algebra -> sort * sort -> sort
    39   val certify_class: algebra -> class -> class    (*exception TYPE*)
    40   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    41   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    42   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    43   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    44   val empty_algebra: algebra
    45   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    46   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list)
    47     -> algebra -> (sort -> sort) * algebra
    48   type class_error
    49   val msg_class_error: Pretty.pp -> class_error -> string
    50   val class_error: Pretty.pp -> class_error -> 'a
    51   exception CLASS_ERROR of class_error
    52   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    53   val of_sort: algebra -> typ * sort -> bool
    54   val of_sort_derivation: Pretty.pp -> algebra ->
    55     {class_relation: 'a * class -> class -> 'a,
    56      type_constructor: string -> ('a * class) list list -> class -> 'a,
    57      type_variable: typ -> ('a * class) list} ->
    58     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    59   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    60 end;
    61 
    62 structure Sorts: SORTS =
    63 struct
    64 
    65 
    66 (** ordered lists of sorts **)
    67 
    68 val op union = OrdList.union Term.sort_ord;
    69 val subtract = OrdList.subtract Term.sort_ord;
    70 
    71 val remove_sort = OrdList.remove Term.sort_ord;
    72 val insert_sort = OrdList.insert Term.sort_ord;
    73 
    74 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    75   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    76   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    77 and insert_typs [] Ss = Ss
    78   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    79 
    80 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    81   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    82   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    83   | insert_term (Bound _) Ss = Ss
    84   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    85   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    86 
    87 fun insert_terms [] Ss = Ss
    88   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    89 
    90 
    91 
    92 (** order-sorted algebra **)
    93 
    94 (*
    95   classes: graph representing class declarations together with proper
    96     subclass relation, which needs to be transitive and acyclic.
    97 
    98   arities: table of association lists of all type arities; (t, ars)
    99     means that type constructor t has the arities ars; an element
   100     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   101     via c0 <= c.  "Coregularity" of the arities structure requires
   102     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   103     c1 <= c2 holds Ss1 <= Ss2.
   104 *)
   105 
   106 datatype algebra = Algebra of
   107  {classes: serial Graph.T,
   108   arities: (class * (class * sort list)) list Symtab.table};
   109 
   110 fun rep_algebra (Algebra args) = args;
   111 
   112 val classes_of = #classes o rep_algebra;
   113 val arities_of = #arities o rep_algebra;
   114 
   115 fun make_algebra (classes, arities) =
   116   Algebra {classes = classes, arities = arities};
   117 
   118 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   119 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   120 
   121 
   122 (* classes *)
   123 
   124 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   125 
   126 val minimal_classes = Graph.minimals o classes_of;
   127 val super_classes = Graph.imm_succs o classes_of;
   128 
   129 
   130 (* class relations *)
   131 
   132 val class_less = Graph.is_edge o classes_of;
   133 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   134 
   135 
   136 (* sort relations *)
   137 
   138 fun sort_le algebra (S1, S2) =
   139   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   140 
   141 fun sorts_le algebra (Ss1, Ss2) =
   142   ListPair.all (sort_le algebra) (Ss1, Ss2);
   143 
   144 fun sort_eq algebra (S1, S2) =
   145   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   146 
   147 
   148 (* intersection *)
   149 
   150 fun inter_class algebra c S =
   151   let
   152     fun intr [] = [c]
   153       | intr (S' as c' :: c's) =
   154           if class_le algebra (c', c) then S'
   155           else if class_le algebra (c, c') then intr c's
   156           else c' :: intr c's
   157   in intr S end;
   158 
   159 fun inter_sort algebra (S1, S2) =
   160   sort_strings (fold (inter_class algebra) S1 S2);
   161 
   162 
   163 (* normal form *)
   164 
   165 fun norm_sort _ [] = []
   166   | norm_sort _ (S as [_]) = S
   167   | norm_sort algebra S =
   168       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   169       |> sort_distinct string_ord;
   170 
   171 
   172 (* certify *)
   173 
   174 fun certify_class algebra c =
   175   if can (Graph.get_node (classes_of algebra)) c then c
   176   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   177 
   178 fun certify_sort classes = norm_sort classes o map (certify_class classes);
   179 
   180 
   181 
   182 (** build algebras **)
   183 
   184 (* classes *)
   185 
   186 fun err_dup_classes cs =
   187   error ("Duplicate declaration of class(es): " ^ commas_quote cs);
   188 
   189 fun err_cyclic_classes pp css =
   190   error (cat_lines (map (fn cs =>
   191     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   192 
   193 fun add_class pp (c, cs) = map_classes (fn classes =>
   194   let
   195     val classes' = classes |> Graph.new_node (c, serial ())
   196       handle Graph.DUP dup => err_dup_classes [dup];
   197     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   198       handle Graph.CYCLES css => err_cyclic_classes pp css;
   199   in classes'' end);
   200 
   201 
   202 (* arities *)
   203 
   204 local
   205 
   206 fun for_classes _ NONE = ""
   207   | for_classes pp (SOME (c1, c2)) =
   208       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   209 
   210 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   211   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   212     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   213     Pretty.string_of_arity pp (t, Ss', [c']));
   214 
   215 fun coregular pp algebra t (c, (c0, Ss)) ars =
   216   let
   217     fun conflict (c', (_, Ss')) =
   218       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   219         SOME ((c, c'), (c', Ss'))
   220       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   221         SOME ((c', c), (c', Ss'))
   222       else NONE;
   223   in
   224     (case get_first conflict ars of
   225       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   226     | NONE => (c, (c0, Ss)) :: ars)
   227   end;
   228 
   229 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   230 
   231 fun insert pp algebra t (c, (c0, Ss)) ars =
   232   (case AList.lookup (op =) ars c of
   233     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   234   | SOME (_, Ss') =>
   235       if sorts_le algebra (Ss, Ss') then ars
   236       else if sorts_le algebra (Ss', Ss) then
   237         coregular pp algebra t (c, (c0, Ss))
   238           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   239       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   240 
   241 fun insert_ars pp algebra (t, ars) arities =
   242   let val ars' =
   243     Symtab.lookup_list arities t
   244     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   245   in Symtab.update (t, ars') arities end;
   246 
   247 in
   248 
   249 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   250 
   251 fun add_arities_table pp algebra =
   252   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   253 
   254 end;
   255 
   256 
   257 (* classrel *)
   258 
   259 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   260   Symtab.empty
   261   |> add_arities_table pp algebra arities);
   262 
   263 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   264   classes |> Graph.add_edge_trans_acyclic rel
   265     handle Graph.CYCLES css => err_cyclic_classes pp css);
   266 
   267 
   268 (* empty and merge *)
   269 
   270 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   271 
   272 fun merge_algebra pp
   273    (Algebra {classes = classes1, arities = arities1},
   274     Algebra {classes = classes2, arities = arities2}) =
   275   let
   276     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   277       handle Graph.DUPS cs => err_dup_classes cs
   278           | Graph.CYCLES css => err_cyclic_classes pp css;
   279     val algebra0 = make_algebra (classes', Symtab.empty);
   280     val arities' = Symtab.empty
   281       |> add_arities_table pp algebra0 arities1
   282       |> add_arities_table pp algebra0 arities2;
   283   in make_algebra (classes', arities') end;
   284 
   285 
   286 (* subalgebra *)
   287 
   288 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   289   let
   290     val restrict_sort = norm_sort algebra o filter P o Graph.all_succs classes;
   291     fun restrict_arity tyco (c, (_, Ss)) =
   292       if P c then
   293         SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) (sargs (c, tyco))
   294           |> map restrict_sort))
   295       else NONE;
   296     val classes' = classes |> Graph.subgraph P;
   297     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   298   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   299 
   300 
   301 
   302 (** sorts of types **)
   303 
   304 (* errors *)
   305 
   306 datatype class_error = NoClassrel of class * class | NoArity of string * class;
   307 
   308 fun msg_class_error pp (NoClassrel (c1, c2)) =
   309       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   310   | msg_class_error pp (NoArity (a, c)) =
   311       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c]);
   312 
   313 fun class_error pp = error o msg_class_error pp;
   314 
   315 exception CLASS_ERROR of class_error;
   316 
   317 
   318 (* mg_domain *)
   319 
   320 fun mg_domain algebra a S =
   321   let
   322     val arities = arities_of algebra;
   323     fun dom c =
   324       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   325         NONE => raise CLASS_ERROR (NoArity (a, c))
   326       | SOME (_, Ss) => Ss);
   327     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   328   in
   329     (case S of
   330       [] => raise Fail "Unknown domain of empty intersection"
   331     | c :: cs => fold dom_inter cs (dom c))
   332   end;
   333 
   334 
   335 (* of_sort *)
   336 
   337 fun of_sort algebra =
   338   let
   339     fun ofS (_, []) = true
   340       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   341       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   342       | ofS (Type (a, Ts), S) =
   343           let val Ss = mg_domain algebra a S in
   344             ListPair.all ofS (Ts, Ss)
   345           end handle CLASS_ERROR _ => false;
   346   in ofS end;
   347 
   348 
   349 (* of_sort_derivation *)
   350 
   351 fun of_sort_derivation pp algebra {class_relation, type_constructor, type_variable} =
   352   let
   353     val {classes, arities} = rep_algebra algebra;
   354     fun weaken_path (x, c1 :: c2 :: cs) =
   355           weaken_path (class_relation (x, c1) c2, c2 :: cs)
   356       | weaken_path (x, _) = x;
   357     fun weaken (x, c1) c2 =
   358       (case Graph.irreducible_paths classes (c1, c2) of
   359         [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   360       | cs :: _ => weaken_path (x, cs));
   361 
   362     fun weakens S1 S2 = S2 |> map (fn c2 =>
   363       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   364         SOME d1 => weaken d1 c2
   365       | NONE => error ("Cannot derive subsort relation " ^
   366           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   367 
   368     fun derive _ [] = []
   369       | derive (Type (a, Ts)) S =
   370           let
   371             val Ss = mg_domain algebra a S;
   372             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   373           in
   374             S |> map (fn c =>
   375               let
   376                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   377                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   378               in weaken (type_constructor a dom' c0, c0) c end)
   379           end
   380       | derive T S = weakens (type_variable T) S;
   381   in uncurry derive end;
   382 
   383 
   384 (* witness_sorts *)
   385 
   386 fun witness_sorts algebra types hyps sorts =
   387   let
   388     fun le S1 S2 = sort_le algebra (S1, S2);
   389     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   390     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   391     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   392 
   393     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   394       | witn_sort path S (solved, failed) =
   395           if exists (le S) failed then (NONE, (solved, failed))
   396           else
   397             (case get_first (get_solved S) solved of
   398               SOME w => (SOME w, (solved, failed))
   399             | NONE =>
   400                 (case get_first (get_hyp S) hyps of
   401                   SOME w => (SOME w, (w :: solved, failed))
   402                 | NONE => witn_types path types S (solved, failed)))
   403 
   404     and witn_sorts path x = fold_map (witn_sort path) x
   405 
   406     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   407       | witn_types path (t :: ts) S solved_failed =
   408           (case mg_dom t S of
   409             SOME SS =>
   410               (*do not descend into stronger args (achieving termination)*)
   411               if exists (fn D => le D S orelse exists (le D) path) SS then
   412                 witn_types path ts S solved_failed
   413               else
   414                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   415                   if forall is_some ws then
   416                     let val w = (Type (t, map (#1 o the) ws), S)
   417                     in (SOME w, (w :: solved', failed')) end
   418                   else witn_types path ts S (solved', failed')
   419                 end
   420           | NONE => witn_types path ts S solved_failed);
   421 
   422   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   423 
   424 end;