src/Pure/sorts.ML
author wenzelm
Sun Nov 11 20:31:46 2012 +0100 (2012-11-11)
changeset 50081 9b92ee8dec98
parent 48272 db75b4005d9a
child 61262 7bd1eb4b056e
permissions -rw-r--r--
tuned;
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort Ord_List.T
    18   val subset: sort Ord_List.T * sort Ord_List.T -> bool
    19   val union: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    20   val subtract: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    21   val remove_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    22   val insert_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    23   val insert_typ: typ -> sort Ord_List.T -> sort Ord_List.T
    24   val insert_typs: typ list -> sort Ord_List.T -> sort Ord_List.T
    25   val insert_term: term -> sort Ord_List.T -> sort Ord_List.T
    26   val insert_terms: term list -> sort Ord_List.T -> sort Ord_List.T
    27   type algebra
    28   val classes_of: algebra -> serial Graph.T
    29   val arities_of: algebra -> (class * sort list) list Symtab.table
    30   val all_classes: algebra -> class list
    31   val super_classes: algebra -> class -> class list
    32   val class_less: algebra -> class * class -> bool
    33   val class_le: algebra -> class * class -> bool
    34   val sort_eq: algebra -> sort * sort -> bool
    35   val sort_le: algebra -> sort * sort -> bool
    36   val sorts_le: algebra -> sort list * sort list -> bool
    37   val inter_sort: algebra -> sort * sort -> sort
    38   val minimize_sort: algebra -> sort -> sort
    39   val complete_sort: algebra -> sort -> sort
    40   val minimal_sorts: algebra -> sort list -> sort Ord_List.T
    41   val add_class: Context.pretty -> class * class list -> algebra -> algebra
    42   val add_classrel: Context.pretty -> class * class -> algebra -> algebra
    43   val add_arities: Context.pretty -> string * (class * sort list) list -> algebra -> algebra
    44   val empty_algebra: algebra
    45   val merge_algebra: Context.pretty -> algebra * algebra -> algebra
    46   val subalgebra: Context.pretty -> (class -> bool) -> (class * string -> sort list option)
    47     -> algebra -> (sort -> sort) * algebra
    48   type class_error
    49   val class_error: Context.pretty -> class_error -> string
    50   exception CLASS_ERROR of class_error
    51   val has_instance: algebra -> string -> sort -> bool
    52   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    53   val meet_sort: algebra -> typ * sort
    54     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    55   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    56   val of_sort: algebra -> typ * sort -> bool
    57   val of_sort_derivation: algebra ->
    58     {class_relation: typ -> 'a * class -> class -> 'a,
    59      type_constructor: string * typ list -> ('a * class) list list -> class -> 'a,
    60      type_variable: typ -> ('a * class) list} ->
    61     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    62   val classrel_derivation: algebra ->
    63     ('a * class -> class -> 'a) -> 'a * class -> class -> 'a  (*exception CLASS_ERROR*)
    64   val witness_sorts: algebra -> string list -> (typ * sort) list -> sort list -> (typ * sort) list
    65 end;
    66 
    67 structure Sorts: SORTS =
    68 struct
    69 
    70 
    71 (** ordered lists of sorts **)
    72 
    73 val make = Ord_List.make Term_Ord.sort_ord;
    74 val subset = Ord_List.subset Term_Ord.sort_ord;
    75 val union = Ord_List.union Term_Ord.sort_ord;
    76 val subtract = Ord_List.subtract Term_Ord.sort_ord;
    77 
    78 val remove_sort = Ord_List.remove Term_Ord.sort_ord;
    79 val insert_sort = Ord_List.insert Term_Ord.sort_ord;
    80 
    81 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    82   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    83   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    84 and insert_typs [] Ss = Ss
    85   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    86 
    87 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    88   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    89   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    90   | insert_term (Bound _) Ss = Ss
    91   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    92   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    93 
    94 fun insert_terms [] Ss = Ss
    95   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    96 
    97 
    98 
    99 (** order-sorted algebra **)
   100 
   101 (*
   102   classes: graph representing class declarations together with proper
   103     subclass relation, which needs to be transitive and acyclic.
   104 
   105   arities: table of association lists of all type arities; (t, ars)
   106     means that type constructor t has the arities ars; an element
   107     (c, Ss) of ars represents the arity t::(Ss)c.  "Coregularity" of
   108     the arities structure requires that for any two declarations
   109     t::(Ss1)c1 and t::(Ss2)c2 such that c1 <= c2 holds Ss1 <= Ss2.
   110 *)
   111 
   112 datatype algebra = Algebra of
   113  {classes: serial Graph.T,
   114   arities: (class * sort list) list Symtab.table};
   115 
   116 fun classes_of (Algebra {classes, ...}) = classes;
   117 fun arities_of (Algebra {arities, ...}) = arities;
   118 
   119 fun make_algebra (classes, arities) =
   120   Algebra {classes = classes, arities = arities};
   121 
   122 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   123 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   124 
   125 
   126 (* classes *)
   127 
   128 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   129 
   130 val super_classes = Graph.immediate_succs o classes_of;
   131 
   132 
   133 (* class relations *)
   134 
   135 val class_less = Graph.is_edge o classes_of;
   136 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   137 
   138 
   139 (* sort relations *)
   140 
   141 fun sort_le algebra (S1, S2) =
   142   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   143 
   144 fun sorts_le algebra (Ss1, Ss2) =
   145   ListPair.all (sort_le algebra) (Ss1, Ss2);
   146 
   147 fun sort_eq algebra (S1, S2) =
   148   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   149 
   150 
   151 (* intersection *)
   152 
   153 fun inter_class algebra c S =
   154   let
   155     fun intr [] = [c]
   156       | intr (S' as c' :: c's) =
   157           if class_le algebra (c', c) then S'
   158           else if class_le algebra (c, c') then intr c's
   159           else c' :: intr c's
   160   in intr S end;
   161 
   162 fun inter_sort algebra (S1, S2) =
   163   sort_strings (fold (inter_class algebra) S1 S2);
   164 
   165 
   166 (* normal forms *)
   167 
   168 fun minimize_sort _ [] = []
   169   | minimize_sort _ (S as [_]) = S
   170   | minimize_sort algebra S =
   171       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   172       |> sort_distinct string_ord;
   173 
   174 fun complete_sort algebra =
   175   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   176 
   177 fun minimal_sorts algebra raw_sorts =
   178   let
   179     fun le S1 S2 = sort_le algebra (S1, S2);
   180     val sorts = make (map (minimize_sort algebra) raw_sorts);
   181   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   182 
   183 
   184 
   185 (** build algebras **)
   186 
   187 (* classes *)
   188 
   189 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   190 
   191 fun err_cyclic_classes pp css =
   192   error (cat_lines (map (fn cs =>
   193     "Cycle in class relation: " ^ Syntax.string_of_classrel (Syntax.init_pretty pp) cs) css));
   194 
   195 fun add_class pp (c, cs) = map_classes (fn classes =>
   196   let
   197     val classes' = classes |> Graph.new_node (c, serial ())
   198       handle Graph.DUP dup => err_dup_class dup;
   199     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   200       handle Graph.CYCLES css => err_cyclic_classes pp css;
   201   in classes'' end);
   202 
   203 
   204 (* arities *)
   205 
   206 local
   207 
   208 fun for_classes _ NONE = ""
   209   | for_classes ctxt (SOME (c1, c2)) = " for classes " ^ Syntax.string_of_classrel ctxt [c1, c2];
   210 
   211 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   212   let val ctxt = Syntax.init_pretty pp in
   213     error ("Conflict of type arities" ^ for_classes ctxt cc ^ ":\n  " ^
   214       Syntax.string_of_arity ctxt (t, Ss, [c]) ^ " and\n  " ^
   215       Syntax.string_of_arity ctxt (t, Ss', [c']))
   216   end;
   217 
   218 fun coregular pp algebra t (c, Ss) ars =
   219   let
   220     fun conflict (c', Ss') =
   221       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   222         SOME ((c, c'), (c', Ss'))
   223       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   224         SOME ((c', c), (c', Ss'))
   225       else NONE;
   226   in
   227     (case get_first conflict ars of
   228       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   229     | NONE => (c, Ss) :: ars)
   230   end;
   231 
   232 fun complete algebra (c, Ss) = map (rpair Ss) (c :: super_classes algebra c);
   233 
   234 fun insert pp algebra t (c, Ss) ars =
   235   (case AList.lookup (op =) ars c of
   236     NONE => coregular pp algebra t (c, Ss) ars
   237   | SOME Ss' =>
   238       if sorts_le algebra (Ss, Ss') then ars
   239       else if sorts_le algebra (Ss', Ss)
   240       then coregular pp algebra t (c, Ss) (remove (op =) (c, Ss') ars)
   241       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   242 
   243 in
   244 
   245 fun insert_ars pp algebra t = fold_rev (insert pp algebra t);
   246 
   247 fun insert_complete_ars pp algebra (t, ars) arities =
   248   let val ars' =
   249     Symtab.lookup_list arities t
   250     |> fold_rev (insert_ars pp algebra t) (map (complete algebra) ars);
   251   in Symtab.update (t, ars') arities end;
   252 
   253 fun add_arities pp arg algebra =
   254   algebra |> map_arities (insert_complete_ars pp algebra arg);
   255 
   256 fun add_arities_table pp algebra =
   257   Symtab.fold (fn (t, ars) => insert_complete_ars pp algebra (t, ars));
   258 
   259 end;
   260 
   261 
   262 (* classrel *)
   263 
   264 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   265   Symtab.empty
   266   |> add_arities_table pp algebra arities);
   267 
   268 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   269   classes |> Graph.add_edge_trans_acyclic rel
   270     handle Graph.CYCLES css => err_cyclic_classes pp css);
   271 
   272 
   273 (* empty and merge *)
   274 
   275 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   276 
   277 fun merge_algebra pp
   278    (Algebra {classes = classes1, arities = arities1},
   279     Algebra {classes = classes2, arities = arities2}) =
   280   let
   281     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   282       handle Graph.DUP c => err_dup_class c
   283         | Graph.CYCLES css => err_cyclic_classes pp css;
   284     val algebra0 = make_algebra (classes', Symtab.empty);
   285     val arities' =
   286       (case (pointer_eq (classes1, classes2), pointer_eq (arities1, arities2)) of
   287         (true, true) => arities1
   288       | (true, false) =>  (*no completion*)
   289           (arities1, arities2) |> Symtab.join (fn t => fn (ars1, ars2) =>
   290             if pointer_eq (ars1, ars2) then raise Symtab.SAME
   291             else insert_ars pp algebra0 t ars2 ars1)
   292       | (false, true) =>  (*unary completion*)
   293           Symtab.empty
   294           |> add_arities_table pp algebra0 arities1
   295       | (false, false) => (*binary completion*)
   296           Symtab.empty
   297           |> add_arities_table pp algebra0 arities1
   298           |> add_arities_table pp algebra0 arities2);
   299   in make_algebra (classes', arities') end;
   300 
   301 
   302 (* algebra projections *)  (* FIXME potentially violates abstract type integrity *)
   303 
   304 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   305   let
   306     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   307     fun restrict_arity t (c, Ss) =
   308       if P c then
   309         (case sargs (c, t) of
   310           SOME sorts =>
   311             SOME (c, Ss |> map2 (curry (inter_sort algebra)) sorts |> map restrict_sort)
   312         | NONE => NONE)
   313       else NONE;
   314     val classes' = classes |> Graph.restrict P;
   315     val arities' = arities |> Symtab.map (map_filter o restrict_arity);
   316   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   317 
   318 
   319 
   320 (** sorts of types **)
   321 
   322 (* errors -- performance tuning via delayed message composition *)
   323 
   324 datatype class_error =
   325   No_Classrel of class * class |
   326   No_Arity of string * class |
   327   No_Subsort of sort * sort;
   328 
   329 fun class_error pp =
   330   let val ctxt = Syntax.init_pretty pp in
   331     fn No_Classrel (c1, c2) => "No class relation " ^ Syntax.string_of_classrel ctxt [c1, c2]
   332      | No_Arity (a, c) => "No type arity " ^ Syntax.string_of_arity ctxt (a, [], [c])
   333      | No_Subsort (S1, S2) =>
   334         "Cannot derive subsort relation " ^
   335           Syntax.string_of_sort ctxt S1 ^ " < " ^ Syntax.string_of_sort ctxt S2
   336   end;
   337 
   338 exception CLASS_ERROR of class_error;
   339 
   340 
   341 (* instances *)
   342 
   343 fun has_instance algebra a =
   344   forall (AList.defined (op =) (Symtab.lookup_list (arities_of algebra) a));
   345 
   346 fun mg_domain algebra a S =
   347   let
   348     val ars = Symtab.lookup_list (arities_of algebra) a;
   349     fun dom c =
   350       (case AList.lookup (op =) ars c of
   351         NONE => raise CLASS_ERROR (No_Arity (a, c))
   352       | SOME Ss => Ss);
   353     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   354   in
   355     (case S of
   356       [] => raise Fail "Unknown domain of empty intersection"
   357     | c :: cs => fold dom_inter cs (dom c))
   358   end;
   359 
   360 
   361 (* meet_sort *)
   362 
   363 fun meet_sort algebra =
   364   let
   365     fun inters S S' = inter_sort algebra (S, S');
   366     fun meet _ [] = I
   367       | meet (TFree (_, S)) S' =
   368           if sort_le algebra (S, S') then I
   369           else raise CLASS_ERROR (No_Subsort (S, S'))
   370       | meet (TVar (v, S)) S' =
   371           if sort_le algebra (S, S') then I
   372           else Vartab.map_default (v, S) (inters S')
   373       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   374   in uncurry meet end;
   375 
   376 fun meet_sort_typ algebra (T, S) =
   377   let val tab = meet_sort algebra (T, S) Vartab.empty;
   378   in Term.map_type_tvar (fn (v, _) => TVar (v, (the o Vartab.lookup tab) v)) end;
   379 
   380 
   381 (* of_sort *)
   382 
   383 fun of_sort algebra =
   384   let
   385     fun ofS (_, []) = true
   386       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   387       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   388       | ofS (Type (a, Ts), S) =
   389           let val Ss = mg_domain algebra a S in
   390             ListPair.all ofS (Ts, Ss)
   391           end handle CLASS_ERROR _ => false;
   392   in ofS end;
   393 
   394 
   395 (* animating derivations *)
   396 
   397 fun of_sort_derivation algebra {class_relation, type_constructor, type_variable} =
   398   let
   399     val arities = arities_of algebra;
   400 
   401     fun weaken T D1 S2 =
   402       let val S1 = map snd D1 in
   403         if S1 = S2 then map fst D1
   404         else
   405           S2 |> map (fn c2 =>
   406             (case D1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   407               SOME d1 => class_relation T d1 c2
   408             | NONE => raise CLASS_ERROR (No_Subsort (S1, S2))))
   409       end;
   410 
   411     fun derive (_, []) = []
   412       | derive (Type (a, Us), S) =
   413           let
   414             val Ss = mg_domain algebra a S;
   415             val dom = map2 (fn U => fn S => derive (U, S) ~~ S) Us Ss;
   416           in
   417             S |> map (fn c =>
   418               let
   419                 val Ss' = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   420                 val dom' = map (fn ((U, d), S') => weaken U d S' ~~ S') ((Us ~~ dom) ~~ Ss');
   421               in type_constructor (a, Us) dom' c end)
   422           end
   423       | derive (T, S) = weaken T (type_variable T) S;
   424   in derive end;
   425 
   426 fun classrel_derivation algebra class_relation =
   427   let
   428     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   429       | path (x, _) = x;
   430   in
   431     fn (x, c1) => fn c2 =>
   432       (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   433         [] => raise CLASS_ERROR (No_Classrel (c1, c2))
   434       | cs :: _ => path (x, cs))
   435   end;
   436 
   437 
   438 (* witness_sorts *)
   439 
   440 fun witness_sorts algebra types hyps sorts =
   441   let
   442     fun le S1 S2 = sort_le algebra (S1, S2);
   443     fun get S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   444     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   445 
   446     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   447       | witn_sort path S (solved, failed) =
   448           if exists (le S) failed then (NONE, (solved, failed))
   449           else
   450             (case get_first (get S) solved of
   451               SOME w => (SOME w, (solved, failed))
   452             | NONE =>
   453                 (case get_first (get S) hyps of
   454                   SOME w => (SOME w, (w :: solved, failed))
   455                 | NONE => witn_types path types S (solved, failed)))
   456 
   457     and witn_sorts path x = fold_map (witn_sort path) x
   458 
   459     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   460       | witn_types path (t :: ts) S solved_failed =
   461           (case mg_dom t S of
   462             SOME SS =>
   463               (*do not descend into stronger args (achieving termination)*)
   464               if exists (fn D => le D S orelse exists (le D) path) SS then
   465                 witn_types path ts S solved_failed
   466               else
   467                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   468                   if forall is_some ws then
   469                     let val w = (Type (t, map (#1 o the) ws), S)
   470                     in (SOME w, (w :: solved', failed')) end
   471                   else witn_types path ts S (solved', failed')
   472                 end
   473           | NONE => witn_types path ts S solved_failed);
   474 
   475   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   476 
   477 end;