src/Pure/sorts.ML
author haftmann
Wed Feb 18 19:18:33 2009 +0100 (2009-02-18)
changeset 29972 aee7610106fd
parent 29269 5c25a2012975
child 30060 672012330c4e
permissions -rw-r--r--
sort instances wrt. to class hierarchy
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort OrdList.T
    18   val subset: sort OrdList.T * sort OrdList.T -> bool
    19   val union: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    20   val subtract: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    21   val remove_sort: sort -> sort OrdList.T -> sort OrdList.T
    22   val insert_sort: sort -> sort OrdList.T -> sort OrdList.T
    23   val insert_typ: typ -> sort OrdList.T -> sort OrdList.T
    24   val insert_typs: typ list -> sort OrdList.T -> sort OrdList.T
    25   val insert_term: term -> sort OrdList.T -> sort OrdList.T
    26   val insert_terms: term list -> sort OrdList.T -> sort OrdList.T
    27   type algebra
    28   val rep_algebra: algebra ->
    29    {classes: serial Graph.T,
    30     arities: (class * (class * sort list)) list Symtab.table}
    31   val all_classes: algebra -> class list
    32   val super_classes: algebra -> class -> class list
    33   val class_less: algebra -> class * class -> bool
    34   val class_le: algebra -> class * class -> bool
    35   val sort_eq: algebra -> sort * sort -> bool
    36   val sort_le: algebra -> sort * sort -> bool
    37   val sorts_le: algebra -> sort list * sort list -> bool
    38   val inter_sort: algebra -> sort * sort -> sort
    39   val minimize_sort: algebra -> sort -> sort
    40   val complete_sort: algebra -> sort -> sort
    41   val minimal_sorts: algebra -> sort list -> sort OrdList.T
    42   val certify_class: algebra -> class -> class    (*exception TYPE*)
    43   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    44   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    45   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    46   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    47   val empty_algebra: algebra
    48   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    49   val classrels_of: algebra -> (class * class list) list
    50   val instances_of: algebra -> (string * class) list
    51   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list)
    52     -> algebra -> (sort -> sort) * algebra
    53   type class_error
    54   val class_error: Pretty.pp -> class_error -> string
    55   exception CLASS_ERROR of class_error
    56   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    57   val meet_sort: algebra -> typ * sort
    58     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    59   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    60   val of_sort: algebra -> typ * sort -> bool
    61   val weaken: algebra -> ('a * class -> class -> 'a) -> 'a * class -> class -> 'a
    62   val of_sort_derivation: Pretty.pp -> algebra ->
    63     {class_relation: 'a * class -> class -> 'a,
    64      type_constructor: string -> ('a * class) list list -> class -> 'a,
    65      type_variable: typ -> ('a * class) list} ->
    66     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    67   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    68 end;
    69 
    70 structure Sorts: SORTS =
    71 struct
    72 
    73 
    74 (** ordered lists of sorts **)
    75 
    76 val make = OrdList.make TermOrd.sort_ord;
    77 val op subset = OrdList.subset TermOrd.sort_ord;
    78 val op union = OrdList.union TermOrd.sort_ord;
    79 val subtract = OrdList.subtract TermOrd.sort_ord;
    80 
    81 val remove_sort = OrdList.remove TermOrd.sort_ord;
    82 val insert_sort = OrdList.insert TermOrd.sort_ord;
    83 
    84 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    85   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    86   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    87 and insert_typs [] Ss = Ss
    88   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    89 
    90 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    91   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    92   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    93   | insert_term (Bound _) Ss = Ss
    94   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    95   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    96 
    97 fun insert_terms [] Ss = Ss
    98   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    99 
   100 
   101 
   102 (** order-sorted algebra **)
   103 
   104 (*
   105   classes: graph representing class declarations together with proper
   106     subclass relation, which needs to be transitive and acyclic.
   107 
   108   arities: table of association lists of all type arities; (t, ars)
   109     means that type constructor t has the arities ars; an element
   110     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   111     via c0 <= c.  "Coregularity" of the arities structure requires
   112     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   113     c1 <= c2 holds Ss1 <= Ss2.
   114 *)
   115 
   116 datatype algebra = Algebra of
   117  {classes: serial Graph.T,
   118   arities: (class * (class * sort list)) list Symtab.table};
   119 
   120 fun rep_algebra (Algebra args) = args;
   121 
   122 val classes_of = #classes o rep_algebra;
   123 val arities_of = #arities o rep_algebra;
   124 
   125 fun make_algebra (classes, arities) =
   126   Algebra {classes = classes, arities = arities};
   127 
   128 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   129 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   130 
   131 
   132 (* classes *)
   133 
   134 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   135 
   136 val super_classes = Graph.imm_succs o classes_of;
   137 
   138 
   139 (* class relations *)
   140 
   141 val class_less = Graph.is_edge o classes_of;
   142 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   143 
   144 
   145 (* sort relations *)
   146 
   147 fun sort_le algebra (S1, S2) =
   148   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   149 
   150 fun sorts_le algebra (Ss1, Ss2) =
   151   ListPair.all (sort_le algebra) (Ss1, Ss2);
   152 
   153 fun sort_eq algebra (S1, S2) =
   154   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   155 
   156 
   157 (* intersection *)
   158 
   159 fun inter_class algebra c S =
   160   let
   161     fun intr [] = [c]
   162       | intr (S' as c' :: c's) =
   163           if class_le algebra (c', c) then S'
   164           else if class_le algebra (c, c') then intr c's
   165           else c' :: intr c's
   166   in intr S end;
   167 
   168 fun inter_sort algebra (S1, S2) =
   169   sort_strings (fold (inter_class algebra) S1 S2);
   170 
   171 
   172 (* normal forms *)
   173 
   174 fun minimize_sort _ [] = []
   175   | minimize_sort _ (S as [_]) = S
   176   | minimize_sort algebra S =
   177       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   178       |> sort_distinct string_ord;
   179 
   180 fun complete_sort algebra =
   181   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   182 
   183 fun minimal_sorts algebra raw_sorts =
   184   let
   185     fun le S1 S2 = sort_le algebra (S1, S2);
   186     val sorts = make (map (minimize_sort algebra) raw_sorts);
   187   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   188 
   189 
   190 (* certify *)
   191 
   192 fun certify_class algebra c =
   193   if can (Graph.get_node (classes_of algebra)) c then c
   194   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   195 
   196 fun certify_sort classes = minimize_sort classes o map (certify_class classes);
   197 
   198 
   199 
   200 (** build algebras **)
   201 
   202 (* classes *)
   203 
   204 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   205 
   206 fun err_cyclic_classes pp css =
   207   error (cat_lines (map (fn cs =>
   208     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   209 
   210 fun add_class pp (c, cs) = map_classes (fn classes =>
   211   let
   212     val classes' = classes |> Graph.new_node (c, serial ())
   213       handle Graph.DUP dup => err_dup_class dup;
   214     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   215       handle Graph.CYCLES css => err_cyclic_classes pp css;
   216   in classes'' end);
   217 
   218 
   219 (* arities *)
   220 
   221 local
   222 
   223 fun for_classes _ NONE = ""
   224   | for_classes pp (SOME (c1, c2)) =
   225       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   226 
   227 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   228   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   229     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   230     Pretty.string_of_arity pp (t, Ss', [c']));
   231 
   232 fun coregular pp algebra t (c, (c0, Ss)) ars =
   233   let
   234     fun conflict (c', (_, Ss')) =
   235       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   236         SOME ((c, c'), (c', Ss'))
   237       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   238         SOME ((c', c), (c', Ss'))
   239       else NONE;
   240   in
   241     (case get_first conflict ars of
   242       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   243     | NONE => (c, (c0, Ss)) :: ars)
   244   end;
   245 
   246 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   247 
   248 fun insert pp algebra t (c, (c0, Ss)) ars =
   249   (case AList.lookup (op =) ars c of
   250     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   251   | SOME (_, Ss') =>
   252       if sorts_le algebra (Ss, Ss') then ars
   253       else if sorts_le algebra (Ss', Ss) then
   254         coregular pp algebra t (c, (c0, Ss))
   255           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   256       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   257 
   258 fun insert_ars pp algebra (t, ars) arities =
   259   let val ars' =
   260     Symtab.lookup_list arities t
   261     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   262   in Symtab.update (t, ars') arities end;
   263 
   264 in
   265 
   266 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   267 
   268 fun add_arities_table pp algebra =
   269   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   270 
   271 end;
   272 
   273 
   274 (* classrel *)
   275 
   276 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   277   Symtab.empty
   278   |> add_arities_table pp algebra arities);
   279 
   280 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   281   classes |> Graph.add_edge_trans_acyclic rel
   282     handle Graph.CYCLES css => err_cyclic_classes pp css);
   283 
   284 
   285 (* empty and merge *)
   286 
   287 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   288 
   289 fun merge_algebra pp
   290    (Algebra {classes = classes1, arities = arities1},
   291     Algebra {classes = classes2, arities = arities2}) =
   292   let
   293     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   294       handle Graph.DUP c => err_dup_class c
   295           | Graph.CYCLES css => err_cyclic_classes pp css;
   296     val algebra0 = make_algebra (classes', Symtab.empty);
   297     val arities' = Symtab.empty
   298       |> add_arities_table pp algebra0 arities1
   299       |> add_arities_table pp algebra0 arities2;
   300   in make_algebra (classes', arities') end;
   301 
   302 
   303 (* algebra projections *)
   304 
   305 val sorted_classes = rev o flat o Graph.strong_conn o classes_of;
   306 
   307 fun classrels_of algebra = 
   308   map (fn c => (c, Graph.imm_succs (classes_of algebra) c)) (sorted_classes algebra);
   309 
   310 fun instances_of algebra =
   311   let
   312     val arities = arities_of algebra;
   313     val all_classes = sorted_classes algebra;
   314     fun sort_classes cs = filter (member (op = o apsnd fst) cs)
   315       all_classes;
   316   in
   317     Symtab.fold (fn (a, cs) => append ((map (pair a) o sort_classes) cs))
   318       arities []
   319   end;
   320 
   321 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   322   let
   323     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   324     fun restrict_arity tyco (c, (_, Ss)) =
   325       if P c then
   326         SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) (sargs (c, tyco))
   327           |> map restrict_sort))
   328       else NONE;
   329     val classes' = classes |> Graph.subgraph P;
   330     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   331   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   332 
   333 
   334 
   335 (** sorts of types **)
   336 
   337 (* errors -- delayed message composition *)
   338 
   339 datatype class_error =
   340   NoClassrel of class * class |
   341   NoArity of string * class |
   342   NoSubsort of sort * sort;
   343 
   344 fun class_error pp (NoClassrel (c1, c2)) =
   345       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   346   | class_error pp (NoArity (a, c)) =
   347       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c])
   348   | class_error pp (NoSubsort (S1, S2)) =
   349      "Cannot derive subsort relation " ^ Pretty.string_of_sort pp S1
   350        ^ " < " ^ Pretty.string_of_sort pp S2;
   351 
   352 exception CLASS_ERROR of class_error;
   353 
   354 
   355 (* mg_domain *)
   356 
   357 fun mg_domain algebra a S =
   358   let
   359     val arities = arities_of algebra;
   360     fun dom c =
   361       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   362         NONE => raise CLASS_ERROR (NoArity (a, c))
   363       | SOME (_, Ss) => Ss);
   364     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   365   in
   366     (case S of
   367       [] => raise Fail "Unknown domain of empty intersection"
   368     | c :: cs => fold dom_inter cs (dom c))
   369   end;
   370 
   371 
   372 (* meet_sort *)
   373 
   374 fun meet_sort algebra =
   375   let
   376     fun inters S S' = inter_sort algebra (S, S');
   377     fun meet _ [] = I
   378       | meet (TFree (_, S)) S' =
   379           if sort_le algebra (S, S') then I
   380           else raise CLASS_ERROR (NoSubsort (S, S'))
   381       | meet (TVar (v, S)) S' =
   382           if sort_le algebra (S, S') then I
   383           else Vartab.map_default (v, S) (inters S')
   384       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   385   in uncurry meet end;
   386 
   387 fun meet_sort_typ algebra (T, S) =
   388   let
   389     val tab = meet_sort algebra (T, S) Vartab.empty;
   390   in Term.map_type_tvar (fn (v, _) =>
   391     TVar (v, (the o Vartab.lookup tab) v))
   392   end;
   393 
   394 
   395 (* of_sort *)
   396 
   397 fun of_sort algebra =
   398   let
   399     fun ofS (_, []) = true
   400       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   401       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   402       | ofS (Type (a, Ts), S) =
   403           let val Ss = mg_domain algebra a S in
   404             ListPair.all ofS (Ts, Ss)
   405           end handle CLASS_ERROR _ => false;
   406   in ofS end;
   407 
   408 
   409 (* animating derivations *)
   410 
   411 fun weaken algebra class_relation =
   412   let
   413     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   414       | path (x, _) = x;
   415   in fn (x, c1) => fn c2 =>
   416     (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   417       [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   418     | cs :: _ => path (x, cs))
   419   end;
   420 
   421 fun of_sort_derivation pp algebra {class_relation, type_constructor, type_variable} =
   422   let
   423     val weaken = weaken algebra class_relation;
   424     val arities = arities_of algebra;
   425 
   426     fun weakens S1 S2 = S2 |> map (fn c2 =>
   427       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   428         SOME d1 => weaken d1 c2
   429       | NONE => raise CLASS_ERROR (NoSubsort (map #2 S1, S2))));
   430 
   431     fun derive _ [] = []
   432       | derive (Type (a, Ts)) S =
   433           let
   434             val Ss = mg_domain algebra a S;
   435             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   436           in
   437             S |> map (fn c =>
   438               let
   439                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   440                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   441               in weaken (type_constructor a dom' c0, c0) c end)
   442           end
   443       | derive T S = weakens (type_variable T) S;
   444   in uncurry derive end;
   445 
   446 
   447 (* witness_sorts *)
   448 
   449 fun witness_sorts algebra types hyps sorts =
   450   let
   451     fun le S1 S2 = sort_le algebra (S1, S2);
   452     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   453     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   454     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   455 
   456     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   457       | witn_sort path S (solved, failed) =
   458           if exists (le S) failed then (NONE, (solved, failed))
   459           else
   460             (case get_first (get_solved S) solved of
   461               SOME w => (SOME w, (solved, failed))
   462             | NONE =>
   463                 (case get_first (get_hyp S) hyps of
   464                   SOME w => (SOME w, (w :: solved, failed))
   465                 | NONE => witn_types path types S (solved, failed)))
   466 
   467     and witn_sorts path x = fold_map (witn_sort path) x
   468 
   469     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   470       | witn_types path (t :: ts) S solved_failed =
   471           (case mg_dom t S of
   472             SOME SS =>
   473               (*do not descend into stronger args (achieving termination)*)
   474               if exists (fn D => le D S orelse exists (le D) path) SS then
   475                 witn_types path ts S solved_failed
   476               else
   477                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   478                   if forall is_some ws then
   479                     let val w = (Type (t, map (#1 o the) ws), S)
   480                     in (SOME w, (w :: solved', failed')) end
   481                   else witn_types path ts S (solved', failed')
   482                 end
   483           | NONE => witn_types path ts S solved_failed);
   484 
   485   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   486 
   487 end;