src/Pure/sorts.ML
author wenzelm
Thu Feb 23 15:49:40 2012 +0100 (2012-02-23)
changeset 46614 165886a4fe64
parent 45595 fe57d786fd5b
child 47005 421760a1efe7
permissions -rw-r--r--
clarified Graph.restrict (formerly Graph.subgraph) based on public graph operations;
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort Ord_List.T
    18   val subset: sort Ord_List.T * sort Ord_List.T -> bool
    19   val union: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    20   val subtract: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    21   val remove_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    22   val insert_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    23   val insert_typ: typ -> sort Ord_List.T -> sort Ord_List.T
    24   val insert_typs: typ list -> sort Ord_List.T -> sort Ord_List.T
    25   val insert_term: term -> sort Ord_List.T -> sort Ord_List.T
    26   val insert_terms: term list -> sort Ord_List.T -> sort Ord_List.T
    27   type algebra
    28   val classes_of: algebra -> serial Graph.T
    29   val arities_of: algebra -> (class * sort list) list Symtab.table
    30   val all_classes: algebra -> class list
    31   val super_classes: algebra -> class -> class list
    32   val class_less: algebra -> class * class -> bool
    33   val class_le: algebra -> class * class -> bool
    34   val sort_eq: algebra -> sort * sort -> bool
    35   val sort_le: algebra -> sort * sort -> bool
    36   val sorts_le: algebra -> sort list * sort list -> bool
    37   val inter_sort: algebra -> sort * sort -> sort
    38   val minimize_sort: algebra -> sort -> sort
    39   val complete_sort: algebra -> sort -> sort
    40   val minimal_sorts: algebra -> sort list -> sort Ord_List.T
    41   val add_class: Proof.context -> class * class list -> algebra -> algebra
    42   val add_classrel: Proof.context -> class * class -> algebra -> algebra
    43   val add_arities: Proof.context -> string * (class * sort list) list -> algebra -> algebra
    44   val empty_algebra: algebra
    45   val merge_algebra: Proof.context -> algebra * algebra -> algebra
    46   val subalgebra: Proof.context -> (class -> bool) -> (class * string -> sort list option)
    47     -> algebra -> (sort -> sort) * algebra
    48   type class_error
    49   val class_error: Proof.context -> class_error -> string
    50   exception CLASS_ERROR of class_error
    51   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    52   val meet_sort: algebra -> typ * sort
    53     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    54   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    55   val of_sort: algebra -> typ * sort -> bool
    56   val of_sort_derivation: algebra ->
    57     {class_relation: typ -> 'a * class -> class -> 'a,
    58      type_constructor: string * typ list -> ('a * class) list list -> class -> 'a,
    59      type_variable: typ -> ('a * class) list} ->
    60     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    61   val classrel_derivation: algebra ->
    62     ('a * class -> class -> 'a) -> 'a * class -> class -> 'a  (*exception CLASS_ERROR*)
    63   val witness_sorts: algebra -> string list -> (typ * sort) list -> sort list -> (typ * sort) list
    64 end;
    65 
    66 structure Sorts: SORTS =
    67 struct
    68 
    69 
    70 (** ordered lists of sorts **)
    71 
    72 val make = Ord_List.make Term_Ord.sort_ord;
    73 val subset = Ord_List.subset Term_Ord.sort_ord;
    74 val union = Ord_List.union Term_Ord.sort_ord;
    75 val subtract = Ord_List.subtract Term_Ord.sort_ord;
    76 
    77 val remove_sort = Ord_List.remove Term_Ord.sort_ord;
    78 val insert_sort = Ord_List.insert Term_Ord.sort_ord;
    79 
    80 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    81   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    82   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    83 and insert_typs [] Ss = Ss
    84   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    85 
    86 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    87   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    88   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    89   | insert_term (Bound _) Ss = Ss
    90   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    91   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    92 
    93 fun insert_terms [] Ss = Ss
    94   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    95 
    96 
    97 
    98 (** order-sorted algebra **)
    99 
   100 (*
   101   classes: graph representing class declarations together with proper
   102     subclass relation, which needs to be transitive and acyclic.
   103 
   104   arities: table of association lists of all type arities; (t, ars)
   105     means that type constructor t has the arities ars; an element
   106     (c, Ss) of ars represents the arity t::(Ss)c.  "Coregularity" of
   107     the arities structure requires that for any two declarations
   108     t::(Ss1)c1 and t::(Ss2)c2 such that c1 <= c2 holds Ss1 <= Ss2.
   109 *)
   110 
   111 datatype algebra = Algebra of
   112  {classes: serial Graph.T,
   113   arities: (class * sort list) list Symtab.table};
   114 
   115 fun classes_of (Algebra {classes, ...}) = classes;
   116 fun arities_of (Algebra {arities, ...}) = arities;
   117 
   118 fun make_algebra (classes, arities) =
   119   Algebra {classes = classes, arities = arities};
   120 
   121 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   122 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   123 
   124 
   125 (* classes *)
   126 
   127 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   128 
   129 val super_classes = Graph.immediate_succs o classes_of;
   130 
   131 
   132 (* class relations *)
   133 
   134 val class_less = Graph.is_edge o classes_of;
   135 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   136 
   137 
   138 (* sort relations *)
   139 
   140 fun sort_le algebra (S1, S2) =
   141   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   142 
   143 fun sorts_le algebra (Ss1, Ss2) =
   144   ListPair.all (sort_le algebra) (Ss1, Ss2);
   145 
   146 fun sort_eq algebra (S1, S2) =
   147   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   148 
   149 
   150 (* intersection *)
   151 
   152 fun inter_class algebra c S =
   153   let
   154     fun intr [] = [c]
   155       | intr (S' as c' :: c's) =
   156           if class_le algebra (c', c) then S'
   157           else if class_le algebra (c, c') then intr c's
   158           else c' :: intr c's
   159   in intr S end;
   160 
   161 fun inter_sort algebra (S1, S2) =
   162   sort_strings (fold (inter_class algebra) S1 S2);
   163 
   164 
   165 (* normal forms *)
   166 
   167 fun minimize_sort _ [] = []
   168   | minimize_sort _ (S as [_]) = S
   169   | minimize_sort algebra S =
   170       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   171       |> sort_distinct string_ord;
   172 
   173 fun complete_sort algebra =
   174   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   175 
   176 fun minimal_sorts algebra raw_sorts =
   177   let
   178     fun le S1 S2 = sort_le algebra (S1, S2);
   179     val sorts = make (map (minimize_sort algebra) raw_sorts);
   180   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   181 
   182 
   183 
   184 (** build algebras **)
   185 
   186 (* classes *)
   187 
   188 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   189 
   190 fun err_cyclic_classes ctxt css =
   191   error (cat_lines (map (fn cs =>
   192     "Cycle in class relation: " ^ Syntax.string_of_classrel ctxt cs) css));
   193 
   194 fun add_class ctxt (c, cs) = map_classes (fn classes =>
   195   let
   196     val classes' = classes |> Graph.new_node (c, serial ())
   197       handle Graph.DUP dup => err_dup_class dup;
   198     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   199       handle Graph.CYCLES css => err_cyclic_classes ctxt css;
   200   in classes'' end);
   201 
   202 
   203 (* arities *)
   204 
   205 local
   206 
   207 fun for_classes _ NONE = ""
   208   | for_classes ctxt (SOME (c1, c2)) = " for classes " ^ Syntax.string_of_classrel ctxt [c1, c2];
   209 
   210 fun err_conflict ctxt t cc (c, Ss) (c', Ss') =
   211   error ("Conflict of type arities" ^ for_classes ctxt cc ^ ":\n  " ^
   212     Syntax.string_of_arity ctxt (t, Ss, [c]) ^ " and\n  " ^
   213     Syntax.string_of_arity ctxt (t, Ss', [c']));
   214 
   215 fun coregular ctxt algebra t (c, Ss) ars =
   216   let
   217     fun conflict (c', Ss') =
   218       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   219         SOME ((c, c'), (c', Ss'))
   220       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   221         SOME ((c', c), (c', Ss'))
   222       else NONE;
   223   in
   224     (case get_first conflict ars of
   225       SOME ((c1, c2), (c', Ss')) => err_conflict ctxt t (SOME (c1, c2)) (c, Ss) (c', Ss')
   226     | NONE => (c, Ss) :: ars)
   227   end;
   228 
   229 fun complete algebra (c, Ss) = map (rpair Ss) (c :: super_classes algebra c);
   230 
   231 fun insert ctxt algebra t (c, Ss) ars =
   232   (case AList.lookup (op =) ars c of
   233     NONE => coregular ctxt algebra t (c, Ss) ars
   234   | SOME Ss' =>
   235       if sorts_le algebra (Ss, Ss') then ars
   236       else if sorts_le algebra (Ss', Ss)
   237       then coregular ctxt algebra t (c, Ss) (remove (op =) (c, Ss') ars)
   238       else err_conflict ctxt t NONE (c, Ss) (c, Ss'));
   239 
   240 in
   241 
   242 fun insert_ars ctxt algebra t = fold_rev (insert ctxt algebra t);
   243 
   244 fun insert_complete_ars ctxt algebra (t, ars) arities =
   245   let val ars' =
   246     Symtab.lookup_list arities t
   247     |> fold_rev (insert_ars ctxt algebra t) (map (complete algebra) ars);
   248   in Symtab.update (t, ars') arities end;
   249 
   250 fun add_arities ctxt arg algebra =
   251   algebra |> map_arities (insert_complete_ars ctxt algebra arg);
   252 
   253 fun add_arities_table ctxt algebra =
   254   Symtab.fold (fn (t, ars) => insert_complete_ars ctxt algebra (t, ars));
   255 
   256 end;
   257 
   258 
   259 (* classrel *)
   260 
   261 fun rebuild_arities ctxt algebra = algebra |> map_arities (fn arities =>
   262   Symtab.empty
   263   |> add_arities_table ctxt algebra arities);
   264 
   265 fun add_classrel ctxt rel = rebuild_arities ctxt o map_classes (fn classes =>
   266   classes |> Graph.add_edge_trans_acyclic rel
   267     handle Graph.CYCLES css => err_cyclic_classes ctxt css);
   268 
   269 
   270 (* empty and merge *)
   271 
   272 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   273 
   274 fun merge_algebra ctxt
   275    (Algebra {classes = classes1, arities = arities1},
   276     Algebra {classes = classes2, arities = arities2}) =
   277   let
   278     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   279       handle Graph.DUP c => err_dup_class c
   280         | Graph.CYCLES css => err_cyclic_classes ctxt css;
   281     val algebra0 = make_algebra (classes', Symtab.empty);
   282     val arities' =
   283       (case (pointer_eq (classes1, classes2), pointer_eq (arities1, arities2)) of
   284         (true, true) => arities1
   285       | (true, false) =>  (*no completion*)
   286           (arities1, arities2) |> Symtab.join (fn t => fn (ars1, ars2) =>
   287             if pointer_eq (ars1, ars2) then raise Symtab.SAME
   288             else insert_ars ctxt algebra0 t ars2 ars1)
   289       | (false, true) =>  (*unary completion*)
   290           Symtab.empty
   291           |> add_arities_table ctxt algebra0 arities1
   292       | (false, false) => (*binary completion*)
   293           Symtab.empty
   294           |> add_arities_table ctxt algebra0 arities1
   295           |> add_arities_table ctxt algebra0 arities2);
   296   in make_algebra (classes', arities') end;
   297 
   298 
   299 (* algebra projections *)  (* FIXME potentially violates abstract type integrity *)
   300 
   301 fun subalgebra ctxt P sargs (algebra as Algebra {classes, arities}) =
   302   let
   303     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   304     fun restrict_arity t (c, Ss) =
   305       if P c then
   306         (case sargs (c, t) of
   307           SOME sorts =>
   308             SOME (c, Ss |> map2 (curry (inter_sort algebra)) sorts |> map restrict_sort)
   309         | NONE => NONE)
   310       else NONE;
   311     val classes' = classes |> Graph.restrict P;
   312     val arities' = arities |> Symtab.map (map_filter o restrict_arity);
   313   in (restrict_sort, rebuild_arities ctxt (make_algebra (classes', arities'))) end;
   314 
   315 
   316 
   317 (** sorts of types **)
   318 
   319 (* errors -- performance tuning via delayed message composition *)
   320 
   321 datatype class_error =
   322   No_Classrel of class * class |
   323   No_Arity of string * class |
   324   No_Subsort of sort * sort;
   325 
   326 fun class_error ctxt (No_Classrel (c1, c2)) =
   327       "No class relation " ^ Syntax.string_of_classrel ctxt [c1, c2]
   328   | class_error ctxt (No_Arity (a, c)) =
   329       "No type arity " ^ Syntax.string_of_arity ctxt (a, [], [c])
   330   | class_error ctxt (No_Subsort (S1, S2)) =
   331       "Cannot derive subsort relation " ^
   332         Syntax.string_of_sort ctxt S1 ^ " < " ^ Syntax.string_of_sort ctxt S2;
   333 
   334 exception CLASS_ERROR of class_error;
   335 
   336 
   337 (* mg_domain *)
   338 
   339 fun mg_domain algebra a S =
   340   let
   341     val arities = arities_of algebra;
   342     fun dom c =
   343       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   344         NONE => raise CLASS_ERROR (No_Arity (a, c))
   345       | SOME Ss => Ss);
   346     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   347   in
   348     (case S of
   349       [] => raise Fail "Unknown domain of empty intersection"
   350     | c :: cs => fold dom_inter cs (dom c))
   351   end;
   352 
   353 
   354 (* meet_sort *)
   355 
   356 fun meet_sort algebra =
   357   let
   358     fun inters S S' = inter_sort algebra (S, S');
   359     fun meet _ [] = I
   360       | meet (TFree (_, S)) S' =
   361           if sort_le algebra (S, S') then I
   362           else raise CLASS_ERROR (No_Subsort (S, S'))
   363       | meet (TVar (v, S)) S' =
   364           if sort_le algebra (S, S') then I
   365           else Vartab.map_default (v, S) (inters S')
   366       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   367   in uncurry meet end;
   368 
   369 fun meet_sort_typ algebra (T, S) =
   370   let val tab = meet_sort algebra (T, S) Vartab.empty;
   371   in Term.map_type_tvar (fn (v, _) => TVar (v, (the o Vartab.lookup tab) v)) end;
   372 
   373 
   374 (* of_sort *)
   375 
   376 fun of_sort algebra =
   377   let
   378     fun ofS (_, []) = true
   379       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   380       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   381       | ofS (Type (a, Ts), S) =
   382           let val Ss = mg_domain algebra a S in
   383             ListPair.all ofS (Ts, Ss)
   384           end handle CLASS_ERROR _ => false;
   385   in ofS end;
   386 
   387 
   388 (* animating derivations *)
   389 
   390 fun of_sort_derivation algebra {class_relation, type_constructor, type_variable} =
   391   let
   392     val arities = arities_of algebra;
   393 
   394     fun weaken T D1 S2 =
   395       let val S1 = map snd D1 in
   396         if S1 = S2 then map fst D1
   397         else
   398           S2 |> map (fn c2 =>
   399             (case D1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   400               SOME d1 => class_relation T d1 c2
   401             | NONE => raise CLASS_ERROR (No_Subsort (S1, S2))))
   402       end;
   403 
   404     fun derive (_, []) = []
   405       | derive (Type (a, Us), S) =
   406           let
   407             val Ss = mg_domain algebra a S;
   408             val dom = map2 (fn U => fn S => derive (U, S) ~~ S) Us Ss;
   409           in
   410             S |> map (fn c =>
   411               let
   412                 val Ss' = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   413                 val dom' = map (fn ((U, d), S') => weaken U d S' ~~ S') ((Us ~~ dom) ~~ Ss');
   414               in type_constructor (a, Us) dom' c end)
   415           end
   416       | derive (T, S) = weaken T (type_variable T) S;
   417   in derive end;
   418 
   419 fun classrel_derivation algebra class_relation =
   420   let
   421     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   422       | path (x, _) = x;
   423   in
   424     fn (x, c1) => fn c2 =>
   425       (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   426         [] => raise CLASS_ERROR (No_Classrel (c1, c2))
   427       | cs :: _ => path (x, cs))
   428   end;
   429 
   430 
   431 (* witness_sorts *)
   432 
   433 fun witness_sorts algebra types hyps sorts =
   434   let
   435     fun le S1 S2 = sort_le algebra (S1, S2);
   436     fun get S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   437     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   438 
   439     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   440       | witn_sort path S (solved, failed) =
   441           if exists (le S) failed then (NONE, (solved, failed))
   442           else
   443             (case get_first (get S) solved of
   444               SOME w => (SOME w, (solved, failed))
   445             | NONE =>
   446                 (case get_first (get S) hyps of
   447                   SOME w => (SOME w, (w :: solved, failed))
   448                 | NONE => witn_types path types S (solved, failed)))
   449 
   450     and witn_sorts path x = fold_map (witn_sort path) x
   451 
   452     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   453       | witn_types path (t :: ts) S solved_failed =
   454           (case mg_dom t S of
   455             SOME SS =>
   456               (*do not descend into stronger args (achieving termination)*)
   457               if exists (fn D => le D S orelse exists (le D) path) SS then
   458                 witn_types path ts S solved_failed
   459               else
   460                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   461                   if forall is_some ws then
   462                     let val w = (Type (t, map (#1 o the) ws), S)
   463                     in (SOME w, (w :: solved', failed')) end
   464                   else witn_types path ts S (solved', failed')
   465                 end
   466           | NONE => witn_types path ts S solved_failed);
   467 
   468   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   469 
   470 end;