src/Pure/sorts.ML
author wenzelm
Sat Aug 20 23:35:30 2011 +0200 (2011-08-20)
changeset 44338 700008399ee5
parent 43792 d5803c3d537a
child 45595 fe57d786fd5b
permissions -rw-r--r--
refined Graph implementation: more abstract/scalable Graph.Keys instead of plain lists -- order of adjacency is now standardized wrt. Key.ord;
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort Ord_List.T
    18   val subset: sort Ord_List.T * sort Ord_List.T -> bool
    19   val union: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    20   val subtract: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    21   val remove_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    22   val insert_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    23   val insert_typ: typ -> sort Ord_List.T -> sort Ord_List.T
    24   val insert_typs: typ list -> sort Ord_List.T -> sort Ord_List.T
    25   val insert_term: term -> sort Ord_List.T -> sort Ord_List.T
    26   val insert_terms: term list -> sort Ord_List.T -> sort Ord_List.T
    27   type algebra
    28   val classes_of: algebra -> serial Graph.T
    29   val arities_of: algebra -> (class * sort list) list Symtab.table
    30   val all_classes: algebra -> class list
    31   val super_classes: algebra -> class -> class list
    32   val class_less: algebra -> class * class -> bool
    33   val class_le: algebra -> class * class -> bool
    34   val sort_eq: algebra -> sort * sort -> bool
    35   val sort_le: algebra -> sort * sort -> bool
    36   val sorts_le: algebra -> sort list * sort list -> bool
    37   val inter_sort: algebra -> sort * sort -> sort
    38   val minimize_sort: algebra -> sort -> sort
    39   val complete_sort: algebra -> sort -> sort
    40   val minimal_sorts: algebra -> sort list -> sort Ord_List.T
    41   val certify_class: algebra -> class -> class    (*exception TYPE*)
    42   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    43   val add_class: Proof.context -> class * class list -> algebra -> algebra
    44   val add_classrel: Proof.context -> class * class -> algebra -> algebra
    45   val add_arities: Proof.context -> string * (class * sort list) list -> algebra -> algebra
    46   val empty_algebra: algebra
    47   val merge_algebra: Proof.context -> algebra * algebra -> algebra
    48   val subalgebra: Proof.context -> (class -> bool) -> (class * string -> sort list option)
    49     -> algebra -> (sort -> sort) * algebra
    50   type class_error
    51   val class_error: Proof.context -> class_error -> string
    52   exception CLASS_ERROR of class_error
    53   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    54   val meet_sort: algebra -> typ * sort
    55     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    56   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    57   val of_sort: algebra -> typ * sort -> bool
    58   val of_sort_derivation: algebra ->
    59     {class_relation: typ -> 'a * class -> class -> 'a,
    60      type_constructor: string * typ list -> ('a * class) list list -> class -> 'a,
    61      type_variable: typ -> ('a * class) list} ->
    62     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    63   val classrel_derivation: algebra ->
    64     ('a * class -> class -> 'a) -> 'a * class -> class -> 'a  (*exception CLASS_ERROR*)
    65   val witness_sorts: algebra -> string list -> (typ * sort) list -> sort list -> (typ * sort) list
    66 end;
    67 
    68 structure Sorts: SORTS =
    69 struct
    70 
    71 
    72 (** ordered lists of sorts **)
    73 
    74 val make = Ord_List.make Term_Ord.sort_ord;
    75 val subset = Ord_List.subset Term_Ord.sort_ord;
    76 val union = Ord_List.union Term_Ord.sort_ord;
    77 val subtract = Ord_List.subtract Term_Ord.sort_ord;
    78 
    79 val remove_sort = Ord_List.remove Term_Ord.sort_ord;
    80 val insert_sort = Ord_List.insert Term_Ord.sort_ord;
    81 
    82 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    83   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    84   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    85 and insert_typs [] Ss = Ss
    86   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    87 
    88 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    89   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    90   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    91   | insert_term (Bound _) Ss = Ss
    92   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    93   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    94 
    95 fun insert_terms [] Ss = Ss
    96   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    97 
    98 
    99 
   100 (** order-sorted algebra **)
   101 
   102 (*
   103   classes: graph representing class declarations together with proper
   104     subclass relation, which needs to be transitive and acyclic.
   105 
   106   arities: table of association lists of all type arities; (t, ars)
   107     means that type constructor t has the arities ars; an element
   108     (c, Ss) of ars represents the arity t::(Ss)c.  "Coregularity" of
   109     the arities structure requires that for any two declarations
   110     t::(Ss1)c1 and t::(Ss2)c2 such that c1 <= c2 holds Ss1 <= Ss2.
   111 *)
   112 
   113 datatype algebra = Algebra of
   114  {classes: serial Graph.T,
   115   arities: (class * sort list) list Symtab.table};
   116 
   117 fun classes_of (Algebra {classes, ...}) = classes;
   118 fun arities_of (Algebra {arities, ...}) = arities;
   119 
   120 fun make_algebra (classes, arities) =
   121   Algebra {classes = classes, arities = arities};
   122 
   123 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   124 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   125 
   126 
   127 (* classes *)
   128 
   129 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   130 
   131 val super_classes = Graph.immediate_succs o classes_of;
   132 
   133 
   134 (* class relations *)
   135 
   136 val class_less = Graph.is_edge o classes_of;
   137 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   138 
   139 
   140 (* sort relations *)
   141 
   142 fun sort_le algebra (S1, S2) =
   143   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   144 
   145 fun sorts_le algebra (Ss1, Ss2) =
   146   ListPair.all (sort_le algebra) (Ss1, Ss2);
   147 
   148 fun sort_eq algebra (S1, S2) =
   149   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   150 
   151 
   152 (* intersection *)
   153 
   154 fun inter_class algebra c S =
   155   let
   156     fun intr [] = [c]
   157       | intr (S' as c' :: c's) =
   158           if class_le algebra (c', c) then S'
   159           else if class_le algebra (c, c') then intr c's
   160           else c' :: intr c's
   161   in intr S end;
   162 
   163 fun inter_sort algebra (S1, S2) =
   164   sort_strings (fold (inter_class algebra) S1 S2);
   165 
   166 
   167 (* normal forms *)
   168 
   169 fun minimize_sort _ [] = []
   170   | minimize_sort _ (S as [_]) = S
   171   | minimize_sort algebra S =
   172       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   173       |> sort_distinct string_ord;
   174 
   175 fun complete_sort algebra =
   176   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   177 
   178 fun minimal_sorts algebra raw_sorts =
   179   let
   180     fun le S1 S2 = sort_le algebra (S1, S2);
   181     val sorts = make (map (minimize_sort algebra) raw_sorts);
   182   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   183 
   184 
   185 (* certify *)
   186 
   187 fun certify_class algebra c =
   188   #1 (Graph.get_entry (classes_of algebra) c)
   189     handle Graph.UNDEF _ => raise TYPE ("Undeclared class: " ^ quote c, [], []);
   190 
   191 fun certify_sort classes = map (certify_class classes);
   192 
   193 
   194 
   195 (** build algebras **)
   196 
   197 (* classes *)
   198 
   199 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   200 
   201 fun err_cyclic_classes ctxt css =
   202   error (cat_lines (map (fn cs =>
   203     "Cycle in class relation: " ^ Syntax.string_of_classrel ctxt cs) css));
   204 
   205 fun add_class ctxt (c, cs) = map_classes (fn classes =>
   206   let
   207     val classes' = classes |> Graph.new_node (c, serial ())
   208       handle Graph.DUP dup => err_dup_class dup;
   209     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   210       handle Graph.CYCLES css => err_cyclic_classes ctxt css;
   211   in classes'' end);
   212 
   213 
   214 (* arities *)
   215 
   216 local
   217 
   218 fun for_classes _ NONE = ""
   219   | for_classes ctxt (SOME (c1, c2)) = " for classes " ^ Syntax.string_of_classrel ctxt [c1, c2];
   220 
   221 fun err_conflict ctxt t cc (c, Ss) (c', Ss') =
   222   error ("Conflict of type arities" ^ for_classes ctxt cc ^ ":\n  " ^
   223     Syntax.string_of_arity ctxt (t, Ss, [c]) ^ " and\n  " ^
   224     Syntax.string_of_arity ctxt (t, Ss', [c']));
   225 
   226 fun coregular ctxt algebra t (c, Ss) ars =
   227   let
   228     fun conflict (c', Ss') =
   229       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   230         SOME ((c, c'), (c', Ss'))
   231       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   232         SOME ((c', c), (c', Ss'))
   233       else NONE;
   234   in
   235     (case get_first conflict ars of
   236       SOME ((c1, c2), (c', Ss')) => err_conflict ctxt t (SOME (c1, c2)) (c, Ss) (c', Ss')
   237     | NONE => (c, Ss) :: ars)
   238   end;
   239 
   240 fun complete algebra (c, Ss) = map (rpair Ss) (c :: super_classes algebra c);
   241 
   242 fun insert ctxt algebra t (c, Ss) ars =
   243   (case AList.lookup (op =) ars c of
   244     NONE => coregular ctxt algebra t (c, Ss) ars
   245   | SOME Ss' =>
   246       if sorts_le algebra (Ss, Ss') then ars
   247       else if sorts_le algebra (Ss', Ss)
   248       then coregular ctxt algebra t (c, Ss) (remove (op =) (c, Ss') ars)
   249       else err_conflict ctxt t NONE (c, Ss) (c, Ss'));
   250 
   251 in
   252 
   253 fun insert_ars ctxt algebra t = fold_rev (insert ctxt algebra t);
   254 
   255 fun insert_complete_ars ctxt algebra (t, ars) arities =
   256   let val ars' =
   257     Symtab.lookup_list arities t
   258     |> fold_rev (insert_ars ctxt algebra t) (map (complete algebra) ars);
   259   in Symtab.update (t, ars') arities end;
   260 
   261 fun add_arities ctxt arg algebra =
   262   algebra |> map_arities (insert_complete_ars ctxt algebra arg);
   263 
   264 fun add_arities_table ctxt algebra =
   265   Symtab.fold (fn (t, ars) => insert_complete_ars ctxt algebra (t, ars));
   266 
   267 end;
   268 
   269 
   270 (* classrel *)
   271 
   272 fun rebuild_arities ctxt algebra = algebra |> map_arities (fn arities =>
   273   Symtab.empty
   274   |> add_arities_table ctxt algebra arities);
   275 
   276 fun add_classrel ctxt rel = rebuild_arities ctxt o map_classes (fn classes =>
   277   classes |> Graph.add_edge_trans_acyclic rel
   278     handle Graph.CYCLES css => err_cyclic_classes ctxt css);
   279 
   280 
   281 (* empty and merge *)
   282 
   283 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   284 
   285 fun merge_algebra ctxt
   286    (Algebra {classes = classes1, arities = arities1},
   287     Algebra {classes = classes2, arities = arities2}) =
   288   let
   289     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   290       handle Graph.DUP c => err_dup_class c
   291         | Graph.CYCLES css => err_cyclic_classes ctxt css;
   292     val algebra0 = make_algebra (classes', Symtab.empty);
   293     val arities' =
   294       (case (pointer_eq (classes1, classes2), pointer_eq (arities1, arities2)) of
   295         (true, true) => arities1
   296       | (true, false) =>  (*no completion*)
   297           (arities1, arities2) |> Symtab.join (fn t => fn (ars1, ars2) =>
   298             if pointer_eq (ars1, ars2) then raise Symtab.SAME
   299             else insert_ars ctxt algebra0 t ars2 ars1)
   300       | (false, true) =>  (*unary completion*)
   301           Symtab.empty
   302           |> add_arities_table ctxt algebra0 arities1
   303       | (false, false) => (*binary completion*)
   304           Symtab.empty
   305           |> add_arities_table ctxt algebra0 arities1
   306           |> add_arities_table ctxt algebra0 arities2);
   307   in make_algebra (classes', arities') end;
   308 
   309 
   310 (* algebra projections *)  (* FIXME potentially violates abstract type integrity *)
   311 
   312 fun subalgebra ctxt P sargs (algebra as Algebra {classes, arities}) =
   313   let
   314     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   315     fun restrict_arity t (c, Ss) =
   316       if P c then
   317         (case sargs (c, t) of
   318           SOME sorts =>
   319             SOME (c, Ss |> map2 (curry (inter_sort algebra)) sorts |> map restrict_sort)
   320         | NONE => NONE)
   321       else NONE;
   322     val classes' = classes |> Graph.subgraph P;
   323     val arities' = arities |> Symtab.map (map_filter o restrict_arity);
   324   in (restrict_sort, rebuild_arities ctxt (make_algebra (classes', arities'))) end;
   325 
   326 
   327 
   328 (** sorts of types **)
   329 
   330 (* errors -- performance tuning via delayed message composition *)
   331 
   332 datatype class_error =
   333   No_Classrel of class * class |
   334   No_Arity of string * class |
   335   No_Subsort of sort * sort;
   336 
   337 fun class_error ctxt (No_Classrel (c1, c2)) =
   338       "No class relation " ^ Syntax.string_of_classrel ctxt [c1, c2]
   339   | class_error ctxt (No_Arity (a, c)) =
   340       "No type arity " ^ Syntax.string_of_arity ctxt (a, [], [c])
   341   | class_error ctxt (No_Subsort (S1, S2)) =
   342       "Cannot derive subsort relation " ^
   343         Syntax.string_of_sort ctxt S1 ^ " < " ^ Syntax.string_of_sort ctxt S2;
   344 
   345 exception CLASS_ERROR of class_error;
   346 
   347 
   348 (* mg_domain *)
   349 
   350 fun mg_domain algebra a S =
   351   let
   352     val arities = arities_of algebra;
   353     fun dom c =
   354       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   355         NONE => raise CLASS_ERROR (No_Arity (a, c))
   356       | SOME Ss => Ss);
   357     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   358   in
   359     (case S of
   360       [] => raise Fail "Unknown domain of empty intersection"
   361     | c :: cs => fold dom_inter cs (dom c))
   362   end;
   363 
   364 
   365 (* meet_sort *)
   366 
   367 fun meet_sort algebra =
   368   let
   369     fun inters S S' = inter_sort algebra (S, S');
   370     fun meet _ [] = I
   371       | meet (TFree (_, S)) S' =
   372           if sort_le algebra (S, S') then I
   373           else raise CLASS_ERROR (No_Subsort (S, S'))
   374       | meet (TVar (v, S)) S' =
   375           if sort_le algebra (S, S') then I
   376           else Vartab.map_default (v, S) (inters S')
   377       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   378   in uncurry meet end;
   379 
   380 fun meet_sort_typ algebra (T, S) =
   381   let val tab = meet_sort algebra (T, S) Vartab.empty;
   382   in Term.map_type_tvar (fn (v, _) => TVar (v, (the o Vartab.lookup tab) v)) end;
   383 
   384 
   385 (* of_sort *)
   386 
   387 fun of_sort algebra =
   388   let
   389     fun ofS (_, []) = true
   390       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   391       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   392       | ofS (Type (a, Ts), S) =
   393           let val Ss = mg_domain algebra a S in
   394             ListPair.all ofS (Ts, Ss)
   395           end handle CLASS_ERROR _ => false;
   396   in ofS end;
   397 
   398 
   399 (* animating derivations *)
   400 
   401 fun of_sort_derivation algebra {class_relation, type_constructor, type_variable} =
   402   let
   403     val arities = arities_of algebra;
   404 
   405     fun weaken T D1 S2 =
   406       let val S1 = map snd D1 in
   407         if S1 = S2 then map fst D1
   408         else
   409           S2 |> map (fn c2 =>
   410             (case D1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   411               SOME d1 => class_relation T d1 c2
   412             | NONE => raise CLASS_ERROR (No_Subsort (S1, S2))))
   413       end;
   414 
   415     fun derive (_, []) = []
   416       | derive (Type (a, Us), S) =
   417           let
   418             val Ss = mg_domain algebra a S;
   419             val dom = map2 (fn U => fn S => derive (U, S) ~~ S) Us Ss;
   420           in
   421             S |> map (fn c =>
   422               let
   423                 val Ss' = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   424                 val dom' = map (fn ((U, d), S') => weaken U d S' ~~ S') ((Us ~~ dom) ~~ Ss');
   425               in type_constructor (a, Us) dom' c end)
   426           end
   427       | derive (T, S) = weaken T (type_variable T) S;
   428   in derive end;
   429 
   430 fun classrel_derivation algebra class_relation =
   431   let
   432     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   433       | path (x, _) = x;
   434   in
   435     fn (x, c1) => fn c2 =>
   436       (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   437         [] => raise CLASS_ERROR (No_Classrel (c1, c2))
   438       | cs :: _ => path (x, cs))
   439   end;
   440 
   441 
   442 (* witness_sorts *)
   443 
   444 fun witness_sorts algebra types hyps sorts =
   445   let
   446     fun le S1 S2 = sort_le algebra (S1, S2);
   447     fun get S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   448     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   449 
   450     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   451       | witn_sort path S (solved, failed) =
   452           if exists (le S) failed then (NONE, (solved, failed))
   453           else
   454             (case get_first (get S) solved of
   455               SOME w => (SOME w, (solved, failed))
   456             | NONE =>
   457                 (case get_first (get S) hyps of
   458                   SOME w => (SOME w, (w :: solved, failed))
   459                 | NONE => witn_types path types S (solved, failed)))
   460 
   461     and witn_sorts path x = fold_map (witn_sort path) x
   462 
   463     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   464       | witn_types path (t :: ts) S solved_failed =
   465           (case mg_dom t S of
   466             SOME SS =>
   467               (*do not descend into stronger args (achieving termination)*)
   468               if exists (fn D => le D S orelse exists (le D) path) SS then
   469                 witn_types path ts S solved_failed
   470               else
   471                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   472                   if forall is_some ws then
   473                     let val w = (Type (t, map (#1 o the) ws), S)
   474                     in (SOME w, (w :: solved', failed')) end
   475                   else witn_types path ts S (solved', failed')
   476                 end
   477           | NONE => witn_types path ts S solved_failed);
   478 
   479   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   480 
   481 end;