src/Pure/sorts.ML
author wenzelm
Sun Apr 13 16:40:05 2008 +0200 (2008-04-13)
changeset 26639 75ea79a50326
parent 26517 ef036a63f6e9
child 26994 197af793312c
permissions -rw-r--r--
removed unused minimal_classes;
class_error: produce message only (formerly msg_class_error);
tuned;
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val union: sort list -> sort list -> sort list
    19   val subtract: sort list -> sort list -> sort list
    20   val remove_sort: sort -> sort list -> sort list
    21   val insert_sort: sort -> sort list -> sort list
    22   val insert_typ: typ -> sort list -> sort list
    23   val insert_typs: typ list -> sort list -> sort list
    24   val insert_term: term -> sort list -> sort list
    25   val insert_terms: term list -> sort list -> sort list
    26   type algebra
    27   val rep_algebra: algebra ->
    28    {classes: serial Graph.T,
    29     arities: (class * (class * sort list)) list Symtab.table}
    30   val all_classes: algebra -> class list
    31   val super_classes: algebra -> class -> class list
    32   val class_less: algebra -> class * class -> bool
    33   val class_le: algebra -> class * class -> bool
    34   val sort_eq: algebra -> sort * sort -> bool
    35   val sort_le: algebra -> sort * sort -> bool
    36   val sorts_le: algebra -> sort list * sort list -> bool
    37   val inter_sort: algebra -> sort * sort -> sort
    38   val minimize_sort: algebra -> sort -> sort
    39   val complete_sort: algebra -> sort -> sort
    40   val certify_class: algebra -> class -> class    (*exception TYPE*)
    41   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    42   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    43   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    44   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    45   val empty_algebra: algebra
    46   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    47   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list)
    48     -> algebra -> (sort -> sort) * algebra
    49   type class_error
    50   val class_error: Pretty.pp -> class_error -> string
    51   exception CLASS_ERROR of class_error
    52   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    53   val meet_sort: algebra -> typ * sort -> sort Vartab.table -> sort Vartab.table
    54   val of_sort: algebra -> typ * sort -> bool
    55   val of_sort_derivation: Pretty.pp -> algebra ->
    56     {class_relation: 'a * class -> class -> 'a,
    57      type_constructor: string -> ('a * class) list list -> class -> 'a,
    58      type_variable: typ -> ('a * class) list} ->
    59     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    60   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    61 end;
    62 
    63 structure Sorts: SORTS =
    64 struct
    65 
    66 
    67 (** ordered lists of sorts **)
    68 
    69 val op union = OrdList.union Term.sort_ord;
    70 val subtract = OrdList.subtract Term.sort_ord;
    71 
    72 val remove_sort = OrdList.remove Term.sort_ord;
    73 val insert_sort = OrdList.insert Term.sort_ord;
    74 
    75 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    76   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    77   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    78 and insert_typs [] Ss = Ss
    79   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    80 
    81 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    82   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    83   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    84   | insert_term (Bound _) Ss = Ss
    85   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    86   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    87 
    88 fun insert_terms [] Ss = Ss
    89   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    90 
    91 
    92 
    93 (** order-sorted algebra **)
    94 
    95 (*
    96   classes: graph representing class declarations together with proper
    97     subclass relation, which needs to be transitive and acyclic.
    98 
    99   arities: table of association lists of all type arities; (t, ars)
   100     means that type constructor t has the arities ars; an element
   101     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   102     via c0 <= c.  "Coregularity" of the arities structure requires
   103     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   104     c1 <= c2 holds Ss1 <= Ss2.
   105 *)
   106 
   107 datatype algebra = Algebra of
   108  {classes: serial Graph.T,
   109   arities: (class * (class * sort list)) list Symtab.table};
   110 
   111 fun rep_algebra (Algebra args) = args;
   112 
   113 val classes_of = #classes o rep_algebra;
   114 val arities_of = #arities o rep_algebra;
   115 
   116 fun make_algebra (classes, arities) =
   117   Algebra {classes = classes, arities = arities};
   118 
   119 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   120 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   121 
   122 
   123 (* classes *)
   124 
   125 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   126 
   127 val super_classes = Graph.imm_succs o classes_of;
   128 
   129 
   130 (* class relations *)
   131 
   132 val class_less = Graph.is_edge o classes_of;
   133 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   134 
   135 
   136 (* sort relations *)
   137 
   138 fun sort_le algebra (S1, S2) =
   139   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   140 
   141 fun sorts_le algebra (Ss1, Ss2) =
   142   ListPair.all (sort_le algebra) (Ss1, Ss2);
   143 
   144 fun sort_eq algebra (S1, S2) =
   145   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   146 
   147 
   148 (* intersection *)
   149 
   150 fun inter_class algebra c S =
   151   let
   152     fun intr [] = [c]
   153       | intr (S' as c' :: c's) =
   154           if class_le algebra (c', c) then S'
   155           else if class_le algebra (c, c') then intr c's
   156           else c' :: intr c's
   157   in intr S end;
   158 
   159 fun inter_sort algebra (S1, S2) =
   160   sort_strings (fold (inter_class algebra) S1 S2);
   161 
   162 
   163 (* normal forms *)
   164 
   165 fun minimize_sort _ [] = []
   166   | minimize_sort _ (S as [_]) = S
   167   | minimize_sort algebra S =
   168       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   169       |> sort_distinct string_ord;
   170 
   171 fun complete_sort algebra =
   172   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   173 
   174 
   175 (* certify *)
   176 
   177 fun certify_class algebra c =
   178   if can (Graph.get_node (classes_of algebra)) c then c
   179   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   180 
   181 fun certify_sort classes = minimize_sort classes o map (certify_class classes);
   182 
   183 
   184 
   185 (** build algebras **)
   186 
   187 (* classes *)
   188 
   189 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   190 
   191 fun err_cyclic_classes pp css =
   192   error (cat_lines (map (fn cs =>
   193     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   194 
   195 fun add_class pp (c, cs) = map_classes (fn classes =>
   196   let
   197     val classes' = classes |> Graph.new_node (c, serial ())
   198       handle Graph.DUP dup => err_dup_class dup;
   199     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   200       handle Graph.CYCLES css => err_cyclic_classes pp css;
   201   in classes'' end);
   202 
   203 
   204 (* arities *)
   205 
   206 local
   207 
   208 fun for_classes _ NONE = ""
   209   | for_classes pp (SOME (c1, c2)) =
   210       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   211 
   212 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   213   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   214     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   215     Pretty.string_of_arity pp (t, Ss', [c']));
   216 
   217 fun coregular pp algebra t (c, (c0, Ss)) ars =
   218   let
   219     fun conflict (c', (_, Ss')) =
   220       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   221         SOME ((c, c'), (c', Ss'))
   222       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   223         SOME ((c', c), (c', Ss'))
   224       else NONE;
   225   in
   226     (case get_first conflict ars of
   227       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   228     | NONE => (c, (c0, Ss)) :: ars)
   229   end;
   230 
   231 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   232 
   233 fun insert pp algebra t (c, (c0, Ss)) ars =
   234   (case AList.lookup (op =) ars c of
   235     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   236   | SOME (_, Ss') =>
   237       if sorts_le algebra (Ss, Ss') then ars
   238       else if sorts_le algebra (Ss', Ss) then
   239         coregular pp algebra t (c, (c0, Ss))
   240           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   241       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   242 
   243 fun insert_ars pp algebra (t, ars) arities =
   244   let val ars' =
   245     Symtab.lookup_list arities t
   246     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   247   in Symtab.update (t, ars') arities end;
   248 
   249 in
   250 
   251 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   252 
   253 fun add_arities_table pp algebra =
   254   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   255 
   256 end;
   257 
   258 
   259 (* classrel *)
   260 
   261 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   262   Symtab.empty
   263   |> add_arities_table pp algebra arities);
   264 
   265 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   266   classes |> Graph.add_edge_trans_acyclic rel
   267     handle Graph.CYCLES css => err_cyclic_classes pp css);
   268 
   269 
   270 (* empty and merge *)
   271 
   272 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   273 
   274 fun merge_algebra pp
   275    (Algebra {classes = classes1, arities = arities1},
   276     Algebra {classes = classes2, arities = arities2}) =
   277   let
   278     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   279       handle Graph.DUP c => err_dup_class c
   280           | Graph.CYCLES css => err_cyclic_classes pp css;
   281     val algebra0 = make_algebra (classes', Symtab.empty);
   282     val arities' = Symtab.empty
   283       |> add_arities_table pp algebra0 arities1
   284       |> add_arities_table pp algebra0 arities2;
   285   in make_algebra (classes', arities') end;
   286 
   287 
   288 (* subalgebra *)
   289 
   290 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   291   let
   292     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   293     fun restrict_arity tyco (c, (_, Ss)) =
   294       if P c then
   295         SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) (sargs (c, tyco))
   296           |> map restrict_sort))
   297       else NONE;
   298     val classes' = classes |> Graph.subgraph P;
   299     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   300   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   301 
   302 
   303 
   304 (** sorts of types **)
   305 
   306 (* errors -- delayed message composition *)
   307 
   308 datatype class_error =
   309   NoClassrel of class * class |
   310   NoArity of string * class |
   311   NoSubsort of sort * sort;
   312 
   313 fun class_error pp (NoClassrel (c1, c2)) =
   314       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   315   | class_error pp (NoArity (a, c)) =
   316       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c])
   317   | class_error pp (NoSubsort (s1, s2)) =
   318       Pretty.string_of_sort pp s1 ^ " is not a subsort of " ^ Pretty.string_of_sort pp s2;
   319 
   320 exception CLASS_ERROR of class_error;
   321 
   322 
   323 (* mg_domain *)
   324 
   325 fun mg_domain algebra a S =
   326   let
   327     val arities = arities_of algebra;
   328     fun dom c =
   329       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   330         NONE => raise CLASS_ERROR (NoArity (a, c))
   331       | SOME (_, Ss) => Ss);
   332     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   333   in
   334     (case S of
   335       [] => raise Fail "Unknown domain of empty intersection"
   336     | c :: cs => fold dom_inter cs (dom c))
   337   end;
   338 
   339 
   340 (* meet_sort *)
   341 
   342 fun meet_sort algebra =
   343   let
   344     fun inters S S' = inter_sort algebra (S, S');
   345     fun meet _ [] = I
   346       | meet (TFree (_, S)) S' =
   347           if sort_le algebra (S, S') then I
   348           else raise CLASS_ERROR (NoSubsort (S, S'))
   349       | meet (TVar (v, S)) S' =
   350           if sort_le algebra (S, S') then I
   351           else Vartab.map_default (v, S) (inters S')
   352       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   353   in uncurry meet end;
   354 
   355 
   356 (* of_sort *)
   357 
   358 fun of_sort algebra =
   359   let
   360     fun ofS (_, []) = true
   361       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   362       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   363       | ofS (Type (a, Ts), S) =
   364           let val Ss = mg_domain algebra a S in
   365             ListPair.all ofS (Ts, Ss)
   366           end handle CLASS_ERROR _ => false;
   367   in ofS end;
   368 
   369 
   370 (* of_sort_derivation *)
   371 
   372 fun of_sort_derivation pp algebra {class_relation, type_constructor, type_variable} =
   373   let
   374     val {classes, arities} = rep_algebra algebra;
   375     fun weaken_path (x, c1 :: c2 :: cs) =
   376           weaken_path (class_relation (x, c1) c2, c2 :: cs)
   377       | weaken_path (x, _) = x;
   378     fun weaken (x, c1) c2 =
   379       (case Graph.irreducible_paths classes (c1, c2) of
   380         [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   381       | cs :: _ => weaken_path (x, cs));
   382 
   383     fun weakens S1 S2 = S2 |> map (fn c2 =>
   384       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   385         SOME d1 => weaken d1 c2
   386       | NONE => error ("Cannot derive subsort relation " ^
   387           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   388 
   389     fun derive _ [] = []
   390       | derive (Type (a, Ts)) S =
   391           let
   392             val Ss = mg_domain algebra a S;
   393             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   394           in
   395             S |> map (fn c =>
   396               let
   397                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   398                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   399               in weaken (type_constructor a dom' c0, c0) c end)
   400           end
   401       | derive T S = weakens (type_variable T) S;
   402   in uncurry derive end;
   403 
   404 
   405 (* witness_sorts *)
   406 
   407 fun witness_sorts algebra types hyps sorts =
   408   let
   409     fun le S1 S2 = sort_le algebra (S1, S2);
   410     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   411     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   412     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   413 
   414     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   415       | witn_sort path S (solved, failed) =
   416           if exists (le S) failed then (NONE, (solved, failed))
   417           else
   418             (case get_first (get_solved S) solved of
   419               SOME w => (SOME w, (solved, failed))
   420             | NONE =>
   421                 (case get_first (get_hyp S) hyps of
   422                   SOME w => (SOME w, (w :: solved, failed))
   423                 | NONE => witn_types path types S (solved, failed)))
   424 
   425     and witn_sorts path x = fold_map (witn_sort path) x
   426 
   427     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   428       | witn_types path (t :: ts) S solved_failed =
   429           (case mg_dom t S of
   430             SOME SS =>
   431               (*do not descend into stronger args (achieving termination)*)
   432               if exists (fn D => le D S orelse exists (le D) path) SS then
   433                 witn_types path ts S solved_failed
   434               else
   435                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   436                   if forall is_some ws then
   437                     let val w = (Type (t, map (#1 o the) ws), S)
   438                     in (SOME w, (w :: solved', failed')) end
   439                   else witn_types path ts S (solved', failed')
   440                 end
   441           | NONE => witn_types path ts S solved_failed);
   442 
   443   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   444 
   445 end;