src/Pure/sorts.ML
author wenzelm
Tue Jun 01 22:19:17 2010 +0200 (2010-06-01)
changeset 37248 8e8e5f9d1441
parent 36429 9d6b3be996d4
child 39020 ac0f24f850c9
permissions -rw-r--r--
arities: no need to maintain original codomain (cf. f795c1164708) -- completion happens in axclass.ML;
misc tuning;
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort OrdList.T
    18   val subset: sort OrdList.T * sort OrdList.T -> bool
    19   val union: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    20   val subtract: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    21   val remove_sort: sort -> sort OrdList.T -> sort OrdList.T
    22   val insert_sort: sort -> sort OrdList.T -> sort OrdList.T
    23   val insert_typ: typ -> sort OrdList.T -> sort OrdList.T
    24   val insert_typs: typ list -> sort OrdList.T -> sort OrdList.T
    25   val insert_term: term -> sort OrdList.T -> sort OrdList.T
    26   val insert_terms: term list -> sort OrdList.T -> sort OrdList.T
    27   type algebra
    28   val classes_of: algebra -> serial Graph.T
    29   val arities_of: algebra -> (class * sort list) list Symtab.table
    30   val all_classes: algebra -> class list
    31   val super_classes: algebra -> class -> class list
    32   val class_less: algebra -> class * class -> bool
    33   val class_le: algebra -> class * class -> bool
    34   val sort_eq: algebra -> sort * sort -> bool
    35   val sort_le: algebra -> sort * sort -> bool
    36   val sorts_le: algebra -> sort list * sort list -> bool
    37   val inter_sort: algebra -> sort * sort -> sort
    38   val minimize_sort: algebra -> sort -> sort
    39   val complete_sort: algebra -> sort -> sort
    40   val minimal_sorts: algebra -> sort list -> sort OrdList.T
    41   val certify_class: algebra -> class -> class    (*exception TYPE*)
    42   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    43   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    44   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    45   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    46   val empty_algebra: algebra
    47   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    48   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list option)
    49     -> algebra -> (sort -> sort) * algebra
    50   type class_error
    51   val class_error: Pretty.pp -> class_error -> string
    52   exception CLASS_ERROR of class_error
    53   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    54   val meet_sort: algebra -> typ * sort
    55     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    56   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    57   val of_sort: algebra -> typ * sort -> bool
    58   val of_sort_derivation: algebra ->
    59     {class_relation: typ -> 'a * class -> class -> 'a,
    60      type_constructor: string * typ list -> ('a * class) list list -> class -> 'a,
    61      type_variable: typ -> ('a * class) list} ->
    62     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    63   val classrel_derivation: algebra ->
    64     ('a * class -> class -> 'a) -> 'a * class -> class -> 'a  (*exception CLASS_ERROR*)
    65   val witness_sorts: algebra -> string list -> (typ * sort) list -> sort list -> (typ * sort) list
    66 end;
    67 
    68 structure Sorts: SORTS =
    69 struct
    70 
    71 
    72 (** ordered lists of sorts **)
    73 
    74 val make = OrdList.make Term_Ord.sort_ord;
    75 val subset = OrdList.subset Term_Ord.sort_ord;
    76 val union = OrdList.union Term_Ord.sort_ord;
    77 val subtract = OrdList.subtract Term_Ord.sort_ord;
    78 
    79 val remove_sort = OrdList.remove Term_Ord.sort_ord;
    80 val insert_sort = OrdList.insert Term_Ord.sort_ord;
    81 
    82 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    83   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    84   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    85 and insert_typs [] Ss = Ss
    86   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    87 
    88 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    89   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    90   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    91   | insert_term (Bound _) Ss = Ss
    92   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    93   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    94 
    95 fun insert_terms [] Ss = Ss
    96   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    97 
    98 
    99 
   100 (** order-sorted algebra **)
   101 
   102 (*
   103   classes: graph representing class declarations together with proper
   104     subclass relation, which needs to be transitive and acyclic.
   105 
   106   arities: table of association lists of all type arities; (t, ars)
   107     means that type constructor t has the arities ars; an element
   108     (c, Ss) of ars represents the arity t::(Ss)c.  "Coregularity" of
   109     the arities structure requires that for any two declarations
   110     t::(Ss1)c1 and t::(Ss2)c2 such that c1 <= c2 holds Ss1 <= Ss2.
   111 *)
   112 
   113 datatype algebra = Algebra of
   114  {classes: serial Graph.T,
   115   arities: (class * sort list) list Symtab.table};
   116 
   117 fun classes_of (Algebra {classes, ...}) = classes;
   118 fun arities_of (Algebra {arities, ...}) = arities;
   119 
   120 fun make_algebra (classes, arities) =
   121   Algebra {classes = classes, arities = arities};
   122 
   123 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   124 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   125 
   126 
   127 (* classes *)
   128 
   129 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   130 
   131 val super_classes = Graph.imm_succs o classes_of;
   132 
   133 
   134 (* class relations *)
   135 
   136 val class_less = Graph.is_edge o classes_of;
   137 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   138 
   139 
   140 (* sort relations *)
   141 
   142 fun sort_le algebra (S1, S2) =
   143   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   144 
   145 fun sorts_le algebra (Ss1, Ss2) =
   146   ListPair.all (sort_le algebra) (Ss1, Ss2);
   147 
   148 fun sort_eq algebra (S1, S2) =
   149   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   150 
   151 
   152 (* intersection *)
   153 
   154 fun inter_class algebra c S =
   155   let
   156     fun intr [] = [c]
   157       | intr (S' as c' :: c's) =
   158           if class_le algebra (c', c) then S'
   159           else if class_le algebra (c, c') then intr c's
   160           else c' :: intr c's
   161   in intr S end;
   162 
   163 fun inter_sort algebra (S1, S2) =
   164   sort_strings (fold (inter_class algebra) S1 S2);
   165 
   166 
   167 (* normal forms *)
   168 
   169 fun minimize_sort _ [] = []
   170   | minimize_sort _ (S as [_]) = S
   171   | minimize_sort algebra S =
   172       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   173       |> sort_distinct string_ord;
   174 
   175 fun complete_sort algebra =
   176   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   177 
   178 fun minimal_sorts algebra raw_sorts =
   179   let
   180     fun le S1 S2 = sort_le algebra (S1, S2);
   181     val sorts = make (map (minimize_sort algebra) raw_sorts);
   182   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   183 
   184 
   185 (* certify *)
   186 
   187 fun certify_class algebra c =
   188   if can (Graph.get_node (classes_of algebra)) c then c
   189   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   190 
   191 fun certify_sort classes = map (certify_class classes);
   192 
   193 
   194 
   195 (** build algebras **)
   196 
   197 (* classes *)
   198 
   199 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   200 
   201 fun err_cyclic_classes pp css =
   202   error (cat_lines (map (fn cs =>
   203     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   204 
   205 fun add_class pp (c, cs) = map_classes (fn classes =>
   206   let
   207     val classes' = classes |> Graph.new_node (c, serial ())
   208       handle Graph.DUP dup => err_dup_class dup;
   209     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   210       handle Graph.CYCLES css => err_cyclic_classes pp css;
   211   in classes'' end);
   212 
   213 
   214 (* arities *)
   215 
   216 local
   217 
   218 fun for_classes _ NONE = ""
   219   | for_classes pp (SOME (c1, c2)) =
   220       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   221 
   222 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   223   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   224     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   225     Pretty.string_of_arity pp (t, Ss', [c']));
   226 
   227 fun coregular pp algebra t (c, Ss) ars =
   228   let
   229     fun conflict (c', Ss') =
   230       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   231         SOME ((c, c'), (c', Ss'))
   232       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   233         SOME ((c', c), (c', Ss'))
   234       else NONE;
   235   in
   236     (case get_first conflict ars of
   237       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   238     | NONE => (c, Ss) :: ars)
   239   end;
   240 
   241 fun complete algebra (c, Ss) = map (rpair Ss) (c :: super_classes algebra c);
   242 
   243 fun insert pp algebra t (c, Ss) ars =
   244   (case AList.lookup (op =) ars c of
   245     NONE => coregular pp algebra t (c, Ss) ars
   246   | SOME Ss' =>
   247       if sorts_le algebra (Ss, Ss') then ars
   248       else if sorts_le algebra (Ss', Ss)
   249       then coregular pp algebra t (c, Ss) (remove (op =) (c, Ss') ars)
   250       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   251 
   252 in
   253 
   254 fun insert_ars pp algebra t = fold_rev (insert pp algebra t);
   255 
   256 fun insert_complete_ars pp algebra (t, ars) arities =
   257   let val ars' =
   258     Symtab.lookup_list arities t
   259     |> fold_rev (insert_ars pp algebra t) (map (complete algebra) ars);
   260   in Symtab.update (t, ars') arities end;
   261 
   262 fun add_arities pp arg algebra =
   263   algebra |> map_arities (insert_complete_ars pp algebra arg);
   264 
   265 fun add_arities_table pp algebra =
   266   Symtab.fold (fn (t, ars) => insert_complete_ars pp algebra (t, ars));
   267 
   268 end;
   269 
   270 
   271 (* classrel *)
   272 
   273 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   274   Symtab.empty
   275   |> add_arities_table pp algebra arities);
   276 
   277 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   278   classes |> Graph.add_edge_trans_acyclic rel
   279     handle Graph.CYCLES css => err_cyclic_classes pp css);
   280 
   281 
   282 (* empty and merge *)
   283 
   284 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   285 
   286 fun merge_algebra pp
   287    (Algebra {classes = classes1, arities = arities1},
   288     Algebra {classes = classes2, arities = arities2}) =
   289   let
   290     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   291       handle Graph.DUP c => err_dup_class c
   292         | Graph.CYCLES css => err_cyclic_classes pp css;
   293     val algebra0 = make_algebra (classes', Symtab.empty);
   294     val arities' =
   295       (case (pointer_eq (classes1, classes2), pointer_eq (arities1, arities2)) of
   296         (true, true) => arities1
   297       | (true, false) =>  (*no completion*)
   298           (arities1, arities2) |> Symtab.join (fn t => fn (ars1, ars2) =>
   299             if pointer_eq (ars1, ars2) then raise Symtab.SAME
   300             else insert_ars pp algebra0 t ars2 ars1)
   301       | (false, true) =>  (*unary completion*)
   302           Symtab.empty
   303           |> add_arities_table pp algebra0 arities1
   304       | (false, false) => (*binary completion*)
   305           Symtab.empty
   306           |> add_arities_table pp algebra0 arities1
   307           |> add_arities_table pp algebra0 arities2);
   308   in make_algebra (classes', arities') end;
   309 
   310 
   311 (* algebra projections *)  (* FIXME potentially violates abstract type integrity *)
   312 
   313 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   314   let
   315     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   316     fun restrict_arity t (c, Ss) =
   317       if P c then
   318         (case sargs (c, t) of
   319           SOME sorts =>
   320             SOME (c, Ss |> map2 (curry (inter_sort algebra)) sorts |> map restrict_sort)
   321         | NONE => NONE)
   322       else NONE;
   323     val classes' = classes |> Graph.subgraph P;
   324     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   325   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   326 
   327 
   328 
   329 (** sorts of types **)
   330 
   331 (* errors -- performance tuning via delayed message composition *)
   332 
   333 datatype class_error =
   334   No_Classrel of class * class |
   335   No_Arity of string * class |
   336   No_Subsort of sort * sort;
   337 
   338 fun class_error pp (No_Classrel (c1, c2)) =
   339       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   340   | class_error pp (No_Arity (a, c)) =
   341       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c])
   342   | class_error pp (No_Subsort (S1, S2)) =
   343      "Cannot derive subsort relation " ^ Pretty.string_of_sort pp S1
   344        ^ " < " ^ Pretty.string_of_sort pp S2;
   345 
   346 exception CLASS_ERROR of class_error;
   347 
   348 
   349 (* mg_domain *)
   350 
   351 fun mg_domain algebra a S =
   352   let
   353     val arities = arities_of algebra;
   354     fun dom c =
   355       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   356         NONE => raise CLASS_ERROR (No_Arity (a, c))
   357       | SOME Ss => Ss);
   358     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   359   in
   360     (case S of
   361       [] => raise Fail "Unknown domain of empty intersection"
   362     | c :: cs => fold dom_inter cs (dom c))
   363   end;
   364 
   365 
   366 (* meet_sort *)
   367 
   368 fun meet_sort algebra =
   369   let
   370     fun inters S S' = inter_sort algebra (S, S');
   371     fun meet _ [] = I
   372       | meet (TFree (_, S)) S' =
   373           if sort_le algebra (S, S') then I
   374           else raise CLASS_ERROR (No_Subsort (S, S'))
   375       | meet (TVar (v, S)) S' =
   376           if sort_le algebra (S, S') then I
   377           else Vartab.map_default (v, S) (inters S')
   378       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   379   in uncurry meet end;
   380 
   381 fun meet_sort_typ algebra (T, S) =
   382   let val tab = meet_sort algebra (T, S) Vartab.empty;
   383   in Term.map_type_tvar (fn (v, _) => TVar (v, (the o Vartab.lookup tab) v)) end;
   384 
   385 
   386 (* of_sort *)
   387 
   388 fun of_sort algebra =
   389   let
   390     fun ofS (_, []) = true
   391       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   392       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   393       | ofS (Type (a, Ts), S) =
   394           let val Ss = mg_domain algebra a S in
   395             ListPair.all ofS (Ts, Ss)
   396           end handle CLASS_ERROR _ => false;
   397   in ofS end;
   398 
   399 
   400 (* animating derivations *)
   401 
   402 fun of_sort_derivation algebra {class_relation, type_constructor, type_variable} =
   403   let
   404     val arities = arities_of algebra;
   405 
   406     fun weaken T D1 S2 =
   407       let val S1 = map snd D1 in
   408         if S1 = S2 then map fst D1
   409         else
   410           S2 |> map (fn c2 =>
   411             (case D1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   412               SOME d1 => class_relation T d1 c2
   413             | NONE => raise CLASS_ERROR (No_Subsort (S1, S2))))
   414       end;
   415 
   416     fun derive (_, []) = []
   417       | derive (T as Type (a, Us), S) =
   418           let
   419             val Ss = mg_domain algebra a S;
   420             val dom = map2 (fn U => fn S => derive (U, S) ~~ S) Us Ss;
   421           in
   422             S |> map (fn c =>
   423               let
   424                 val Ss' = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   425                 val dom' = map (fn ((U, d), S') => weaken U d S' ~~ S') ((Us ~~ dom) ~~ Ss');
   426               in type_constructor (a, Us) dom' c end)
   427           end
   428       | derive (T, S) = weaken T (type_variable T) S;
   429   in derive end;
   430 
   431 fun classrel_derivation algebra class_relation =
   432   let
   433     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   434       | path (x, _) = x;
   435   in
   436     fn (x, c1) => fn c2 =>
   437       (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   438         [] => raise CLASS_ERROR (No_Classrel (c1, c2))
   439       | cs :: _ => path (x, cs))
   440   end;
   441 
   442 
   443 (* witness_sorts *)
   444 
   445 fun witness_sorts algebra types hyps sorts =
   446   let
   447     fun le S1 S2 = sort_le algebra (S1, S2);
   448     fun get S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   449     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   450 
   451     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   452       | witn_sort path S (solved, failed) =
   453           if exists (le S) failed then (NONE, (solved, failed))
   454           else
   455             (case get_first (get S) solved of
   456               SOME w => (SOME w, (solved, failed))
   457             | NONE =>
   458                 (case get_first (get S) hyps of
   459                   SOME w => (SOME w, (w :: solved, failed))
   460                 | NONE => witn_types path types S (solved, failed)))
   461 
   462     and witn_sorts path x = fold_map (witn_sort path) x
   463 
   464     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   465       | witn_types path (t :: ts) S solved_failed =
   466           (case mg_dom t S of
   467             SOME SS =>
   468               (*do not descend into stronger args (achieving termination)*)
   469               if exists (fn D => le D S orelse exists (le D) path) SS then
   470                 witn_types path ts S solved_failed
   471               else
   472                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   473                   if forall is_some ws then
   474                     let val w = (Type (t, map (#1 o the) ws), S)
   475                     in (SOME w, (w :: solved', failed')) end
   476                   else witn_types path ts S (solved', failed')
   477                 end
   478           | NONE => witn_types path ts S solved_failed);
   479 
   480   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   481 
   482 end;