src/Pure/sorts.ML
author wenzelm
Thu Aug 15 16:02:47 2019 +0200 (9 months ago)
changeset 70533 031620901fcd
parent 68295 781a98696638
child 71454 b2c9f94e025f
permissions -rw-r--r--
support for (fully reconstructed) proof terms in Scala;
proper cache_typs;
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort Ord_List.T
    18   val subset: sort Ord_List.T * sort Ord_List.T -> bool
    19   val union: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    20   val subtract: sort Ord_List.T -> sort Ord_List.T -> sort Ord_List.T
    21   val remove_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    22   val insert_sort: sort -> sort Ord_List.T -> sort Ord_List.T
    23   val insert_typ: typ -> sort Ord_List.T -> sort Ord_List.T
    24   val insert_typs: typ list -> sort Ord_List.T -> sort Ord_List.T
    25   val insert_term: term -> sort Ord_List.T -> sort Ord_List.T
    26   val insert_terms: term list -> sort Ord_List.T -> sort Ord_List.T
    27   type algebra
    28   val classes_of: algebra -> serial Graph.T
    29   val arities_of: algebra -> (class * sort list) list Symtab.table
    30   val all_classes: algebra -> class list
    31   val super_classes: algebra -> class -> class list
    32   val class_less: algebra -> class * class -> bool
    33   val class_le: algebra -> class * class -> bool
    34   val sort_eq: algebra -> sort * sort -> bool
    35   val sort_le: algebra -> sort * sort -> bool
    36   val sorts_le: algebra -> sort list * sort list -> bool
    37   val inter_sort: algebra -> sort * sort -> sort
    38   val minimize_sort: algebra -> sort -> sort
    39   val complete_sort: algebra -> sort -> sort
    40   val minimal_sorts: algebra -> sort list -> sort Ord_List.T
    41   val add_class: Context.generic -> class * class list -> algebra -> algebra
    42   val add_classrel: Context.generic -> class * class -> algebra -> algebra
    43   val add_arities: Context.generic -> string * (class * sort list) list -> algebra -> algebra
    44   val empty_algebra: algebra
    45   val merge_algebra: Context.generic -> algebra * algebra -> algebra
    46   val dest_algebra: algebra list -> algebra ->
    47     {classrel: (class * class list) list,
    48      arities: (string * sort list * class) list}
    49   val subalgebra: Context.generic -> (class -> bool) -> (class * string -> sort list option)
    50     -> algebra -> (sort -> sort) * algebra
    51   type class_error
    52   val class_error: Context.generic -> class_error -> string
    53   exception CLASS_ERROR of class_error
    54   val has_instance: algebra -> string -> sort -> bool
    55   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    56   val meet_sort: algebra -> typ * sort
    57     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    58   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    59   val of_sort: algebra -> typ * sort -> bool
    60   val of_sort_derivation: algebra ->
    61     {class_relation: typ -> bool -> 'a * class -> class -> 'a,
    62      type_constructor: string * typ list -> ('a * class) list list -> class -> 'a,
    63      type_variable: typ -> ('a * class) list} ->
    64     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    65   val classrel_derivation: algebra ->
    66     ('a * class -> class -> 'a) -> 'a * class -> class -> 'a  (*exception CLASS_ERROR*)
    67   val witness_sorts: algebra -> string list -> (typ * sort) list -> sort list -> (typ * sort) list
    68 end;
    69 
    70 structure Sorts: SORTS =
    71 struct
    72 
    73 
    74 (** ordered lists of sorts **)
    75 
    76 val make = Ord_List.make Term_Ord.sort_ord;
    77 val subset = Ord_List.subset Term_Ord.sort_ord;
    78 val union = Ord_List.union Term_Ord.sort_ord;
    79 val subtract = Ord_List.subtract Term_Ord.sort_ord;
    80 
    81 val remove_sort = Ord_List.remove Term_Ord.sort_ord;
    82 val insert_sort = Ord_List.insert Term_Ord.sort_ord;
    83 
    84 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    85   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    86   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    87 and insert_typs [] Ss = Ss
    88   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    89 
    90 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    91   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    92   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    93   | insert_term (Bound _) Ss = Ss
    94   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    95   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    96 
    97 fun insert_terms [] Ss = Ss
    98   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    99 
   100 
   101 
   102 (** order-sorted algebra **)
   103 
   104 (*
   105   classes: graph representing class declarations together with proper
   106     subclass relation, which needs to be transitive and acyclic.
   107 
   108   arities: table of association lists of all type arities; (t, ars)
   109     means that type constructor t has the arities ars; an element
   110     (c, Ss) of ars represents the arity t::(Ss)c.  "Coregularity" of
   111     the arities structure requires that for any two declarations
   112     t::(Ss1)c1 and t::(Ss2)c2 such that c1 <= c2 holds Ss1 <= Ss2.
   113 *)
   114 
   115 datatype algebra = Algebra of
   116  {classes: serial Graph.T,
   117   arities: (class * sort list) list Symtab.table};
   118 
   119 fun classes_of (Algebra {classes, ...}) = classes;
   120 fun arities_of (Algebra {arities, ...}) = arities;
   121 
   122 fun make_algebra (classes, arities) =
   123   Algebra {classes = classes, arities = arities};
   124 
   125 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   126 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   127 
   128 
   129 (* classes *)
   130 
   131 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   132 
   133 val super_classes = Graph.immediate_succs o classes_of;
   134 
   135 
   136 (* class relations *)
   137 
   138 val class_less = Graph.is_edge o classes_of;
   139 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   140 
   141 
   142 (* sort relations *)
   143 
   144 fun sort_le algebra (S1, S2) =
   145   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   146 
   147 fun sorts_le algebra (Ss1, Ss2) =
   148   ListPair.all (sort_le algebra) (Ss1, Ss2);
   149 
   150 fun sort_eq algebra (S1, S2) =
   151   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   152 
   153 
   154 (* intersection *)
   155 
   156 fun inter_class algebra c S =
   157   let
   158     fun intr [] = [c]
   159       | intr (S' as c' :: c's) =
   160           if class_le algebra (c', c) then S'
   161           else if class_le algebra (c, c') then intr c's
   162           else c' :: intr c's
   163   in intr S end;
   164 
   165 fun inter_sort algebra (S1, S2) =
   166   sort_strings (fold (inter_class algebra) S1 S2);
   167 
   168 
   169 (* normal forms *)
   170 
   171 fun minimize_sort _ [] = []
   172   | minimize_sort _ (S as [_]) = S
   173   | minimize_sort algebra S =
   174       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   175       |> sort_distinct string_ord;
   176 
   177 fun complete_sort algebra =
   178   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   179 
   180 fun minimal_sorts algebra raw_sorts =
   181   let
   182     fun le S1 S2 = sort_le algebra (S1, S2);
   183     val sorts = make (map (minimize_sort algebra) raw_sorts);
   184   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   185 
   186 
   187 
   188 (** build algebras **)
   189 
   190 (* classes *)
   191 
   192 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   193 
   194 fun err_cyclic_classes context css =
   195   error (cat_lines (map (fn cs =>
   196     "Cycle in class relation: " ^ Syntax.string_of_classrel (Syntax.init_pretty context) cs) css));
   197 
   198 fun add_class context (c, cs) = map_classes (fn classes =>
   199   let
   200     val classes' = classes |> Graph.new_node (c, serial ())
   201       handle Graph.DUP dup => err_dup_class dup;
   202     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   203       handle Graph.CYCLES css => err_cyclic_classes context css;
   204   in classes'' end);
   205 
   206 
   207 (* arities *)
   208 
   209 local
   210 
   211 fun for_classes _ NONE = ""
   212   | for_classes ctxt (SOME (c1, c2)) = " for classes " ^ Syntax.string_of_classrel ctxt [c1, c2];
   213 
   214 fun err_conflict context t cc (c, Ss) (c', Ss') =
   215   let val ctxt = Syntax.init_pretty context in
   216     error ("Conflict of type arities" ^ for_classes ctxt cc ^ ":\n  " ^
   217       Syntax.string_of_arity ctxt (t, Ss, [c]) ^ " and\n  " ^
   218       Syntax.string_of_arity ctxt (t, Ss', [c']))
   219   end;
   220 
   221 fun coregular context algebra t (c, Ss) ars =
   222   let
   223     fun conflict (c', Ss') =
   224       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   225         SOME ((c, c'), (c', Ss'))
   226       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   227         SOME ((c', c), (c', Ss'))
   228       else NONE;
   229   in
   230     (case get_first conflict ars of
   231       SOME ((c1, c2), (c', Ss')) => err_conflict context t (SOME (c1, c2)) (c, Ss) (c', Ss')
   232     | NONE => (c, Ss) :: ars)
   233   end;
   234 
   235 fun complete algebra (c, Ss) = map (rpair Ss) (c :: super_classes algebra c);
   236 
   237 fun insert context algebra t (c, Ss) ars =
   238   (case AList.lookup (op =) ars c of
   239     NONE => coregular context algebra t (c, Ss) ars
   240   | SOME Ss' =>
   241       if sorts_le algebra (Ss, Ss') then ars
   242       else if sorts_le algebra (Ss', Ss)
   243       then coregular context algebra t (c, Ss) (remove (op =) (c, Ss') ars)
   244       else err_conflict context t NONE (c, Ss) (c, Ss'));
   245 
   246 in
   247 
   248 fun insert_ars context algebra t = fold_rev (insert context algebra t);
   249 
   250 fun insert_complete_ars context algebra (t, ars) arities =
   251   let val ars' =
   252     Symtab.lookup_list arities t
   253     |> fold_rev (insert_ars context algebra t) (map (complete algebra) ars);
   254   in Symtab.update (t, ars') arities end;
   255 
   256 fun add_arities context arg algebra =
   257   algebra |> map_arities (insert_complete_ars context algebra arg);
   258 
   259 fun add_arities_table context algebra =
   260   Symtab.fold (fn (t, ars) => insert_complete_ars context algebra (t, ars));
   261 
   262 end;
   263 
   264 
   265 (* classrel *)
   266 
   267 fun rebuild_arities context algebra = algebra |> map_arities (fn arities =>
   268   Symtab.empty
   269   |> add_arities_table context algebra arities);
   270 
   271 fun add_classrel context rel = rebuild_arities context o map_classes (fn classes =>
   272   classes |> Graph.add_edge_trans_acyclic rel
   273     handle Graph.CYCLES css => err_cyclic_classes context css);
   274 
   275 
   276 (* empty and merge *)
   277 
   278 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   279 
   280 fun merge_algebra context
   281    (Algebra {classes = classes1, arities = arities1},
   282     Algebra {classes = classes2, arities = arities2}) =
   283   let
   284     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   285       handle Graph.DUP c => err_dup_class c
   286         | Graph.CYCLES css => err_cyclic_classes context css;
   287     val algebra0 = make_algebra (classes', Symtab.empty);
   288     val arities' =
   289       (case (pointer_eq (classes1, classes2), pointer_eq (arities1, arities2)) of
   290         (true, true) => arities1
   291       | (true, false) =>  (*no completion*)
   292           (arities1, arities2) |> Symtab.join (fn t => fn (ars1, ars2) =>
   293             if pointer_eq (ars1, ars2) then raise Symtab.SAME
   294             else insert_ars context algebra0 t ars2 ars1)
   295       | (false, true) =>  (*unary completion*)
   296           Symtab.empty
   297           |> add_arities_table context algebra0 arities1
   298       | (false, false) => (*binary completion*)
   299           Symtab.empty
   300           |> add_arities_table context algebra0 arities1
   301           |> add_arities_table context algebra0 arities2);
   302   in make_algebra (classes', arities') end;
   303 
   304 
   305 (* destruct *)
   306 
   307 fun dest_algebra parents (Algebra {classes, arities}) =
   308   let
   309     fun new_classrel rel = not (exists (fn algebra => class_less algebra rel) parents);
   310     val classrel =
   311       (classes, []) |-> Graph.fold (fn (c, (_, (_, ds))) =>
   312         (case filter (fn d => new_classrel (c, d)) (Graph.Keys.dest ds) of
   313           [] => I
   314         | ds' => cons (c, sort_strings ds')))
   315       |> sort_by #1;
   316 
   317     fun is_arity t ar algebra = member (op =) (Symtab.lookup_list (arities_of algebra) t) ar;
   318     fun add_arity t (c, Ss) = not (exists (is_arity t (c, Ss)) parents) ? cons (t, Ss, c);
   319     val arities =
   320       (arities, []) |-> Symtab.fold (fn (t, ars) => fold_rev (add_arity t) (sort_by #1 ars))
   321       |> sort_by #1;
   322   in {classrel = classrel, arities = arities} end;
   323 
   324 
   325 (* algebra projections *)  (* FIXME potentially violates abstract type integrity *)
   326 
   327 fun subalgebra context P sargs (algebra as Algebra {classes, arities}) =
   328   let
   329     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   330     fun restrict_arity t (c, Ss) =
   331       if P c then
   332         (case sargs (c, t) of
   333           SOME sorts =>
   334             SOME (c, Ss |> map2 (curry (inter_sort algebra)) sorts |> map restrict_sort)
   335         | NONE => NONE)
   336       else NONE;
   337     val classes' = classes |> Graph.restrict P;
   338     val arities' = arities |> Symtab.map (map_filter o restrict_arity);
   339   in (restrict_sort, rebuild_arities context (make_algebra (classes', arities'))) end;
   340 
   341 
   342 
   343 (** sorts of types **)
   344 
   345 (* errors -- performance tuning via delayed message composition *)
   346 
   347 datatype class_error =
   348   No_Classrel of class * class |
   349   No_Arity of string * class |
   350   No_Subsort of sort * sort;
   351 
   352 fun class_error context =
   353   let val ctxt = Syntax.init_pretty context in
   354     fn No_Classrel (c1, c2) => "No class relation " ^ Syntax.string_of_classrel ctxt [c1, c2]
   355      | No_Arity (a, c) => "No type arity " ^ Syntax.string_of_arity ctxt (a, [], [c])
   356      | No_Subsort (S1, S2) =>
   357         "Cannot derive subsort relation " ^
   358           Syntax.string_of_sort ctxt S1 ^ " < " ^ Syntax.string_of_sort ctxt S2
   359   end;
   360 
   361 exception CLASS_ERROR of class_error;
   362 
   363 
   364 (* instances *)
   365 
   366 fun has_instance algebra a =
   367   forall (AList.defined (op =) (Symtab.lookup_list (arities_of algebra) a));
   368 
   369 fun mg_domain algebra a S =
   370   let
   371     val ars = Symtab.lookup_list (arities_of algebra) a;
   372     fun dom c =
   373       (case AList.lookup (op =) ars c of
   374         NONE => raise CLASS_ERROR (No_Arity (a, c))
   375       | SOME Ss => Ss);
   376     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   377   in
   378     (case S of
   379       [] => raise Fail "Unknown domain of empty intersection"
   380     | c :: cs => fold dom_inter cs (dom c))
   381   end;
   382 
   383 
   384 (* meet_sort *)
   385 
   386 fun meet_sort algebra =
   387   let
   388     fun inters S S' = inter_sort algebra (S, S');
   389     fun meet _ [] = I
   390       | meet (TFree (_, S)) S' =
   391           if sort_le algebra (S, S') then I
   392           else raise CLASS_ERROR (No_Subsort (S, S'))
   393       | meet (TVar (v, S)) S' =
   394           if sort_le algebra (S, S') then I
   395           else Vartab.map_default (v, S) (inters S')
   396       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   397   in uncurry meet end;
   398 
   399 fun meet_sort_typ algebra (T, S) =
   400   let val tab = meet_sort algebra (T, S) Vartab.empty;
   401   in Term.map_type_tvar (fn (v, _) => TVar (v, (the o Vartab.lookup tab) v)) end;
   402 
   403 
   404 (* of_sort *)
   405 
   406 fun of_sort algebra =
   407   let
   408     fun ofS (_, []) = true
   409       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   410       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   411       | ofS (Type (a, Ts), S) =
   412           let val Ss = mg_domain algebra a S in
   413             ListPair.all ofS (Ts, Ss)
   414           end handle CLASS_ERROR _ => false;
   415   in ofS end;
   416 
   417 
   418 (* animating derivations *)
   419 
   420 fun of_sort_derivation algebra {class_relation, type_constructor, type_variable} =
   421   let
   422     val arities = arities_of algebra;
   423 
   424     fun weaken T D1 S2 =
   425       let val S1 = map snd D1 in
   426         if S1 = S2 then map fst D1
   427         else
   428           S2 |> map (fn c2 =>
   429             (case D1 |> filter (fn (_, c1) => class_le algebra (c1, c2)) of
   430               [d1] => class_relation T true d1 c2
   431             | (d1 :: _ :: _) => class_relation T false d1 c2
   432             | [] => raise CLASS_ERROR (No_Subsort (S1, S2))))
   433       end;
   434 
   435     fun derive (_, []) = []
   436       | derive (Type (a, Us), S) =
   437           let
   438             val Ss = mg_domain algebra a S;
   439             val dom = map2 (fn U => fn S => derive (U, S) ~~ S) Us Ss;
   440           in
   441             S |> map (fn c =>
   442               let
   443                 val Ss' = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   444                 val dom' = map (fn ((U, d), S') => weaken U d S' ~~ S') ((Us ~~ dom) ~~ Ss');
   445               in type_constructor (a, Us) dom' c end)
   446           end
   447       | derive (T, S) = weaken T (type_variable T) S;
   448   in derive end;
   449 
   450 fun classrel_derivation algebra class_relation =
   451   let
   452     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   453       | path (x, _) = x;
   454   in
   455     fn (x, c1) => fn c2 =>
   456       (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   457         [] => raise CLASS_ERROR (No_Classrel (c1, c2))
   458       | cs :: _ => path (x, cs))
   459   end;
   460 
   461 
   462 (* witness_sorts *)
   463 
   464 fun witness_sorts algebra types hyps sorts =
   465   let
   466     fun le S1 S2 = sort_le algebra (S1, S2);
   467     fun get S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   468     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   469 
   470     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   471       | witn_sort path S (solved, failed) =
   472           if exists (le S) failed then (NONE, (solved, failed))
   473           else
   474             (case get_first (get S) solved of
   475               SOME w => (SOME w, (solved, failed))
   476             | NONE =>
   477                 (case get_first (get S) hyps of
   478                   SOME w => (SOME w, (w :: solved, failed))
   479                 | NONE => witn_types path types S (solved, failed)))
   480 
   481     and witn_sorts path x = fold_map (witn_sort path) x
   482 
   483     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   484       | witn_types path (t :: ts) S solved_failed =
   485           (case mg_dom t S of
   486             SOME SS =>
   487               (*do not descend into stronger args (achieving termination)*)
   488               if exists (fn D => le D S orelse exists (le D) path) SS then
   489                 witn_types path ts S solved_failed
   490               else
   491                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   492                   if forall is_some ws then
   493                     let val w = (Type (t, map (#1 o the) ws), S)
   494                     in (SOME w, (w :: solved', failed')) end
   495                   else witn_types path ts S (solved', failed')
   496                 end
   497           | NONE => witn_types path ts S solved_failed);
   498 
   499   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   500 
   501 end;