src/Pure/sorts.ML
author wenzelm
Wed Dec 31 15:30:10 2008 +0100 (2008-12-31)
changeset 29269 5c25a2012975
parent 28922 ac2c34cad840
child 29972 aee7610106fd
child 30240 5b25fee0362c
permissions -rw-r--r--
moved term order operations to structure TermOrd (cf. Pure/term_ord.ML);
tuned signature of structure Term;
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort OrdList.T
    18   val subset: sort OrdList.T * sort OrdList.T -> bool
    19   val union: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    20   val subtract: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    21   val remove_sort: sort -> sort OrdList.T -> sort OrdList.T
    22   val insert_sort: sort -> sort OrdList.T -> sort OrdList.T
    23   val insert_typ: typ -> sort OrdList.T -> sort OrdList.T
    24   val insert_typs: typ list -> sort OrdList.T -> sort OrdList.T
    25   val insert_term: term -> sort OrdList.T -> sort OrdList.T
    26   val insert_terms: term list -> sort OrdList.T -> sort OrdList.T
    27   type algebra
    28   val rep_algebra: algebra ->
    29    {classes: serial Graph.T,
    30     arities: (class * (class * sort list)) list Symtab.table}
    31   val all_classes: algebra -> class list
    32   val super_classes: algebra -> class -> class list
    33   val class_less: algebra -> class * class -> bool
    34   val class_le: algebra -> class * class -> bool
    35   val sort_eq: algebra -> sort * sort -> bool
    36   val sort_le: algebra -> sort * sort -> bool
    37   val sorts_le: algebra -> sort list * sort list -> bool
    38   val inter_sort: algebra -> sort * sort -> sort
    39   val minimize_sort: algebra -> sort -> sort
    40   val complete_sort: algebra -> sort -> sort
    41   val minimal_sorts: algebra -> sort list -> sort OrdList.T
    42   val certify_class: algebra -> class -> class    (*exception TYPE*)
    43   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    44   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    45   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    46   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    47   val empty_algebra: algebra
    48   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    49   val classrels_of: algebra -> (class * class list) list
    50   val instances_of: algebra -> (string * class) list
    51   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list)
    52     -> algebra -> (sort -> sort) * algebra
    53   type class_error
    54   val class_error: Pretty.pp -> class_error -> string
    55   exception CLASS_ERROR of class_error
    56   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    57   val meet_sort: algebra -> typ * sort
    58     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    59   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    60   val of_sort: algebra -> typ * sort -> bool
    61   val weaken: algebra -> ('a * class -> class -> 'a) -> 'a * class -> class -> 'a
    62   val of_sort_derivation: Pretty.pp -> algebra ->
    63     {class_relation: 'a * class -> class -> 'a,
    64      type_constructor: string -> ('a * class) list list -> class -> 'a,
    65      type_variable: typ -> ('a * class) list} ->
    66     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    67   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    68 end;
    69 
    70 structure Sorts: SORTS =
    71 struct
    72 
    73 
    74 (** ordered lists of sorts **)
    75 
    76 val make = OrdList.make TermOrd.sort_ord;
    77 val op subset = OrdList.subset TermOrd.sort_ord;
    78 val op union = OrdList.union TermOrd.sort_ord;
    79 val subtract = OrdList.subtract TermOrd.sort_ord;
    80 
    81 val remove_sort = OrdList.remove TermOrd.sort_ord;
    82 val insert_sort = OrdList.insert TermOrd.sort_ord;
    83 
    84 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    85   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    86   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    87 and insert_typs [] Ss = Ss
    88   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    89 
    90 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    91   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    92   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    93   | insert_term (Bound _) Ss = Ss
    94   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    95   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    96 
    97 fun insert_terms [] Ss = Ss
    98   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    99 
   100 
   101 
   102 (** order-sorted algebra **)
   103 
   104 (*
   105   classes: graph representing class declarations together with proper
   106     subclass relation, which needs to be transitive and acyclic.
   107 
   108   arities: table of association lists of all type arities; (t, ars)
   109     means that type constructor t has the arities ars; an element
   110     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   111     via c0 <= c.  "Coregularity" of the arities structure requires
   112     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   113     c1 <= c2 holds Ss1 <= Ss2.
   114 *)
   115 
   116 datatype algebra = Algebra of
   117  {classes: serial Graph.T,
   118   arities: (class * (class * sort list)) list Symtab.table};
   119 
   120 fun rep_algebra (Algebra args) = args;
   121 
   122 val classes_of = #classes o rep_algebra;
   123 val arities_of = #arities o rep_algebra;
   124 
   125 fun make_algebra (classes, arities) =
   126   Algebra {classes = classes, arities = arities};
   127 
   128 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   129 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   130 
   131 
   132 (* classes *)
   133 
   134 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   135 
   136 val super_classes = Graph.imm_succs o classes_of;
   137 
   138 
   139 (* class relations *)
   140 
   141 val class_less = Graph.is_edge o classes_of;
   142 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   143 
   144 
   145 (* sort relations *)
   146 
   147 fun sort_le algebra (S1, S2) =
   148   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   149 
   150 fun sorts_le algebra (Ss1, Ss2) =
   151   ListPair.all (sort_le algebra) (Ss1, Ss2);
   152 
   153 fun sort_eq algebra (S1, S2) =
   154   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   155 
   156 
   157 (* intersection *)
   158 
   159 fun inter_class algebra c S =
   160   let
   161     fun intr [] = [c]
   162       | intr (S' as c' :: c's) =
   163           if class_le algebra (c', c) then S'
   164           else if class_le algebra (c, c') then intr c's
   165           else c' :: intr c's
   166   in intr S end;
   167 
   168 fun inter_sort algebra (S1, S2) =
   169   sort_strings (fold (inter_class algebra) S1 S2);
   170 
   171 
   172 (* normal forms *)
   173 
   174 fun minimize_sort _ [] = []
   175   | minimize_sort _ (S as [_]) = S
   176   | minimize_sort algebra S =
   177       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   178       |> sort_distinct string_ord;
   179 
   180 fun complete_sort algebra =
   181   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   182 
   183 fun minimal_sorts algebra raw_sorts =
   184   let
   185     fun le S1 S2 = sort_le algebra (S1, S2);
   186     val sorts = make (map (minimize_sort algebra) raw_sorts);
   187   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   188 
   189 
   190 (* certify *)
   191 
   192 fun certify_class algebra c =
   193   if can (Graph.get_node (classes_of algebra)) c then c
   194   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   195 
   196 fun certify_sort classes = minimize_sort classes o map (certify_class classes);
   197 
   198 
   199 
   200 (** build algebras **)
   201 
   202 (* classes *)
   203 
   204 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   205 
   206 fun err_cyclic_classes pp css =
   207   error (cat_lines (map (fn cs =>
   208     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   209 
   210 fun add_class pp (c, cs) = map_classes (fn classes =>
   211   let
   212     val classes' = classes |> Graph.new_node (c, serial ())
   213       handle Graph.DUP dup => err_dup_class dup;
   214     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   215       handle Graph.CYCLES css => err_cyclic_classes pp css;
   216   in classes'' end);
   217 
   218 
   219 (* arities *)
   220 
   221 local
   222 
   223 fun for_classes _ NONE = ""
   224   | for_classes pp (SOME (c1, c2)) =
   225       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   226 
   227 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   228   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   229     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   230     Pretty.string_of_arity pp (t, Ss', [c']));
   231 
   232 fun coregular pp algebra t (c, (c0, Ss)) ars =
   233   let
   234     fun conflict (c', (_, Ss')) =
   235       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   236         SOME ((c, c'), (c', Ss'))
   237       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   238         SOME ((c', c), (c', Ss'))
   239       else NONE;
   240   in
   241     (case get_first conflict ars of
   242       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   243     | NONE => (c, (c0, Ss)) :: ars)
   244   end;
   245 
   246 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   247 
   248 fun insert pp algebra t (c, (c0, Ss)) ars =
   249   (case AList.lookup (op =) ars c of
   250     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   251   | SOME (_, Ss') =>
   252       if sorts_le algebra (Ss, Ss') then ars
   253       else if sorts_le algebra (Ss', Ss) then
   254         coregular pp algebra t (c, (c0, Ss))
   255           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   256       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   257 
   258 fun insert_ars pp algebra (t, ars) arities =
   259   let val ars' =
   260     Symtab.lookup_list arities t
   261     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   262   in Symtab.update (t, ars') arities end;
   263 
   264 in
   265 
   266 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   267 
   268 fun add_arities_table pp algebra =
   269   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   270 
   271 end;
   272 
   273 
   274 (* classrel *)
   275 
   276 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   277   Symtab.empty
   278   |> add_arities_table pp algebra arities);
   279 
   280 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   281   classes |> Graph.add_edge_trans_acyclic rel
   282     handle Graph.CYCLES css => err_cyclic_classes pp css);
   283 
   284 
   285 (* empty and merge *)
   286 
   287 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   288 
   289 fun merge_algebra pp
   290    (Algebra {classes = classes1, arities = arities1},
   291     Algebra {classes = classes2, arities = arities2}) =
   292   let
   293     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   294       handle Graph.DUP c => err_dup_class c
   295           | Graph.CYCLES css => err_cyclic_classes pp css;
   296     val algebra0 = make_algebra (classes', Symtab.empty);
   297     val arities' = Symtab.empty
   298       |> add_arities_table pp algebra0 arities1
   299       |> add_arities_table pp algebra0 arities2;
   300   in make_algebra (classes', arities') end;
   301 
   302 
   303 (* algebra projections *)
   304 
   305 fun classrels_of (Algebra {classes, ...}) =
   306   map (fn [c] => (c, Graph.imm_succs classes c)) (rev (Graph.strong_conn classes));
   307 
   308 fun instances_of (Algebra {arities, ...}) =
   309   Symtab.fold (fn (a, cs) => append (map (pair a o fst) cs)) arities [];
   310 
   311 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   312   let
   313     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   314     fun restrict_arity tyco (c, (_, Ss)) =
   315       if P c then
   316         SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) (sargs (c, tyco))
   317           |> map restrict_sort))
   318       else NONE;
   319     val classes' = classes |> Graph.subgraph P;
   320     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   321   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   322 
   323 
   324 
   325 (** sorts of types **)
   326 
   327 (* errors -- delayed message composition *)
   328 
   329 datatype class_error =
   330   NoClassrel of class * class |
   331   NoArity of string * class |
   332   NoSubsort of sort * sort;
   333 
   334 fun class_error pp (NoClassrel (c1, c2)) =
   335       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   336   | class_error pp (NoArity (a, c)) =
   337       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c])
   338   | class_error pp (NoSubsort (S1, S2)) =
   339      "Cannot derive subsort relation " ^ Pretty.string_of_sort pp S1
   340        ^ " < " ^ Pretty.string_of_sort pp S2;
   341 
   342 exception CLASS_ERROR of class_error;
   343 
   344 
   345 (* mg_domain *)
   346 
   347 fun mg_domain algebra a S =
   348   let
   349     val arities = arities_of algebra;
   350     fun dom c =
   351       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   352         NONE => raise CLASS_ERROR (NoArity (a, c))
   353       | SOME (_, Ss) => Ss);
   354     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   355   in
   356     (case S of
   357       [] => raise Fail "Unknown domain of empty intersection"
   358     | c :: cs => fold dom_inter cs (dom c))
   359   end;
   360 
   361 
   362 (* meet_sort *)
   363 
   364 fun meet_sort algebra =
   365   let
   366     fun inters S S' = inter_sort algebra (S, S');
   367     fun meet _ [] = I
   368       | meet (TFree (_, S)) S' =
   369           if sort_le algebra (S, S') then I
   370           else raise CLASS_ERROR (NoSubsort (S, S'))
   371       | meet (TVar (v, S)) S' =
   372           if sort_le algebra (S, S') then I
   373           else Vartab.map_default (v, S) (inters S')
   374       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   375   in uncurry meet end;
   376 
   377 fun meet_sort_typ algebra (T, S) =
   378   let
   379     val tab = meet_sort algebra (T, S) Vartab.empty;
   380   in Term.map_type_tvar (fn (v, _) =>
   381     TVar (v, (the o Vartab.lookup tab) v))
   382   end;
   383 
   384 
   385 (* of_sort *)
   386 
   387 fun of_sort algebra =
   388   let
   389     fun ofS (_, []) = true
   390       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   391       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   392       | ofS (Type (a, Ts), S) =
   393           let val Ss = mg_domain algebra a S in
   394             ListPair.all ofS (Ts, Ss)
   395           end handle CLASS_ERROR _ => false;
   396   in ofS end;
   397 
   398 
   399 (* animating derivations *)
   400 
   401 fun weaken algebra class_relation =
   402   let
   403     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   404       | path (x, _) = x;
   405   in fn (x, c1) => fn c2 =>
   406     (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   407       [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   408     | cs :: _ => path (x, cs))
   409   end;
   410 
   411 fun of_sort_derivation pp algebra {class_relation, type_constructor, type_variable} =
   412   let
   413     val weaken = weaken algebra class_relation;
   414     val arities = arities_of algebra;
   415 
   416     fun weakens S1 S2 = S2 |> map (fn c2 =>
   417       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   418         SOME d1 => weaken d1 c2
   419       | NONE => raise CLASS_ERROR (NoSubsort (map #2 S1, S2))));
   420 
   421     fun derive _ [] = []
   422       | derive (Type (a, Ts)) S =
   423           let
   424             val Ss = mg_domain algebra a S;
   425             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   426           in
   427             S |> map (fn c =>
   428               let
   429                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   430                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   431               in weaken (type_constructor a dom' c0, c0) c end)
   432           end
   433       | derive T S = weakens (type_variable T) S;
   434   in uncurry derive end;
   435 
   436 
   437 (* witness_sorts *)
   438 
   439 fun witness_sorts algebra types hyps sorts =
   440   let
   441     fun le S1 S2 = sort_le algebra (S1, S2);
   442     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   443     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   444     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   445 
   446     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   447       | witn_sort path S (solved, failed) =
   448           if exists (le S) failed then (NONE, (solved, failed))
   449           else
   450             (case get_first (get_solved S) solved of
   451               SOME w => (SOME w, (solved, failed))
   452             | NONE =>
   453                 (case get_first (get_hyp S) hyps of
   454                   SOME w => (SOME w, (w :: solved, failed))
   455                 | NONE => witn_types path types S (solved, failed)))
   456 
   457     and witn_sorts path x = fold_map (witn_sort path) x
   458 
   459     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   460       | witn_types path (t :: ts) S solved_failed =
   461           (case mg_dom t S of
   462             SOME SS =>
   463               (*do not descend into stronger args (achieving termination)*)
   464               if exists (fn D => le D S orelse exists (le D) path) SS then
   465                 witn_types path ts S solved_failed
   466               else
   467                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   468                   if forall is_some ws then
   469                     let val w = (Type (t, map (#1 o the) ws), S)
   470                     in (SOME w, (w :: solved', failed')) end
   471                   else witn_types path ts S (solved', failed')
   472                 end
   473           | NONE => witn_types path ts S solved_failed);
   474 
   475   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   476 
   477 end;