src/Pure/sorts.ML
author wenzelm
Sun Mar 18 13:04:22 2012 +0100 (2012-03-18)
changeset 47005 421760a1efe7
parent 46614 165886a4fe64
child 48272 db75b4005d9a
permissions -rw-r--r--
maintain generic context naming in structure Name_Space (NB: empty = default_naming, init = local_naming);
more explicit Context.generic for Name_Space.declare/define and derivatives (NB: naming changed after Proof_Context.init_global);
prefer Context.pretty in low-level operations of structure Sorts/Type (defer full Syntax.init_pretty until error output);
simplified signatures;
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort Ord_List.T
    18   val subset: sort Ord_List.T * sort Ord_List.T -> bool
    19   val union: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    20   val subtract: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    21   val remove_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    22   val insert_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    23   val insert_typ: typ -> sort Ord_List.T -> sort Ord_List.T
    24   val insert_typs: typ list -> sort Ord_List.T -> sort Ord_List.T
    25   val insert_term: term -> sort Ord_List.T -> sort Ord_List.T
    26   val insert_terms: term list -> sort Ord_List.T -> sort Ord_List.T
    27   type algebra
    28   val classes_of: algebra -> serial Graph.T
    29   val arities_of: algebra -> (class * sort list) list Symtab.table
    30   val all_classes: algebra -> class list
    31   val super_classes: algebra -> class -> class list
    32   val class_less: algebra -> class * class -> bool
    33   val class_le: algebra -> class * class -> bool
    34   val sort_eq: algebra -> sort * sort -> bool
    35   val sort_le: algebra -> sort * sort -> bool
    36   val sorts_le: algebra -> sort list * sort list -> bool
    37   val inter_sort: algebra -> sort * sort -> sort
    38   val minimize_sort: algebra -> sort -> sort
    39   val complete_sort: algebra -> sort -> sort
    40   val minimal_sorts: algebra -> sort list -> sort Ord_List.T
    41   val add_class: Context.pretty -> class * class list -> algebra -> algebra
    42   val add_classrel: Context.pretty -> class * class -> algebra -> algebra
    43   val add_arities: Context.pretty -> string * (class * sort list) list -> algebra -> algebra
    44   val empty_algebra: algebra
    45   val merge_algebra: Context.pretty -> algebra * algebra -> algebra
    46   val subalgebra: Context.pretty -> (class -> bool) -> (class * string -> sort list option)
    47     -> algebra -> (sort -> sort) * algebra
    48   type class_error
    49   val class_error: Context.pretty -> class_error -> string
    50   exception CLASS_ERROR of class_error
    51   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    52   val meet_sort: algebra -> typ * sort
    53     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    54   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    55   val of_sort: algebra -> typ * sort -> bool
    56   val of_sort_derivation: algebra ->
    57     {class_relation: typ -> 'a * class -> class -> 'a,
    58      type_constructor: string * typ list -> ('a * class) list list -> class -> 'a,
    59      type_variable: typ -> ('a * class) list} ->
    60     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    61   val classrel_derivation: algebra ->
    62     ('a * class -> class -> 'a) -> 'a * class -> class -> 'a  (*exception CLASS_ERROR*)
    63   val witness_sorts: algebra -> string list -> (typ * sort) list -> sort list -> (typ * sort) list
    64 end;
    65 
    66 structure Sorts: SORTS =
    67 struct
    68 
    69 
    70 (** ordered lists of sorts **)
    71 
    72 val make = Ord_List.make Term_Ord.sort_ord;
    73 val subset = Ord_List.subset Term_Ord.sort_ord;
    74 val union = Ord_List.union Term_Ord.sort_ord;
    75 val subtract = Ord_List.subtract Term_Ord.sort_ord;
    76 
    77 val remove_sort = Ord_List.remove Term_Ord.sort_ord;
    78 val insert_sort = Ord_List.insert Term_Ord.sort_ord;
    79 
    80 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    81   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    82   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    83 and insert_typs [] Ss = Ss
    84   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    85 
    86 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    87   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    88   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    89   | insert_term (Bound _) Ss = Ss
    90   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    91   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    92 
    93 fun insert_terms [] Ss = Ss
    94   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    95 
    96 
    97 
    98 (** order-sorted algebra **)
    99 
   100 (*
   101   classes: graph representing class declarations together with proper
   102     subclass relation, which needs to be transitive and acyclic.
   103 
   104   arities: table of association lists of all type arities; (t, ars)
   105     means that type constructor t has the arities ars; an element
   106     (c, Ss) of ars represents the arity t::(Ss)c.  "Coregularity" of
   107     the arities structure requires that for any two declarations
   108     t::(Ss1)c1 and t::(Ss2)c2 such that c1 <= c2 holds Ss1 <= Ss2.
   109 *)
   110 
   111 datatype algebra = Algebra of
   112  {classes: serial Graph.T,
   113   arities: (class * sort list) list Symtab.table};
   114 
   115 fun classes_of (Algebra {classes, ...}) = classes;
   116 fun arities_of (Algebra {arities, ...}) = arities;
   117 
   118 fun make_algebra (classes, arities) =
   119   Algebra {classes = classes, arities = arities};
   120 
   121 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   122 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   123 
   124 
   125 (* classes *)
   126 
   127 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   128 
   129 val super_classes = Graph.immediate_succs o classes_of;
   130 
   131 
   132 (* class relations *)
   133 
   134 val class_less = Graph.is_edge o classes_of;
   135 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   136 
   137 
   138 (* sort relations *)
   139 
   140 fun sort_le algebra (S1, S2) =
   141   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   142 
   143 fun sorts_le algebra (Ss1, Ss2) =
   144   ListPair.all (sort_le algebra) (Ss1, Ss2);
   145 
   146 fun sort_eq algebra (S1, S2) =
   147   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   148 
   149 
   150 (* intersection *)
   151 
   152 fun inter_class algebra c S =
   153   let
   154     fun intr [] = [c]
   155       | intr (S' as c' :: c's) =
   156           if class_le algebra (c', c) then S'
   157           else if class_le algebra (c, c') then intr c's
   158           else c' :: intr c's
   159   in intr S end;
   160 
   161 fun inter_sort algebra (S1, S2) =
   162   sort_strings (fold (inter_class algebra) S1 S2);
   163 
   164 
   165 (* normal forms *)
   166 
   167 fun minimize_sort _ [] = []
   168   | minimize_sort _ (S as [_]) = S
   169   | minimize_sort algebra S =
   170       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   171       |> sort_distinct string_ord;
   172 
   173 fun complete_sort algebra =
   174   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   175 
   176 fun minimal_sorts algebra raw_sorts =
   177   let
   178     fun le S1 S2 = sort_le algebra (S1, S2);
   179     val sorts = make (map (minimize_sort algebra) raw_sorts);
   180   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   181 
   182 
   183 
   184 (** build algebras **)
   185 
   186 (* classes *)
   187 
   188 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   189 
   190 fun err_cyclic_classes pp css =
   191   error (cat_lines (map (fn cs =>
   192     "Cycle in class relation: " ^ Syntax.string_of_classrel (Syntax.init_pretty pp) cs) css));
   193 
   194 fun add_class pp (c, cs) = map_classes (fn classes =>
   195   let
   196     val classes' = classes |> Graph.new_node (c, serial ())
   197       handle Graph.DUP dup => err_dup_class dup;
   198     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   199       handle Graph.CYCLES css => err_cyclic_classes pp css;
   200   in classes'' end);
   201 
   202 
   203 (* arities *)
   204 
   205 local
   206 
   207 fun for_classes _ NONE = ""
   208   | for_classes ctxt (SOME (c1, c2)) = " for classes " ^ Syntax.string_of_classrel ctxt [c1, c2];
   209 
   210 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   211   let val ctxt = Syntax.init_pretty pp in
   212     error ("Conflict of type arities" ^ for_classes ctxt cc ^ ":\n  " ^
   213       Syntax.string_of_arity ctxt (t, Ss, [c]) ^ " and\n  " ^
   214       Syntax.string_of_arity ctxt (t, Ss', [c']))
   215   end;
   216 
   217 fun coregular pp algebra t (c, Ss) ars =
   218   let
   219     fun conflict (c', Ss') =
   220       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   221         SOME ((c, c'), (c', Ss'))
   222       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   223         SOME ((c', c), (c', Ss'))
   224       else NONE;
   225   in
   226     (case get_first conflict ars of
   227       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   228     | NONE => (c, Ss) :: ars)
   229   end;
   230 
   231 fun complete algebra (c, Ss) = map (rpair Ss) (c :: super_classes algebra c);
   232 
   233 fun insert pp algebra t (c, Ss) ars =
   234   (case AList.lookup (op =) ars c of
   235     NONE => coregular pp algebra t (c, Ss) ars
   236   | SOME Ss' =>
   237       if sorts_le algebra (Ss, Ss') then ars
   238       else if sorts_le algebra (Ss', Ss)
   239       then coregular pp algebra t (c, Ss) (remove (op =) (c, Ss') ars)
   240       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   241 
   242 in
   243 
   244 fun insert_ars pp algebra t = fold_rev (insert pp algebra t);
   245 
   246 fun insert_complete_ars pp algebra (t, ars) arities =
   247   let val ars' =
   248     Symtab.lookup_list arities t
   249     |> fold_rev (insert_ars pp algebra t) (map (complete algebra) ars);
   250   in Symtab.update (t, ars') arities end;
   251 
   252 fun add_arities pp arg algebra =
   253   algebra |> map_arities (insert_complete_ars pp algebra arg);
   254 
   255 fun add_arities_table pp algebra =
   256   Symtab.fold (fn (t, ars) => insert_complete_ars pp algebra (t, ars));
   257 
   258 end;
   259 
   260 
   261 (* classrel *)
   262 
   263 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   264   Symtab.empty
   265   |> add_arities_table pp algebra arities);
   266 
   267 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   268   classes |> Graph.add_edge_trans_acyclic rel
   269     handle Graph.CYCLES css => err_cyclic_classes pp css);
   270 
   271 
   272 (* empty and merge *)
   273 
   274 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   275 
   276 fun merge_algebra pp
   277    (Algebra {classes = classes1, arities = arities1},
   278     Algebra {classes = classes2, arities = arities2}) =
   279   let
   280     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   281       handle Graph.DUP c => err_dup_class c
   282         | Graph.CYCLES css => err_cyclic_classes pp css;
   283     val algebra0 = make_algebra (classes', Symtab.empty);
   284     val arities' =
   285       (case (pointer_eq (classes1, classes2), pointer_eq (arities1, arities2)) of
   286         (true, true) => arities1
   287       | (true, false) =>  (*no completion*)
   288           (arities1, arities2) |> Symtab.join (fn t => fn (ars1, ars2) =>
   289             if pointer_eq (ars1, ars2) then raise Symtab.SAME
   290             else insert_ars pp algebra0 t ars2 ars1)
   291       | (false, true) =>  (*unary completion*)
   292           Symtab.empty
   293           |> add_arities_table pp algebra0 arities1
   294       | (false, false) => (*binary completion*)
   295           Symtab.empty
   296           |> add_arities_table pp algebra0 arities1
   297           |> add_arities_table pp algebra0 arities2);
   298   in make_algebra (classes', arities') end;
   299 
   300 
   301 (* algebra projections *)  (* FIXME potentially violates abstract type integrity *)
   302 
   303 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   304   let
   305     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   306     fun restrict_arity t (c, Ss) =
   307       if P c then
   308         (case sargs (c, t) of
   309           SOME sorts =>
   310             SOME (c, Ss |> map2 (curry (inter_sort algebra)) sorts |> map restrict_sort)
   311         | NONE => NONE)
   312       else NONE;
   313     val classes' = classes |> Graph.restrict P;
   314     val arities' = arities |> Symtab.map (map_filter o restrict_arity);
   315   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   316 
   317 
   318 
   319 (** sorts of types **)
   320 
   321 (* errors -- performance tuning via delayed message composition *)
   322 
   323 datatype class_error =
   324   No_Classrel of class * class |
   325   No_Arity of string * class |
   326   No_Subsort of sort * sort;
   327 
   328 fun class_error pp =
   329   let val ctxt = Syntax.init_pretty pp in
   330     fn No_Classrel (c1, c2) => "No class relation " ^ Syntax.string_of_classrel ctxt [c1, c2]
   331      | No_Arity (a, c) => "No type arity " ^ Syntax.string_of_arity ctxt (a, [], [c])
   332      | No_Subsort (S1, S2) =>
   333         "Cannot derive subsort relation " ^
   334           Syntax.string_of_sort ctxt S1 ^ " < " ^ Syntax.string_of_sort ctxt S2
   335   end;
   336 
   337 exception CLASS_ERROR of class_error;
   338 
   339 
   340 (* mg_domain *)
   341 
   342 fun mg_domain algebra a S =
   343   let
   344     val arities = arities_of algebra;
   345     fun dom c =
   346       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   347         NONE => raise CLASS_ERROR (No_Arity (a, c))
   348       | SOME Ss => Ss);
   349     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   350   in
   351     (case S of
   352       [] => raise Fail "Unknown domain of empty intersection"
   353     | c :: cs => fold dom_inter cs (dom c))
   354   end;
   355 
   356 
   357 (* meet_sort *)
   358 
   359 fun meet_sort algebra =
   360   let
   361     fun inters S S' = inter_sort algebra (S, S');
   362     fun meet _ [] = I
   363       | meet (TFree (_, S)) S' =
   364           if sort_le algebra (S, S') then I
   365           else raise CLASS_ERROR (No_Subsort (S, S'))
   366       | meet (TVar (v, S)) S' =
   367           if sort_le algebra (S, S') then I
   368           else Vartab.map_default (v, S) (inters S')
   369       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   370   in uncurry meet end;
   371 
   372 fun meet_sort_typ algebra (T, S) =
   373   let val tab = meet_sort algebra (T, S) Vartab.empty;
   374   in Term.map_type_tvar (fn (v, _) => TVar (v, (the o Vartab.lookup tab) v)) end;
   375 
   376 
   377 (* of_sort *)
   378 
   379 fun of_sort algebra =
   380   let
   381     fun ofS (_, []) = true
   382       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   383       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   384       | ofS (Type (a, Ts), S) =
   385           let val Ss = mg_domain algebra a S in
   386             ListPair.all ofS (Ts, Ss)
   387           end handle CLASS_ERROR _ => false;
   388   in ofS end;
   389 
   390 
   391 (* animating derivations *)
   392 
   393 fun of_sort_derivation algebra {class_relation, type_constructor, type_variable} =
   394   let
   395     val arities = arities_of algebra;
   396 
   397     fun weaken T D1 S2 =
   398       let val S1 = map snd D1 in
   399         if S1 = S2 then map fst D1
   400         else
   401           S2 |> map (fn c2 =>
   402             (case D1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   403               SOME d1 => class_relation T d1 c2
   404             | NONE => raise CLASS_ERROR (No_Subsort (S1, S2))))
   405       end;
   406 
   407     fun derive (_, []) = []
   408       | derive (Type (a, Us), S) =
   409           let
   410             val Ss = mg_domain algebra a S;
   411             val dom = map2 (fn U => fn S => derive (U, S) ~~ S) Us Ss;
   412           in
   413             S |> map (fn c =>
   414               let
   415                 val Ss' = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   416                 val dom' = map (fn ((U, d), S') => weaken U d S' ~~ S') ((Us ~~ dom) ~~ Ss');
   417               in type_constructor (a, Us) dom' c end)
   418           end
   419       | derive (T, S) = weaken T (type_variable T) S;
   420   in derive end;
   421 
   422 fun classrel_derivation algebra class_relation =
   423   let
   424     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   425       | path (x, _) = x;
   426   in
   427     fn (x, c1) => fn c2 =>
   428       (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   429         [] => raise CLASS_ERROR (No_Classrel (c1, c2))
   430       | cs :: _ => path (x, cs))
   431   end;
   432 
   433 
   434 (* witness_sorts *)
   435 
   436 fun witness_sorts algebra types hyps sorts =
   437   let
   438     fun le S1 S2 = sort_le algebra (S1, S2);
   439     fun get S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   440     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   441 
   442     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   443       | witn_sort path S (solved, failed) =
   444           if exists (le S) failed then (NONE, (solved, failed))
   445           else
   446             (case get_first (get S) solved of
   447               SOME w => (SOME w, (solved, failed))
   448             | NONE =>
   449                 (case get_first (get S) hyps of
   450                   SOME w => (SOME w, (w :: solved, failed))
   451                 | NONE => witn_types path types S (solved, failed)))
   452 
   453     and witn_sorts path x = fold_map (witn_sort path) x
   454 
   455     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   456       | witn_types path (t :: ts) S solved_failed =
   457           (case mg_dom t S of
   458             SOME SS =>
   459               (*do not descend into stronger args (achieving termination)*)
   460               if exists (fn D => le D S orelse exists (le D) path) SS then
   461                 witn_types path ts S solved_failed
   462               else
   463                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   464                   if forall is_some ws then
   465                     let val w = (Type (t, map (#1 o the) ws), S)
   466                     in (SOME w, (w :: solved', failed')) end
   467                   else witn_types path ts S (solved', failed')
   468                 end
   469           | NONE => witn_types path ts S solved_failed);
   470 
   471   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   472 
   473 end;