src/Pure/sorts.ML
author wenzelm
Thu Oct 16 22:44:32 2008 +0200 (2008-10-16)
changeset 28623 de573f2e5389
parent 28374 27f1b5cc5f9b
child 28665 98aba9fc90f6
permissions -rw-r--r--
added make, minimal_sorts;
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val make: sort list -> sort OrdList.T
    19   val subset: sort OrdList.T * sort OrdList.T -> bool
    20   val union: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    21   val subtract: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    22   val remove_sort: sort -> sort OrdList.T -> sort OrdList.T
    23   val insert_sort: sort -> sort OrdList.T -> sort OrdList.T
    24   val insert_typ: typ -> sort OrdList.T -> sort OrdList.T
    25   val insert_typs: typ list -> sort OrdList.T -> sort OrdList.T
    26   val insert_term: term -> sort OrdList.T -> sort OrdList.T
    27   val insert_terms: term list -> sort OrdList.T -> sort OrdList.T
    28   type algebra
    29   val rep_algebra: algebra ->
    30    {classes: serial Graph.T,
    31     arities: (class * (class * sort list)) list Symtab.table}
    32   val all_classes: algebra -> class list
    33   val super_classes: algebra -> class -> class list
    34   val class_less: algebra -> class * class -> bool
    35   val class_le: algebra -> class * class -> bool
    36   val sort_eq: algebra -> sort * sort -> bool
    37   val sort_le: algebra -> sort * sort -> bool
    38   val sorts_le: algebra -> sort list * sort list -> bool
    39   val inter_sort: algebra -> sort * sort -> sort
    40   val minimize_sort: algebra -> sort -> sort
    41   val complete_sort: algebra -> sort -> sort
    42   val minimal_sorts: algebra -> sort list -> sort OrdList.T
    43   val certify_class: algebra -> class -> class    (*exception TYPE*)
    44   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    45   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    46   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    47   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    48   val empty_algebra: algebra
    49   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    50   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list)
    51     -> algebra -> (sort -> sort) * algebra
    52   type class_error
    53   val class_error: Pretty.pp -> class_error -> string
    54   exception CLASS_ERROR of class_error
    55   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    56   val meet_sort: algebra -> typ * sort -> sort Vartab.table -> sort Vartab.table
    57   val of_sort: algebra -> typ * sort -> bool
    58   val weaken: algebra -> ('a * class -> class -> 'a) -> 'a * class -> class -> 'a
    59   val of_sort_derivation: Pretty.pp -> algebra ->
    60     {class_relation: 'a * class -> class -> 'a,
    61      type_constructor: string -> ('a * class) list list -> class -> 'a,
    62      type_variable: typ -> ('a * class) list} ->
    63     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    64   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    65 end;
    66 
    67 structure Sorts: SORTS =
    68 struct
    69 
    70 
    71 (** ordered lists of sorts **)
    72 
    73 val make = OrdList.make Term.sort_ord;
    74 val op subset = OrdList.subset Term.sort_ord;
    75 val op union = OrdList.union Term.sort_ord;
    76 val subtract = OrdList.subtract Term.sort_ord;
    77 
    78 val remove_sort = OrdList.remove Term.sort_ord;
    79 val insert_sort = OrdList.insert Term.sort_ord;
    80 
    81 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    82   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    83   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    84 and insert_typs [] Ss = Ss
    85   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    86 
    87 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    88   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    89   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    90   | insert_term (Bound _) Ss = Ss
    91   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    92   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    93 
    94 fun insert_terms [] Ss = Ss
    95   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    96 
    97 
    98 
    99 (** order-sorted algebra **)
   100 
   101 (*
   102   classes: graph representing class declarations together with proper
   103     subclass relation, which needs to be transitive and acyclic.
   104 
   105   arities: table of association lists of all type arities; (t, ars)
   106     means that type constructor t has the arities ars; an element
   107     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   108     via c0 <= c.  "Coregularity" of the arities structure requires
   109     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   110     c1 <= c2 holds Ss1 <= Ss2.
   111 *)
   112 
   113 datatype algebra = Algebra of
   114  {classes: serial Graph.T,
   115   arities: (class * (class * sort list)) list Symtab.table};
   116 
   117 fun rep_algebra (Algebra args) = args;
   118 
   119 val classes_of = #classes o rep_algebra;
   120 val arities_of = #arities o rep_algebra;
   121 
   122 fun make_algebra (classes, arities) =
   123   Algebra {classes = classes, arities = arities};
   124 
   125 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   126 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   127 
   128 
   129 (* classes *)
   130 
   131 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   132 
   133 val super_classes = Graph.imm_succs o classes_of;
   134 
   135 
   136 (* class relations *)
   137 
   138 val class_less = Graph.is_edge o classes_of;
   139 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   140 
   141 
   142 (* sort relations *)
   143 
   144 fun sort_le algebra (S1, S2) =
   145   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   146 
   147 fun sorts_le algebra (Ss1, Ss2) =
   148   ListPair.all (sort_le algebra) (Ss1, Ss2);
   149 
   150 fun sort_eq algebra (S1, S2) =
   151   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   152 
   153 
   154 (* intersection *)
   155 
   156 fun inter_class algebra c S =
   157   let
   158     fun intr [] = [c]
   159       | intr (S' as c' :: c's) =
   160           if class_le algebra (c', c) then S'
   161           else if class_le algebra (c, c') then intr c's
   162           else c' :: intr c's
   163   in intr S end;
   164 
   165 fun inter_sort algebra (S1, S2) =
   166   sort_strings (fold (inter_class algebra) S1 S2);
   167 
   168 
   169 (* normal forms *)
   170 
   171 fun minimize_sort _ [] = []
   172   | minimize_sort _ (S as [_]) = S
   173   | minimize_sort algebra S =
   174       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   175       |> sort_distinct string_ord;
   176 
   177 fun complete_sort algebra =
   178   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   179 
   180 fun minimal_sorts algebra raw_sorts =
   181   let
   182     fun le S1 S2 = sort_le algebra (S1, S2);
   183     val sorts = make (map (minimize_sort algebra) raw_sorts);
   184   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   185 
   186 
   187 (* certify *)
   188 
   189 fun certify_class algebra c =
   190   if can (Graph.get_node (classes_of algebra)) c then c
   191   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   192 
   193 fun certify_sort classes = minimize_sort classes o map (certify_class classes);
   194 
   195 
   196 
   197 (** build algebras **)
   198 
   199 (* classes *)
   200 
   201 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   202 
   203 fun err_cyclic_classes pp css =
   204   error (cat_lines (map (fn cs =>
   205     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   206 
   207 fun add_class pp (c, cs) = map_classes (fn classes =>
   208   let
   209     val classes' = classes |> Graph.new_node (c, serial ())
   210       handle Graph.DUP dup => err_dup_class dup;
   211     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   212       handle Graph.CYCLES css => err_cyclic_classes pp css;
   213   in classes'' end);
   214 
   215 
   216 (* arities *)
   217 
   218 local
   219 
   220 fun for_classes _ NONE = ""
   221   | for_classes pp (SOME (c1, c2)) =
   222       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   223 
   224 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   225   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   226     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   227     Pretty.string_of_arity pp (t, Ss', [c']));
   228 
   229 fun coregular pp algebra t (c, (c0, Ss)) ars =
   230   let
   231     fun conflict (c', (_, Ss')) =
   232       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   233         SOME ((c, c'), (c', Ss'))
   234       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   235         SOME ((c', c), (c', Ss'))
   236       else NONE;
   237   in
   238     (case get_first conflict ars of
   239       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   240     | NONE => (c, (c0, Ss)) :: ars)
   241   end;
   242 
   243 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   244 
   245 fun insert pp algebra t (c, (c0, Ss)) ars =
   246   (case AList.lookup (op =) ars c of
   247     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   248   | SOME (_, Ss') =>
   249       if sorts_le algebra (Ss, Ss') then ars
   250       else if sorts_le algebra (Ss', Ss) then
   251         coregular pp algebra t (c, (c0, Ss))
   252           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   253       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   254 
   255 fun insert_ars pp algebra (t, ars) arities =
   256   let val ars' =
   257     Symtab.lookup_list arities t
   258     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   259   in Symtab.update (t, ars') arities end;
   260 
   261 in
   262 
   263 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   264 
   265 fun add_arities_table pp algebra =
   266   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   267 
   268 end;
   269 
   270 
   271 (* classrel *)
   272 
   273 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   274   Symtab.empty
   275   |> add_arities_table pp algebra arities);
   276 
   277 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   278   classes |> Graph.add_edge_trans_acyclic rel
   279     handle Graph.CYCLES css => err_cyclic_classes pp css);
   280 
   281 
   282 (* empty and merge *)
   283 
   284 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   285 
   286 fun merge_algebra pp
   287    (Algebra {classes = classes1, arities = arities1},
   288     Algebra {classes = classes2, arities = arities2}) =
   289   let
   290     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   291       handle Graph.DUP c => err_dup_class c
   292           | Graph.CYCLES css => err_cyclic_classes pp css;
   293     val algebra0 = make_algebra (classes', Symtab.empty);
   294     val arities' = Symtab.empty
   295       |> add_arities_table pp algebra0 arities1
   296       |> add_arities_table pp algebra0 arities2;
   297   in make_algebra (classes', arities') end;
   298 
   299 
   300 (* subalgebra *)
   301 
   302 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   303   let
   304     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   305     fun restrict_arity tyco (c, (_, Ss)) =
   306       if P c then
   307         SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) (sargs (c, tyco))
   308           |> map restrict_sort))
   309       else NONE;
   310     val classes' = classes |> Graph.subgraph P;
   311     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   312   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   313 
   314 
   315 
   316 (** sorts of types **)
   317 
   318 (* errors -- delayed message composition *)
   319 
   320 datatype class_error =
   321   NoClassrel of class * class |
   322   NoArity of string * class |
   323   NoSubsort of sort * sort;
   324 
   325 fun class_error pp (NoClassrel (c1, c2)) =
   326       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   327   | class_error pp (NoArity (a, c)) =
   328       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c])
   329   | class_error pp (NoSubsort (S1, S2)) =
   330      "Cannot derive subsort relation " ^ Pretty.string_of_sort pp S1
   331        ^ " < " ^ Pretty.string_of_sort pp S2;
   332 
   333 exception CLASS_ERROR of class_error;
   334 
   335 
   336 (* mg_domain *)
   337 
   338 fun mg_domain algebra a S =
   339   let
   340     val arities = arities_of algebra;
   341     fun dom c =
   342       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   343         NONE => raise CLASS_ERROR (NoArity (a, c))
   344       | SOME (_, Ss) => Ss);
   345     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   346   in
   347     (case S of
   348       [] => raise Fail "Unknown domain of empty intersection"
   349     | c :: cs => fold dom_inter cs (dom c))
   350   end;
   351 
   352 
   353 (* meet_sort *)
   354 
   355 fun meet_sort algebra =
   356   let
   357     fun inters S S' = inter_sort algebra (S, S');
   358     fun meet _ [] = I
   359       | meet (TFree (_, S)) S' =
   360           if sort_le algebra (S, S') then I
   361           else raise CLASS_ERROR (NoSubsort (S, S'))
   362       | meet (TVar (v, S)) S' =
   363           if sort_le algebra (S, S') then I
   364           else Vartab.map_default (v, S) (inters S')
   365       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   366   in uncurry meet end;
   367 
   368 
   369 (* of_sort *)
   370 
   371 fun of_sort algebra =
   372   let
   373     fun ofS (_, []) = true
   374       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   375       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   376       | ofS (Type (a, Ts), S) =
   377           let val Ss = mg_domain algebra a S in
   378             ListPair.all ofS (Ts, Ss)
   379           end handle CLASS_ERROR _ => false;
   380   in ofS end;
   381 
   382 
   383 (* animating derivations *)
   384 
   385 fun weaken algebra class_relation =
   386   let
   387     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   388       | path (x, _) = x;
   389   in fn (x, c1) => fn c2 =>
   390     (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   391       [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   392     | cs :: _ => path (x, cs))
   393   end;
   394 
   395 fun of_sort_derivation pp algebra {class_relation, type_constructor, type_variable} =
   396   let
   397     val weaken = weaken algebra class_relation;
   398     val arities = arities_of algebra;
   399 
   400     fun weakens S1 S2 = S2 |> map (fn c2 =>
   401       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   402         SOME d1 => weaken d1 c2
   403       | NONE => raise CLASS_ERROR (NoSubsort (map #2 S1, S2))));
   404 
   405     fun derive _ [] = []
   406       | derive (Type (a, Ts)) S =
   407           let
   408             val Ss = mg_domain algebra a S;
   409             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   410           in
   411             S |> map (fn c =>
   412               let
   413                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   414                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   415               in weaken (type_constructor a dom' c0, c0) c end)
   416           end
   417       | derive T S = weakens (type_variable T) S;
   418   in uncurry derive end;
   419 
   420 
   421 (* witness_sorts *)
   422 
   423 fun witness_sorts algebra types hyps sorts =
   424   let
   425     fun le S1 S2 = sort_le algebra (S1, S2);
   426     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   427     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   428     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   429 
   430     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   431       | witn_sort path S (solved, failed) =
   432           if exists (le S) failed then (NONE, (solved, failed))
   433           else
   434             (case get_first (get_solved S) solved of
   435               SOME w => (SOME w, (solved, failed))
   436             | NONE =>
   437                 (case get_first (get_hyp S) hyps of
   438                   SOME w => (SOME w, (w :: solved, failed))
   439                 | NONE => witn_types path types S (solved, failed)))
   440 
   441     and witn_sorts path x = fold_map (witn_sort path) x
   442 
   443     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   444       | witn_types path (t :: ts) S solved_failed =
   445           (case mg_dom t S of
   446             SOME SS =>
   447               (*do not descend into stronger args (achieving termination)*)
   448               if exists (fn D => le D S orelse exists (le D) path) SS then
   449                 witn_types path ts S solved_failed
   450               else
   451                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   452                   if forall is_some ws then
   453                     let val w = (Type (t, map (#1 o the) ws), S)
   454                     in (SOME w, (w :: solved', failed')) end
   455                   else witn_types path ts S (solved', failed')
   456                 end
   457           | NONE => witn_types path ts S solved_failed);
   458 
   459   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   460 
   461 end;