src/Pure/sorts.ML
author wenzelm
Fri May 28 20:41:23 2010 +0200 (2010-05-28)
changeset 37174 6feaab4fc27d
parent 36429 9d6b3be996d4
child 37248 8e8e5f9d1441
permissions -rw-r--r--
assume given SCALA_HOME, e.g. from component settings or external setup;
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort OrdList.T
    18   val subset: sort OrdList.T * sort OrdList.T -> bool
    19   val union: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    20   val subtract: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    21   val remove_sort: sort -> sort OrdList.T -> sort OrdList.T
    22   val insert_sort: sort -> sort OrdList.T -> sort OrdList.T
    23   val insert_typ: typ -> sort OrdList.T -> sort OrdList.T
    24   val insert_typs: typ list -> sort OrdList.T -> sort OrdList.T
    25   val insert_term: term -> sort OrdList.T -> sort OrdList.T
    26   val insert_terms: term list -> sort OrdList.T -> sort OrdList.T
    27   type algebra
    28   val classes_of: algebra -> serial Graph.T
    29   val arities_of: algebra -> (class * (class * sort list)) list Symtab.table
    30   val all_classes: algebra -> class list
    31   val super_classes: algebra -> class -> class list
    32   val class_less: algebra -> class * class -> bool
    33   val class_le: algebra -> class * class -> bool
    34   val sort_eq: algebra -> sort * sort -> bool
    35   val sort_le: algebra -> sort * sort -> bool
    36   val sorts_le: algebra -> sort list * sort list -> bool
    37   val inter_sort: algebra -> sort * sort -> sort
    38   val minimize_sort: algebra -> sort -> sort
    39   val complete_sort: algebra -> sort -> sort
    40   val minimal_sorts: algebra -> sort list -> sort OrdList.T
    41   val certify_class: algebra -> class -> class    (*exception TYPE*)
    42   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    43   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    44   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    45   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    46   val empty_algebra: algebra
    47   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    48   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list option)
    49     -> algebra -> (sort -> sort) * algebra
    50   type class_error
    51   val class_error: Pretty.pp -> class_error -> string
    52   exception CLASS_ERROR of class_error
    53   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    54   val meet_sort: algebra -> typ * sort
    55     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    56   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    57   val of_sort: algebra -> typ * sort -> bool
    58   val of_sort_derivation: algebra ->
    59     {class_relation: typ -> 'a * class -> class -> 'a,
    60      type_constructor: string * typ list -> ('a * class) list list -> class -> 'a,
    61      type_variable: typ -> ('a * class) list} ->
    62     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    63   val classrel_derivation: algebra ->
    64     ('a * class -> class -> 'a) -> 'a * class -> class -> 'a  (*exception CLASS_ERROR*)
    65   val witness_sorts: algebra -> string list -> (typ * sort) list -> sort list -> (typ * sort) list
    66 end;
    67 
    68 structure Sorts: SORTS =
    69 struct
    70 
    71 
    72 (** ordered lists of sorts **)
    73 
    74 val make = OrdList.make Term_Ord.sort_ord;
    75 val subset = OrdList.subset Term_Ord.sort_ord;
    76 val union = OrdList.union Term_Ord.sort_ord;
    77 val subtract = OrdList.subtract Term_Ord.sort_ord;
    78 
    79 val remove_sort = OrdList.remove Term_Ord.sort_ord;
    80 val insert_sort = OrdList.insert Term_Ord.sort_ord;
    81 
    82 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    83   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    84   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    85 and insert_typs [] Ss = Ss
    86   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    87 
    88 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    89   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    90   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    91   | insert_term (Bound _) Ss = Ss
    92   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    93   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    94 
    95 fun insert_terms [] Ss = Ss
    96   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    97 
    98 
    99 
   100 (** order-sorted algebra **)
   101 
   102 (*
   103   classes: graph representing class declarations together with proper
   104     subclass relation, which needs to be transitive and acyclic.
   105 
   106   arities: table of association lists of all type arities; (t, ars)
   107     means that type constructor t has the arities ars; an element
   108     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   109     via c0 <= c.  "Coregularity" of the arities structure requires
   110     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   111     c1 <= c2 holds Ss1 <= Ss2.
   112 *)
   113 
   114 datatype algebra = Algebra of
   115  {classes: serial Graph.T,
   116   arities: (class * (class * sort list)) list Symtab.table};
   117 
   118 fun classes_of (Algebra {classes, ...}) = classes;
   119 fun arities_of (Algebra {arities, ...}) = arities;
   120 
   121 fun make_algebra (classes, arities) =
   122   Algebra {classes = classes, arities = arities};
   123 
   124 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   125 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   126 
   127 
   128 (* classes *)
   129 
   130 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   131 
   132 val super_classes = Graph.imm_succs o classes_of;
   133 
   134 
   135 (* class relations *)
   136 
   137 val class_less = Graph.is_edge o classes_of;
   138 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   139 
   140 
   141 (* sort relations *)
   142 
   143 fun sort_le algebra (S1, S2) =
   144   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   145 
   146 fun sorts_le algebra (Ss1, Ss2) =
   147   ListPair.all (sort_le algebra) (Ss1, Ss2);
   148 
   149 fun sort_eq algebra (S1, S2) =
   150   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   151 
   152 
   153 (* intersection *)
   154 
   155 fun inter_class algebra c S =
   156   let
   157     fun intr [] = [c]
   158       | intr (S' as c' :: c's) =
   159           if class_le algebra (c', c) then S'
   160           else if class_le algebra (c, c') then intr c's
   161           else c' :: intr c's
   162   in intr S end;
   163 
   164 fun inter_sort algebra (S1, S2) =
   165   sort_strings (fold (inter_class algebra) S1 S2);
   166 
   167 
   168 (* normal forms *)
   169 
   170 fun minimize_sort _ [] = []
   171   | minimize_sort _ (S as [_]) = S
   172   | minimize_sort algebra S =
   173       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   174       |> sort_distinct string_ord;
   175 
   176 fun complete_sort algebra =
   177   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   178 
   179 fun minimal_sorts algebra raw_sorts =
   180   let
   181     fun le S1 S2 = sort_le algebra (S1, S2);
   182     val sorts = make (map (minimize_sort algebra) raw_sorts);
   183   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   184 
   185 
   186 (* certify *)
   187 
   188 fun certify_class algebra c =
   189   if can (Graph.get_node (classes_of algebra)) c then c
   190   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   191 
   192 fun certify_sort classes = map (certify_class classes);
   193 
   194 
   195 
   196 (** build algebras **)
   197 
   198 (* classes *)
   199 
   200 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   201 
   202 fun err_cyclic_classes pp css =
   203   error (cat_lines (map (fn cs =>
   204     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   205 
   206 fun add_class pp (c, cs) = map_classes (fn classes =>
   207   let
   208     val classes' = classes |> Graph.new_node (c, serial ())
   209       handle Graph.DUP dup => err_dup_class dup;
   210     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   211       handle Graph.CYCLES css => err_cyclic_classes pp css;
   212   in classes'' end);
   213 
   214 
   215 (* arities *)
   216 
   217 local
   218 
   219 fun for_classes _ NONE = ""
   220   | for_classes pp (SOME (c1, c2)) =
   221       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   222 
   223 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   224   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   225     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   226     Pretty.string_of_arity pp (t, Ss', [c']));
   227 
   228 fun coregular pp algebra t (c, (c0, Ss)) ars =
   229   let
   230     fun conflict (c', (_, Ss')) =
   231       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   232         SOME ((c, c'), (c', Ss'))
   233       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   234         SOME ((c', c), (c', Ss'))
   235       else NONE;
   236   in
   237     (case get_first conflict ars of
   238       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   239     | NONE => (c, (c0, Ss)) :: ars)
   240   end;
   241 
   242 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   243 
   244 fun insert pp algebra t (c, (c0, Ss)) ars =
   245   (case AList.lookup (op =) ars c of
   246     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   247   | SOME (_, Ss') =>
   248       if sorts_le algebra (Ss, Ss') then ars
   249       else if sorts_le algebra (Ss', Ss) then
   250         coregular pp algebra t (c, (c0, Ss))
   251           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   252       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   253 
   254 in
   255 
   256 fun insert_ars pp algebra t = fold_rev (insert pp algebra t);
   257 
   258 fun insert_complete_ars pp algebra (t, ars) arities =
   259   let val ars' =
   260     Symtab.lookup_list arities t
   261     |> fold_rev (insert_ars pp algebra t) (map (complete algebra) ars);
   262   in Symtab.update (t, ars') arities end;
   263 
   264 fun add_arities pp arg algebra =
   265   algebra |> map_arities (insert_complete_ars pp algebra arg);
   266 
   267 fun add_arities_table pp algebra =
   268   Symtab.fold (fn (t, ars) => insert_complete_ars pp algebra (t, map snd ars));
   269 
   270 end;
   271 
   272 
   273 (* classrel *)
   274 
   275 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   276   Symtab.empty
   277   |> add_arities_table pp algebra arities);
   278 
   279 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   280   classes |> Graph.add_edge_trans_acyclic rel
   281     handle Graph.CYCLES css => err_cyclic_classes pp css);
   282 
   283 
   284 (* empty and merge *)
   285 
   286 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   287 
   288 fun merge_algebra pp
   289    (Algebra {classes = classes1, arities = arities1},
   290     Algebra {classes = classes2, arities = arities2}) =
   291   let
   292     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   293       handle Graph.DUP c => err_dup_class c
   294         | Graph.CYCLES css => err_cyclic_classes pp css;
   295     val algebra0 = make_algebra (classes', Symtab.empty);
   296     val arities' =
   297       (case (pointer_eq (classes1, classes2), pointer_eq (arities1, arities2)) of
   298         (true, true) => arities1
   299       | (true, false) =>  (*no completion*)
   300           (arities1, arities2) |> Symtab.join (fn t => fn (ars1, ars2) =>
   301             if pointer_eq (ars1, ars2) then raise Symtab.SAME
   302             else insert_ars pp algebra0 t ars2 ars1)
   303       | (false, true) =>  (*unary completion*)
   304           Symtab.empty
   305           |> add_arities_table pp algebra0 arities1
   306       | (false, false) => (*binary completion*)
   307           Symtab.empty
   308           |> add_arities_table pp algebra0 arities1
   309           |> add_arities_table pp algebra0 arities2);
   310   in make_algebra (classes', arities') end;
   311 
   312 
   313 (* algebra projections *)
   314 
   315 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   316   let
   317     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   318     fun restrict_arity tyco (c, (_, Ss)) =
   319       if P c then case sargs (c, tyco)
   320        of SOME sorts => SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) sorts
   321           |> map restrict_sort))
   322         | NONE => NONE
   323       else NONE;
   324     val classes' = classes |> Graph.subgraph P;
   325     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   326   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   327 
   328 
   329 
   330 (** sorts of types **)
   331 
   332 (* errors -- performance tuning via delayed message composition *)
   333 
   334 datatype class_error =
   335   No_Classrel of class * class |
   336   No_Arity of string * class |
   337   No_Subsort of sort * sort;
   338 
   339 fun class_error pp (No_Classrel (c1, c2)) =
   340       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   341   | class_error pp (No_Arity (a, c)) =
   342       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c])
   343   | class_error pp (No_Subsort (S1, S2)) =
   344      "Cannot derive subsort relation " ^ Pretty.string_of_sort pp S1
   345        ^ " < " ^ Pretty.string_of_sort pp S2;
   346 
   347 exception CLASS_ERROR of class_error;
   348 
   349 
   350 (* mg_domain *)
   351 
   352 fun mg_domain algebra a S =
   353   let
   354     val arities = arities_of algebra;
   355     fun dom c =
   356       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   357         NONE => raise CLASS_ERROR (No_Arity (a, c))
   358       | SOME (_, Ss) => Ss);
   359     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   360   in
   361     (case S of
   362       [] => raise Fail "Unknown domain of empty intersection"
   363     | c :: cs => fold dom_inter cs (dom c))
   364   end;
   365 
   366 
   367 (* meet_sort *)
   368 
   369 fun meet_sort algebra =
   370   let
   371     fun inters S S' = inter_sort algebra (S, S');
   372     fun meet _ [] = I
   373       | meet (TFree (_, S)) S' =
   374           if sort_le algebra (S, S') then I
   375           else raise CLASS_ERROR (No_Subsort (S, S'))
   376       | meet (TVar (v, S)) S' =
   377           if sort_le algebra (S, S') then I
   378           else Vartab.map_default (v, S) (inters S')
   379       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   380   in uncurry meet end;
   381 
   382 fun meet_sort_typ algebra (T, S) =
   383   let
   384     val tab = meet_sort algebra (T, S) Vartab.empty;
   385   in Term.map_type_tvar (fn (v, _) =>
   386     TVar (v, (the o Vartab.lookup tab) v))
   387   end;
   388 
   389 
   390 (* of_sort *)
   391 
   392 fun of_sort algebra =
   393   let
   394     fun ofS (_, []) = true
   395       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   396       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   397       | ofS (Type (a, Ts), S) =
   398           let val Ss = mg_domain algebra a S in
   399             ListPair.all ofS (Ts, Ss)
   400           end handle CLASS_ERROR _ => false;
   401   in ofS end;
   402 
   403 
   404 (* animating derivations *)
   405 
   406 fun of_sort_derivation algebra {class_relation, type_constructor, type_variable} =
   407   let
   408     val arities = arities_of algebra;
   409 
   410     fun weaken T D1 S2 =
   411       let val S1 = map snd D1 in
   412         if S1 = S2 then map fst D1
   413         else
   414           S2 |> map (fn c2 =>
   415             (case D1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   416               SOME d1 => class_relation T d1 c2
   417             | NONE => raise CLASS_ERROR (No_Subsort (S1, S2))))
   418       end;
   419 
   420     fun derive (_, []) = []
   421       | derive (T as Type (a, Us), S) =
   422           let
   423             val Ss = mg_domain algebra a S;
   424             val dom = map2 (fn U => fn S => derive (U, S) ~~ S) Us Ss;
   425           in
   426             S |> map (fn c =>
   427               let
   428                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   429                 val dom' = map (fn ((U, d), S') => weaken U d S' ~~ S') ((Us ~~ dom) ~~ Ss');
   430               in class_relation T (type_constructor (a, Us) dom' c0, c0) c end)
   431           end
   432       | derive (T, S) = weaken T (type_variable T) S;
   433   in derive end;
   434 
   435 fun classrel_derivation algebra class_relation =
   436   let
   437     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   438       | path (x, _) = x;
   439   in
   440     fn (x, c1) => fn c2 =>
   441       (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   442         [] => raise CLASS_ERROR (No_Classrel (c1, c2))
   443       | cs :: _ => path (x, cs))
   444   end;
   445 
   446 
   447 (* witness_sorts *)
   448 
   449 fun witness_sorts algebra types hyps sorts =
   450   let
   451     fun le S1 S2 = sort_le algebra (S1, S2);
   452     fun get S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   453     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   454 
   455     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   456       | witn_sort path S (solved, failed) =
   457           if exists (le S) failed then (NONE, (solved, failed))
   458           else
   459             (case get_first (get S) solved of
   460               SOME w => (SOME w, (solved, failed))
   461             | NONE =>
   462                 (case get_first (get S) hyps of
   463                   SOME w => (SOME w, (w :: solved, failed))
   464                 | NONE => witn_types path types S (solved, failed)))
   465 
   466     and witn_sorts path x = fold_map (witn_sort path) x
   467 
   468     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   469       | witn_types path (t :: ts) S solved_failed =
   470           (case mg_dom t S of
   471             SOME SS =>
   472               (*do not descend into stronger args (achieving termination)*)
   473               if exists (fn D => le D S orelse exists (le D) path) SS then
   474                 witn_types path ts S solved_failed
   475               else
   476                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   477                   if forall is_some ws then
   478                     let val w = (Type (t, map (#1 o the) ws), S)
   479                     in (SOME w, (w :: solved', failed')) end
   480                   else witn_types path ts S (solved', failed')
   481                 end
   482           | NONE => witn_types path ts S solved_failed);
   483 
   484   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   485 
   486 end;