maintain sort constraints from type instantiations, with pro-forma derivation to collect oracles/thms;
authorwenzelm
Sat, 03 Aug 2019 12:58:53 +0200
changeset 70459 f0a445c5a82c
parent 70458 9e2173eb23eb
child 70460 b2b44fd1b6ec
maintain sort constraints from type instantiations, with pro-forma derivation to collect oracles/thms; tuned;
src/Pure/context.ML
src/Pure/global_theory.ML
src/Pure/goal.ML
src/Pure/thm.ML
--- a/src/Pure/context.ML	Fri Aug 02 14:14:49 2019 +0200
+++ b/src/Pure/context.ML	Sat Aug 03 12:58:53 2019 +0200
@@ -33,6 +33,7 @@
   val timing: bool Unsynchronized.ref
   val parents_of: theory -> theory list
   val ancestors_of: theory -> theory list
+  val theory_id_ord: theory_id * theory_id -> order
   val theory_id_long_name: theory_id -> string
   val theory_id_name: theory_id -> string
   val theory_long_name: theory -> string
@@ -166,6 +167,7 @@
 fun make_history name = {name = name, stage = SOME (init_stage ())};
 fun make_ancestry parents ancestors = {parents = parents, ancestors = ancestors};
 
+val theory_id_ord = int_ord o apply2 (#id o identity_of_id);
 val theory_id_long_name = #name o history_of_id;
 val theory_id_name = Long_Name.base_name o theory_id_long_name;
 val theory_long_name = #name o history_of;
--- a/src/Pure/global_theory.ML	Fri Aug 02 14:14:49 2019 +0200
+++ b/src/Pure/global_theory.ML	Sat Aug 03 12:58:53 2019 +0200
@@ -122,12 +122,16 @@
 val unofficial1 = Name_Flags {pre = true, official = false};
 val unofficial2 = Name_Flags {pre = false, official = false};
 
-fun name_thm No_Name_Flags _ thm = thm
-  | name_thm (Name_Flags {pre, official}) name thm = thm
-      |> (official andalso (not pre orelse Thm.derivation_name thm = "")) ?
-          Thm.name_derivation name
-      |> (name <> "" andalso (not pre orelse not (Thm.has_name_hint thm))) ?
-          Thm.put_name_hint name;
+fun name_thm name_flags name =
+  Thm.solve_constraints #> (fn thm =>
+    (case name_flags of
+      No_Name_Flags => thm
+    | Name_Flags {pre, official} =>
+        thm
+        |> (official andalso (not pre orelse Thm.derivation_name thm = "")) ?
+            Thm.name_derivation name
+        |> (name <> "" andalso (not pre orelse not (Thm.has_name_hint thm))) ?
+            Thm.put_name_hint name));
 
 end;
 
--- a/src/Pure/goal.ML	Fri Aug 02 14:14:49 2019 +0200
+++ b/src/Pure/goal.ML	Sat Aug 03 12:58:53 2019 +0200
@@ -132,13 +132,15 @@
         Drule.forall_intr_list fixes #>
         Thm.adjust_maxidx_thm ~1 #>
         Thm.generalize (map #1 tfrees, []) 0 #>
-        Thm.strip_shyps);
+        Thm.strip_shyps #>
+        Thm.solve_constraints);
     val local_result =
       Thm.future global_result global_prop
       |> Thm.close_derivation
       |> Thm.instantiate (instT, [])
       |> Drule.forall_elim_list fixes
-      |> fold (Thm.elim_implies o Thm.assume) assms;
+      |> fold (Thm.elim_implies o Thm.assume) assms
+      |> Thm.solve_constraints;
   in local_result end;
 
 
--- a/src/Pure/thm.ML	Fri Aug 02 14:14:49 2019 +0200
+++ b/src/Pure/thm.ML	Sat Aug 03 12:58:53 2019 +0200
@@ -68,6 +68,7 @@
   val theory_name: thm -> string
   val maxidx_of: thm -> int
   val maxidx_thm: thm -> int -> int
+  val constraints_of: thm -> (theory * (typ * sort)) list
   val shyps_of: thm -> sort Ord_List.T
   val hyps_of: thm -> term list
   val prop_of: thm -> term
@@ -139,6 +140,7 @@
   val combination: thm -> thm -> thm
   val equal_intr: thm -> thm -> thm
   val equal_elim: thm -> thm -> thm
+  val solve_constraints: thm -> thm
   val flexflex_rule: Proof.context option -> thm -> thm Seq.seq
   val generalize: string list * string list -> int -> thm -> thm
   val generalize_cterm: string list * string list -> int -> cterm -> cterm
@@ -360,11 +362,51 @@
 
 (*** Derivations and Theorems ***)
 
+(* sort constraints *)
+
+type constraint = theory * (typ * sort);
+
+local
+
+val constraint_ord : constraint * constraint -> order =
+  prod_ord (Context.theory_id_ord o apply2 Context.theory_id)
+    (prod_ord Term_Ord.typ_ord Term_Ord.sort_ord);
+
+val smash_atyps =
+  map_atyps (fn TVar (_, S) => Term.aT S | TFree (_, S) => Term.aT S | T => T);
+
+in
+
+val union_constraints = Ord_List.union constraint_ord;
+
+fun insert_constraints thy (T, S) =
+  let
+    val ignored =
+      S = [] orelse
+        (case T of
+          TFree (_, S') => S = S'
+        | TVar (_, S') => S = S'
+        | _ => false);
+  in if ignored then I else Ord_List.insert constraint_ord (thy, (smash_atyps T, S)) end;
+
+fun insert_constraints_env thy env =
+  let
+    val tyenv = Envir.type_env env;
+    fun insert ([], _) = I
+      | insert (S, T) = insert_constraints thy (Envir.norm_type tyenv T, S);
+  in tyenv |> Vartab.fold (insert o #2) end;
+
+end;
+
+
+(* datatype thm *)
+
 datatype thm = Thm of
  deriv *                        (*derivation*)
  {cert: Context.certificate,    (*background theory certificate*)
   tags: Properties.T,           (*additional annotations/comments*)
   maxidx: int,                  (*maximum index of any Var or TVar*)
+  constraints: constraint Ord_List.T,  (*implicit proof obligations for sort constraints*)
   shyps: sort Ord_List.T,       (*sort hypotheses*)
   hyps: term Ord_List.T,        (*hypotheses*)
   tpairs: (term * term) list,   (*flex-flex pairs*)
@@ -396,6 +438,7 @@
         | _ => I) th
   end;
 
+
 fun terms_of_tpairs tpairs = fold_rev (fn (t, u) => cons t o cons u) tpairs [];
 
 fun eq_tpairs ((t, u), (t', u')) = t aconv t' andalso u aconv u';
@@ -427,6 +470,7 @@
 
 val maxidx_of = #maxidx o rep_thm;
 fun maxidx_thm th i = Int.max (maxidx_of th, i);
+val constraints_of = #constraints o rep_thm;
 val shyps_of = #shyps o rep_thm;
 val hyps_of = #hyps o rep_thm;
 val prop_of = #prop o rep_thm;
@@ -486,13 +530,18 @@
       Cterm {cert = Context.Certificate_Id (Context.theory_id thy),
         t = t, T = T, maxidx = maxidx, sorts = sorts});
 
-fun trim_context th =
+fun trim_context_thm th =
   (case th of
-    Thm (_, {cert = Context.Certificate_Id _, ...}) => th
-  | Thm (der, {cert = Context.Certificate thy, tags, maxidx, shyps, hyps, tpairs, prop}) =>
+    Thm (_, {constraints = _ :: _, ...}) =>
+      raise THM ("trim_context: pending sort constraints", 0, [th])
+  | Thm (_, {cert = Context.Certificate_Id _, ...}) => th
+  | Thm (der,
+      {cert = Context.Certificate thy, tags, maxidx, constraints = [], shyps, hyps,
+        tpairs, prop}) =>
       Thm (der,
        {cert = Context.Certificate_Id (Context.theory_id thy),
-        tags = tags, maxidx = maxidx, shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop}));
+        tags = tags, maxidx = maxidx, constraints = [], shyps = shyps, hyps = hyps,
+        tpairs = tpairs, prop = prop}));
 
 fun transfer_ctyp thy' cT =
   let
@@ -522,7 +571,7 @@
 
 fun transfer thy' th =
   let
-    val Thm (der, {cert, tags, maxidx, shyps, hyps, tpairs, prop}) = th;
+    val Thm (der, {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop}) = th;
     val _ =
       Context.subthy_id (Context.certificate_theory_id cert, Context.theory_id thy') orelse
         raise CONTEXT ("Cannot transfer: not a super theory", [], [], [th],
@@ -535,6 +584,7 @@
        {cert = cert',
         tags = tags,
         maxidx = maxidx,
+        constraints = constraints,
         shyps = shyps,
         hyps = hyps,
         tpairs = tpairs,
@@ -584,9 +634,9 @@
 
 
 (*implicit alpha-conversion*)
-fun renamed_prop prop' (Thm (der, {cert, tags, maxidx, shyps, hyps, tpairs, prop})) =
+fun renamed_prop prop' (Thm (der, {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop})) =
   if prop aconv prop' then
-    Thm (der, {cert = cert, tags = tags, maxidx = maxidx, shyps = shyps,
+    Thm (der, {cert = cert, tags = tags, maxidx = maxidx, constraints = constraints, shyps = shyps,
       hyps = hyps, tpairs = tpairs, prop = prop'})
   else raise TERM ("renamed_prop: props disagree", [prop, prop']);
 
@@ -612,7 +662,7 @@
 fun weaken raw_ct th =
   let
     val ct as Cterm {t = A, T, sorts, maxidx = maxidxA, ...} = adjust_maxidx_cterm ~1 raw_ct;
-    val Thm (der, {tags, maxidx, shyps, hyps, tpairs, prop, ...}) = th;
+    val Thm (der, {tags, maxidx, constraints, shyps, hyps, tpairs, prop, ...}) = th;
   in
     if T <> propT then
       raise THM ("weaken: assumptions must have type prop", 0, [])
@@ -623,6 +673,7 @@
        {cert = join_certificate1 (ct, th),
         tags = tags,
         maxidx = maxidx,
+        constraints = constraints,
         shyps = Sorts.union sorts shyps,
         hyps = insert_hyps A hyps,
         tpairs = tpairs,
@@ -656,11 +707,11 @@
 fun promise_ord ((i, _), (j, _)) = int_ord (j, i);
 
 fun deriv_rule2 f
-    (Deriv {promises = ps1, body = PBody {oracles = oras1, thms = thms1, proof = prf1}})
-    (Deriv {promises = ps2, body = PBody {oracles = oras2, thms = thms2, proof = prf2}}) =
+    (Deriv {promises = ps1, body = PBody {oracles = oracles1, thms = thms1, proof = prf1}})
+    (Deriv {promises = ps2, body = PBody {oracles = oracles2, thms = thms2, proof = prf2}}) =
   let
     val ps = Ord_List.union promise_ord ps1 ps2;
-    val oras = Proofterm.unions_oracles [oras1, oras2];
+    val oracles = Proofterm.unions_oracles [oracles1, oracles2];
     val thms = Proofterm.unions_thms [thms1, thms2];
     val prf =
       (case ! Proofterm.proofs of
@@ -668,7 +719,7 @@
       | 1 => MinProof
       | 0 => MinProof
       | i => error ("Illegal level of detail for proof objects: " ^ string_of_int i));
-  in make_deriv ps oras thms prf end;
+  in make_deriv ps oracles thms prf end;
 
 fun deriv_rule1 f = deriv_rule2 (K f) empty_deriv;
 fun deriv_rule0 prf = deriv_rule1 I (make_deriv [] [] [] prf);
@@ -717,11 +768,12 @@
 fun future_result i orig_cert orig_shyps orig_prop thm =
   let
     fun err msg = raise THM ("future_result: " ^ msg, 0, [thm]);
-    val Thm (Deriv {promises, ...}, {cert, shyps, hyps, tpairs, prop, ...}) = thm;
+    val Thm (Deriv {promises, ...}, {cert, constraints, shyps, hyps, tpairs, prop, ...}) = thm;
 
     val _ = Context.eq_certificate (cert, orig_cert) orelse err "bad theory";
     val _ = prop aconv orig_prop orelse err "bad prop";
-    val _ = null tpairs orelse err "bad tpairs";
+    val _ = null constraints orelse err "bad sort constraints";
+    val _ = null tpairs orelse err "bad flex-flex constraints";
     val _ = null hyps orelse err "bad hyps";
     val _ = Sorts.subset (shyps, orig_shyps) orelse err "bad shyps";
     val _ = forall (fn (j, _) => i <> j) promises orelse err "bad dependencies";
@@ -743,6 +795,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx,
+      constraints = [],
       shyps = sorts,
       hyps = [],
       tpairs = [],
@@ -764,8 +817,9 @@
              val maxidx = maxidx_of_term prop;
              val shyps = Sorts.insert_term prop [];
            in
-             Thm (der, {cert = cert, tags = [],
-               maxidx = maxidx, shyps = shyps, hyps = [], tpairs = [], prop = prop})
+             Thm (der,
+               {cert = cert, tags = [], maxidx = maxidx,
+                 constraints = [], shyps = shyps, hyps = [], tpairs = [], prop = prop})
            end);
   in
     (case get_first get_ax (Theory.nodes_of thy0) of
@@ -782,8 +836,8 @@
 
 val get_tags = #tags o rep_thm;
 
-fun map_tags f (Thm (der, {cert, tags, maxidx, shyps, hyps, tpairs, prop})) =
-  Thm (der, {cert = cert, tags = f tags, maxidx = maxidx,
+fun map_tags f (Thm (der, {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop})) =
+  Thm (der, {cert = cert, tags = f tags, maxidx = maxidx, constraints = constraints,
     shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop});
 
 
@@ -792,14 +846,16 @@
 fun norm_proof (th as Thm (der, args)) =
   Thm (deriv_rule1 (Proofterm.rew_proof (theory_of_thm th)) der, args);
 
-fun adjust_maxidx_thm i (th as Thm (der, {cert, tags, maxidx, shyps, hyps, tpairs, prop})) =
+fun adjust_maxidx_thm i
+    (th as Thm (der, {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop})) =
   if maxidx = i then th
   else if maxidx < i then
-    Thm (der, {maxidx = i, cert = cert, tags = tags, shyps = shyps,
+    Thm (der, {maxidx = i, cert = cert, tags = tags, constraints = constraints, shyps = shyps,
       hyps = hyps, tpairs = tpairs, prop = prop})
   else
     Thm (der, {maxidx = Int.max (maxidx_tpairs tpairs (maxidx_of_term prop), i),
-      cert = cert, tags = tags, shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop});
+      cert = cert, tags = tags, constraints = constraints, shyps = shyps,
+      hyps = hyps, tpairs = tpairs, prop = prop});
 
 
 
@@ -870,8 +926,47 @@
 val arity_proof = proof_of oo the_arity;
 
 
+(* solve sort constraints by pro-forma proof *)
 
-(*** Theorems with official name ***)
+local
+
+fun union_digest (oracles1, thms1) (oracles2, thms2) =
+  (Proofterm.unions_oracles [oracles1, oracles2], Proofterm.unions_thms [thms1, thms2]);
+
+fun thm_digest (Thm (Deriv {body = Proofterm.PBody {oracles, thms, ...}, ...}, _)) =
+  (oracles, thms);
+
+fun constraint_digest thy =
+  Sorts.of_sort_derivation (Sign.classes_of thy)
+   {class_relation = fn _ => fn _ => fn (digest, c1) => fn c2 =>
+      if c1 = c2 then ([], []) else union_digest digest (thm_digest (the_classrel thy (c1, c2))),
+    type_constructor = fn (a, _) => fn dom => fn c =>
+      let val arity_digest = thm_digest (the_arity thy (a, (map o map) #2 dom, c))
+      in (fold o fold) (union_digest o #1) dom arity_digest end,
+    type_variable = fn T => map (pair ([], [])) (Type.sort_of_atyp T)};
+
+in
+
+fun solve_constraints (thm as Thm (_, {constraints = [], ...})) = thm
+  | solve_constraints (Thm (der, args)) =
+      let
+        val {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop} = args;
+
+        val Deriv {promises, body = Proofterm.PBody {oracles, thms, proof}} = der;
+        val (oracles', thms') = (oracles, thms)
+          |> fold (fold union_digest o uncurry constraint_digest) constraints;
+        val body' = Proofterm.PBody {oracles = oracles', thms = thms', proof = proof};
+      in
+        Thm (Deriv {promises = promises, body = body'},
+          {constraints = [], cert = cert, tags = tags, maxidx = maxidx,
+            shyps = shyps, hyps = hyps, tpairs = tpairs, prop = prop})
+      end;
+
+end;
+
+
+
+(*** Closed theorems with official name ***)
 
 (*non-deterministic, depends on unknown promises*)
 fun derivation_closed (Thm (Deriv {body, ...}, _)) =
@@ -881,21 +976,28 @@
 fun derivation_name (Thm (Deriv {body, ...}, {shyps, hyps, prop, ...})) =
   Proofterm.get_name shyps hyps prop (Proofterm.proof_of body);
 
-fun name_derivation name (thm as Thm (der, args)) =
-  let
-    val Deriv {promises, body} = der;
-    val {shyps, hyps, prop, tpairs, ...} = args;
-    val _ = null tpairs orelse raise THM ("put_name: unsolved flex-flex constraints", 0, [thm]);
-    val thy = theory_of_thm thm;
+fun name_derivation name =
+  solve_constraints #> (fn thm as Thm (der, args) =>
+    let
+      val thy = theory_of_thm thm;
+
+      val Deriv {promises, body} = der;
+      val {shyps, hyps, prop, tpairs, ...} = args;
+
+      fun err msg = raise THM ("name_derivation: " ^ msg, 0, [thm]);
+      val _ = null tpairs orelse err "bad flex-flex constraints";
 
-    val ps = map (apsnd (Future.map fulfill_body)) promises;
-    val (pthm, proof) =
-      Proofterm.thm_proof thy (classrel_proof thy) (arity_proof thy) name shyps hyps prop ps body;
-    val der' = make_deriv [] [] [pthm] proof;
-  in Thm (der', args) end;
+      val ps = map (apsnd (Future.map fulfill_body)) promises;
+      val (pthm, proof) =
+        Proofterm.thm_proof thy (classrel_proof thy) (arity_proof thy) name shyps hyps prop ps body;
+      val der' = make_deriv [] [] [pthm] proof;
+    in Thm (der', args) end);
 
-fun close_derivation thm =
-  if derivation_closed thm then thm else name_derivation "" thm;
+val close_derivation =
+  solve_constraints #> (fn thm =>
+    if derivation_closed thm then thm else name_derivation "" thm);
+
+val trim_context = solve_constraints #> trim_context_thm;
 
 
 
@@ -915,6 +1017,7 @@
              {cert = Context.join_certificate (Context.Certificate thy', cert2),
               tags = [],
               maxidx = maxidx,
+              constraints = [],
               shyps = sorts,
               hyps = [],
               tpairs = [],
@@ -943,6 +1046,7 @@
      {cert = cert,
       tags = [],
       maxidx = ~1,
+      constraints = [],
       shyps = sorts,
       hyps = [prop],
       tpairs = [],
@@ -958,7 +1062,7 @@
 *)
 fun implies_intr
     (ct as Cterm {t = A, T, maxidx = maxidx1, sorts, ...})
-    (th as Thm (der, {maxidx = maxidx2, hyps, shyps, tpairs, prop, ...})) =
+    (th as Thm (der, {maxidx = maxidx2, hyps, constraints, shyps, tpairs, prop, ...})) =
   if T <> propT then
     raise THM ("implies_intr: assumptions must have type prop", 0, [th])
   else
@@ -966,6 +1070,7 @@
      {cert = join_certificate1 (ct, th),
       tags = [],
       maxidx = Int.max (maxidx1, maxidx2),
+      constraints = constraints,
       shyps = Sorts.union sorts shyps,
       hyps = remove_hyps A hyps,
       tpairs = tpairs,
@@ -979,24 +1084,26 @@
 *)
 fun implies_elim thAB thA =
   let
-    val Thm (derA, {maxidx = maxidx1, hyps = hypsA, shyps = shypsA, tpairs = tpairsA,
-      prop = propA, ...}) = thA
-    and Thm (der, {maxidx = maxidx2, hyps, shyps, tpairs, prop, ...}) = thAB;
+    val Thm (derA,
+      {maxidx = maxidx1, hyps = hypsA, constraints = constraintsA, shyps = shypsA,
+        tpairs = tpairsA, prop = propA, ...}) = thA
+    and Thm (der, {maxidx = maxidx2, hyps, constraints, shyps, tpairs, prop, ...}) = thAB;
     fun err () = raise THM ("implies_elim: major premise", 0, [thAB, thA]);
   in
-    case prop of
+    (case prop of
       Const ("Pure.imp", _) $ A $ B =>
         if A aconv propA then
           Thm (deriv_rule2 (curry Proofterm.%%) der derA,
            {cert = join_certificate2 (thAB, thA),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
+            constraints = union_constraints constraintsA constraints,
             shyps = Sorts.union shypsA shyps,
             hyps = union_hyps hypsA hyps,
             tpairs = union_tpairs tpairsA tpairs,
             prop = B})
         else err ()
-    | _ => err ()
+    | _ => err ())
   end;
 
 (*Forall introduction.  The Free or Var x must not be free in the hypotheses.
@@ -1008,13 +1115,14 @@
 *)
 fun forall_intr
     (ct as Cterm {maxidx = maxidx1, t = x, T, sorts, ...})
-    (th as Thm (der, {maxidx = maxidx2, shyps, hyps, tpairs, prop, ...})) =
+    (th as Thm (der, {maxidx = maxidx2, constraints, shyps, hyps, tpairs, prop, ...})) =
   let
     fun result a =
       Thm (deriv_rule1 (Proofterm.forall_intr_proof x a) der,
        {cert = join_certificate1 (ct, th),
         tags = [],
         maxidx = Int.max (maxidx1, maxidx2),
+        constraints = constraints,
         shyps = Sorts.union sorts shyps,
         hyps = hyps,
         tpairs = tpairs,
@@ -1037,7 +1145,7 @@
 *)
 fun forall_elim
     (ct as Cterm {t, T, maxidx = maxidx1, sorts, ...})
-    (th as Thm (der, {maxidx = maxidx2, shyps, hyps, tpairs, prop, ...})) =
+    (th as Thm (der, {maxidx = maxidx2, constraints, shyps, hyps, tpairs, prop, ...})) =
   (case prop of
     Const ("Pure.all", Type ("fun", [Type ("fun", [qary, _]), _])) $ A =>
       if T <> qary then
@@ -1047,6 +1155,7 @@
          {cert = join_certificate1 (ct, th),
           tags = [],
           maxidx = Int.max (maxidx1, maxidx2),
+          constraints = constraints,
           shyps = Sorts.union sorts shyps,
           hyps = hyps,
           tpairs = tpairs,
@@ -1064,6 +1173,7 @@
    {cert = cert,
     tags = [],
     maxidx = maxidx,
+    constraints = [],
     shyps = sorts,
     hyps = [],
     tpairs = [],
@@ -1074,13 +1184,14 @@
   ------
   u \<equiv> t
 *)
-fun symmetric (th as Thm (der, {cert, maxidx, shyps, hyps, tpairs, prop, ...})) =
+fun symmetric (th as Thm (der, {cert, maxidx, constraints, shyps, hyps, tpairs, prop, ...})) =
   (case prop of
     (eq as Const ("Pure.eq", _)) $ t $ u =>
       Thm (deriv_rule1 Proofterm.symmetric der,
        {cert = cert,
         tags = [],
         maxidx = maxidx,
+        constraints = constraints,
         shyps = shyps,
         hyps = hyps,
         tpairs = tpairs,
@@ -1094,10 +1205,10 @@
 *)
 fun transitive th1 th2 =
   let
-    val Thm (der1, {maxidx = maxidx1, hyps = hyps1, shyps = shyps1, tpairs = tpairs1,
-      prop = prop1, ...}) = th1
-    and Thm (der2, {maxidx = maxidx2, hyps = hyps2, shyps = shyps2, tpairs = tpairs2,
-      prop = prop2, ...}) = th2;
+    val Thm (der1, {maxidx = maxidx1, hyps = hyps1, constraints = constraints1, shyps = shyps1,
+        tpairs = tpairs1, prop = prop1, ...}) = th1
+    and Thm (der2, {maxidx = maxidx2, hyps = hyps2, constraints = constraints2, shyps = shyps2,
+        tpairs = tpairs2, prop = prop2, ...}) = th2;
     fun err msg = raise THM ("transitive: " ^ msg, 0, [th1, th2]);
   in
     case (prop1, prop2) of
@@ -1108,6 +1219,7 @@
            {cert = join_certificate2 (th1, th2),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
+            constraints = union_constraints constraints1 constraints2,
             shyps = Sorts.union shyps1 shyps2,
             hyps = union_hyps hyps1 hyps2,
             tpairs = union_tpairs tpairs1 tpairs2,
@@ -1130,6 +1242,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx,
+      constraints = [],
       shyps = sorts,
       hyps = [],
       tpairs = [],
@@ -1141,6 +1254,7 @@
    {cert = cert,
     tags = [],
     maxidx = maxidx,
+    constraints = [],
     shyps = sorts,
     hyps = [],
     tpairs = [],
@@ -1151,6 +1265,7 @@
    {cert = cert,
     tags = [],
     maxidx = maxidx,
+    constraints = [],
     shyps = sorts,
     hyps = [],
     tpairs = [],
@@ -1164,7 +1279,7 @@
 *)
 fun abstract_rule a
     (Cterm {t = x, T, sorts, ...})
-    (th as Thm (der, {cert, maxidx, hyps, shyps, tpairs, prop, ...})) =
+    (th as Thm (der, {cert, maxidx, hyps, constraints, shyps, tpairs, prop, ...})) =
   let
     val (t, u) = Logic.dest_equals prop
       handle TERM _ => raise THM ("abstract_rule: premise not an equality", 0, [th]);
@@ -1173,6 +1288,7 @@
        {cert = cert,
         tags = [],
         maxidx = maxidx,
+        constraints = constraints,
         shyps = Sorts.union sorts shyps,
         hyps = hyps,
         tpairs = tpairs,
@@ -1196,10 +1312,10 @@
 *)
 fun combination th1 th2 =
   let
-    val Thm (der1, {maxidx = maxidx1, shyps = shyps1, hyps = hyps1, tpairs = tpairs1,
-      prop = prop1, ...}) = th1
-    and Thm (der2, {maxidx = maxidx2, shyps = shyps2, hyps = hyps2, tpairs = tpairs2,
-      prop = prop2, ...}) = th2;
+    val Thm (der1, {maxidx = maxidx1, constraints = constraints1, shyps = shyps1,
+        hyps = hyps1, tpairs = tpairs1, prop = prop1, ...}) = th1
+    and Thm (der2, {maxidx = maxidx2, constraints = constraints2, shyps = shyps2,
+        hyps = hyps2, tpairs = tpairs2, prop = prop2, ...}) = th2;
     fun chktypes fT tT =
       (case fT of
         Type ("fun", [T1, _]) =>
@@ -1216,6 +1332,7 @@
            {cert = join_certificate2 (th1, th2),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
+            constraints = union_constraints constraints1 constraints2,
             shyps = Sorts.union shyps1 shyps2,
             hyps = union_hyps hyps1 hyps2,
             tpairs = union_tpairs tpairs1 tpairs2,
@@ -1230,10 +1347,10 @@
 *)
 fun equal_intr th1 th2 =
   let
-    val Thm (der1, {maxidx = maxidx1, shyps = shyps1, hyps = hyps1, tpairs = tpairs1,
-      prop = prop1, ...}) = th1
-    and Thm (der2, {maxidx = maxidx2, shyps = shyps2, hyps = hyps2, tpairs = tpairs2,
-      prop = prop2, ...}) = th2;
+    val Thm (der1, {maxidx = maxidx1, constraints = constraints1, shyps = shyps1,
+      hyps = hyps1, tpairs = tpairs1, prop = prop1, ...}) = th1
+    and Thm (der2, {maxidx = maxidx2, constraints = constraints2, shyps = shyps2,
+      hyps = hyps2, tpairs = tpairs2, prop = prop2, ...}) = th2;
     fun err msg = raise THM ("equal_intr: " ^ msg, 0, [th1, th2]);
   in
     (case (prop1, prop2) of
@@ -1243,6 +1360,7 @@
            {cert = join_certificate2 (th1, th2),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
+            constraints = union_constraints constraints1 constraints2,
             shyps = Sorts.union shyps1 shyps2,
             hyps = union_hyps hyps1 hyps2,
             tpairs = union_tpairs tpairs1 tpairs2,
@@ -1258,10 +1376,10 @@
 *)
 fun equal_elim th1 th2 =
   let
-    val Thm (der1, {maxidx = maxidx1, shyps = shyps1, hyps = hyps1,
-      tpairs = tpairs1, prop = prop1, ...}) = th1
-    and Thm (der2, {maxidx = maxidx2, shyps = shyps2, hyps = hyps2,
-      tpairs = tpairs2, prop = prop2, ...}) = th2;
+    val Thm (der1, {maxidx = maxidx1, constraints = constraints1, shyps = shyps1,
+      hyps = hyps1, tpairs = tpairs1, prop = prop1, ...}) = th1
+    and Thm (der2, {maxidx = maxidx2, constraints = constraints2, shyps = shyps2,
+      hyps = hyps2, tpairs = tpairs2, prop = prop2, ...}) = th2;
     fun err msg = raise THM ("equal_elim: " ^ msg, 0, [th1, th2]);
   in
     (case prop1 of
@@ -1271,6 +1389,7 @@
            {cert = join_certificate2 (th1, th2),
             tags = [],
             maxidx = Int.max (maxidx1, maxidx2),
+            constraints = union_constraints constraints1 constraints2,
             shyps = Sorts.union shyps1 shyps2,
             hyps = union_hyps hyps1 hyps2,
             tpairs = union_tpairs tpairs1 tpairs2,
@@ -1287,25 +1406,31 @@
   Instantiates the theorem and deletes trivial tpairs.  Resulting
   sequence may contain multiple elements if the tpairs are not all
   flex-flex.*)
-fun flexflex_rule opt_ctxt (th as Thm (der, {cert, maxidx, shyps, hyps, tpairs, prop, ...})) =
-  let val (context, cert') = make_context_certificate [th] opt_ctxt cert in
-    Unify.smash_unifiers context tpairs (Envir.empty maxidx)
-    |> Seq.map (fn env =>
-        if Envir.is_empty env then th
-        else
-          let
-            val tpairs' = tpairs |> map (apply2 (Envir.norm_term env))
-              (*remove trivial tpairs, of the form t \<equiv> t*)
-              |> filter_out (op aconv);
-            val der' = deriv_rule1 (Proofterm.norm_proof' env) der;
-            val prop' = Envir.norm_term env prop;
-            val maxidx = maxidx_tpairs tpairs' (maxidx_of_term prop');
-            val shyps = Envir.insert_sorts env shyps;
-          in
-            Thm (der', {cert = cert', tags = [], maxidx = maxidx,
-              shyps = shyps, hyps = hyps, tpairs = tpairs', prop = prop'})
-          end)
-  end;
+fun flexflex_rule opt_ctxt =
+  solve_constraints #> (fn th =>
+    let
+      val Thm (der, {cert, maxidx, constraints, shyps, hyps, tpairs, prop, ...}) = th;
+      val (context, cert') = make_context_certificate [th] opt_ctxt cert;
+    in
+      Unify.smash_unifiers context tpairs (Envir.empty maxidx)
+      |> Seq.map (fn env =>
+          if Envir.is_empty env then th
+          else
+            let
+              val tpairs' = tpairs |> map (apply2 (Envir.norm_term env))
+                (*remove trivial tpairs, of the form t \<equiv> t*)
+                |> filter_out (op aconv);
+              val der' = deriv_rule1 (Proofterm.norm_proof' env) der;
+              val constraints' =
+                insert_constraints_env (Context.certificate_theory cert') env constraints;
+              val prop' = Envir.norm_term env prop;
+              val maxidx = maxidx_tpairs tpairs' (maxidx_of_term prop');
+              val shyps = Envir.insert_sorts env shyps;
+            in
+              Thm (der', {cert = cert', tags = [], maxidx = maxidx, constraints = constraints',
+                shyps = shyps, hyps = hyps, tpairs = tpairs', prop = prop'})
+            end)
+    end);
 
 
 (*Generalization of fixed variables
@@ -1317,7 +1442,7 @@
 fun generalize ([], []) _ th = th
   | generalize (tfrees, frees) idx th =
       let
-        val Thm (der, {cert, maxidx, shyps, hyps, tpairs, prop, ...}) = th;
+        val Thm (der, {cert, maxidx, constraints, shyps, hyps, tpairs, prop, ...}) = th;
         val _ = idx <= maxidx andalso raise THM ("generalize: bad index", idx, [th]);
 
         val bad_type =
@@ -1341,6 +1466,7 @@
          {cert = cert,
           tags = [],
           maxidx = maxidx',
+          constraints = constraints,
           shyps = shyps,
           hyps = hyps,
           tpairs = tpairs',
@@ -1416,7 +1542,7 @@
 fun instantiate ([], []) th = th
   | instantiate (instT, inst) th =
       let
-        val Thm (der, {cert, hyps, shyps, tpairs, prop, ...}) = th;
+        val Thm (der, {cert, hyps, constraints, shyps, tpairs, prop, ...}) = th;
         val (inst', (instT', (cert', shyps'))) =
           (cert, shyps) |> fold_map add_inst inst ||> fold_map add_instT instT
             handle CONTEXT (msg, cTs, cts, ths, context) =>
@@ -1425,12 +1551,17 @@
         val (prop', maxidx1) = subst prop ~1;
         val (tpairs', maxidx') =
           fold_map (fn (t, u) => fn i => subst t i ||>> subst u) tpairs maxidx1;
+
+        val thy' = Context.certificate_theory cert';
+        val constraints' =
+          fold (fn ((_, S), (T, _)) => insert_constraints thy' (T, S)) instT' constraints;
       in
         Thm (deriv_rule1
           (fn d => Proofterm.instantiate (map (apsnd #1) instT', map (apsnd #1) inst') d) der,
          {cert = cert',
           tags = [],
           maxidx = maxidx',
+          constraints = constraints',
           shyps = shyps',
           hyps = hyps,
           tpairs = tpairs',
@@ -1464,6 +1595,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx,
+      constraints = [],
       shyps = sorts,
       hyps = [],
       tpairs = [],
@@ -1487,6 +1619,7 @@
        {cert = cert,
         tags = [],
         maxidx = maxidx,
+        constraints = insert_constraints thy (T, [c]) [],
         shyps = sorts,
         hyps = [],
         tpairs = [],
@@ -1496,7 +1629,7 @@
 
 (*Remove extra sorts that are witnessed by type signature information*)
 fun strip_shyps (thm as Thm (_, {shyps = [], ...})) = thm
-  | strip_shyps (thm as Thm (der, {cert, tags, maxidx, shyps, hyps, tpairs, prop})) =
+  | strip_shyps (thm as Thm (der, {cert, tags, maxidx, constraints, shyps, hyps, tpairs, prop})) =
       let
         val thy = theory_of_thm thm;
         val algebra = Sign.classes_of thy;
@@ -1510,41 +1643,43 @@
       in
         Thm (deriv_rule_unconditional
           (Proofterm.strip_shyps_proof algebra present witnessed extra') der,
-         {cert = cert, tags = tags, maxidx = maxidx,
+         {cert = cert, tags = tags, maxidx = maxidx, constraints = constraints,
           shyps = shyps', hyps = hyps, tpairs = tpairs, prop = prop})
       end;
 
 (*Internalize sort constraints of type variables*)
-fun unconstrainT (thm as Thm (der, args)) =
-  let
-    val Deriv {promises, body} = der;
-    val {cert, shyps, hyps, tpairs, prop, ...} = args;
-    val thy = theory_of_thm thm;
+val unconstrainT =
+  solve_constraints #> (fn thm as Thm (der, args) =>
+    let
+      val Deriv {promises, body} = der;
+      val {cert, shyps, hyps, tpairs, prop, ...} = args;
+      val thy = theory_of_thm thm;
 
-    fun err msg = raise THM ("unconstrainT: " ^ msg, 0, [thm]);
-    val _ = null hyps orelse err "illegal hyps";
-    val _ = null tpairs orelse err "unsolved flex-flex constraints";
-    val tfrees = rev (Term.add_tfree_names prop []);
-    val _ = null tfrees orelse err ("illegal free type variables " ^ commas_quote tfrees);
+      fun err msg = raise THM ("unconstrainT: " ^ msg, 0, [thm]);
+      val _ = null hyps orelse err "bad hyps";
+      val _ = null tpairs orelse err "bad flex-flex constraints";
+      val tfrees = rev (Term.add_tfree_names prop []);
+      val _ = null tfrees orelse err ("illegal free type variables " ^ commas_quote tfrees);
 
-    val ps = map (apsnd (Future.map fulfill_body)) promises;
-    val (pthm, proof) =
-      Proofterm.unconstrain_thm_proof thy (classrel_proof thy) (arity_proof thy) shyps prop ps body;
-    val der' = make_deriv [] [] [pthm] proof;
-    val prop' = Proofterm.thm_node_prop (#2 pthm);
-  in
-    Thm (der',
-     {cert = cert,
-      tags = [],
-      maxidx = maxidx_of_term prop',
-      shyps = [[]],  (*potentially redundant*)
-      hyps = [],
-      tpairs = [],
-      prop = prop'})
-  end;
+      val ps = map (apsnd (Future.map fulfill_body)) promises;
+      val (pthm, proof) =
+        Proofterm.unconstrain_thm_proof thy (classrel_proof thy) (arity_proof thy) shyps prop ps body;
+      val der' = make_deriv [] [] [pthm] proof;
+      val prop' = Proofterm.thm_node_prop (#2 pthm);
+    in
+      Thm (der',
+       {cert = cert,
+        tags = [],
+        maxidx = maxidx_of_term prop',
+        constraints = [],
+        shyps = [[]],  (*potentially redundant*)
+        hyps = [],
+        tpairs = [],
+        prop = prop'})
+    end);
 
 (*Replace all TFrees not fixed or in the hyps by new TVars*)
-fun varifyT_global' fixed (Thm (der, {cert, maxidx, shyps, hyps, tpairs, prop, ...})) =
+fun varifyT_global' fixed (Thm (der, {cert, maxidx, constraints, shyps, hyps, tpairs, prop, ...})) =
   let
     val tfrees = fold Term.add_tfrees hyps fixed;
     val prop1 = attach_tpairs tpairs prop;
@@ -1555,6 +1690,7 @@
      {cert = cert,
       tags = [],
       maxidx = Int.max (0, maxidx),
+      constraints = constraints,
       shyps = shyps,
       hyps = hyps,
       tpairs = rev (map Logic.dest_equals ts),
@@ -1564,7 +1700,7 @@
 val varifyT_global = #2 o varifyT_global' [];
 
 (*Replace all TVars by TFrees that are often new*)
-fun legacy_freezeT (Thm (der, {cert, shyps, hyps, tpairs, prop, ...})) =
+fun legacy_freezeT (Thm (der, {cert, constraints, shyps, hyps, tpairs, prop, ...})) =
   let
     val prop1 = attach_tpairs tpairs prop;
     val prop2 = Type.legacy_freeze prop1;
@@ -1574,6 +1710,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx_of_term prop2,
+      constraints = constraints,
       shyps = shyps,
       hyps = hyps,
       tpairs = rev (map Logic.dest_equals ts),
@@ -1595,6 +1732,7 @@
   end;
 
 
+
 (*** Inference rules for tactics ***)
 
 (*Destruct proof state into constraints, other goals, goal(i), rest *)
@@ -1612,7 +1750,7 @@
     val inc = gmax + 1;
     val lift_abs = Logic.lift_abs inc gprop;
     val lift_all = Logic.lift_all inc gprop;
-    val Thm (der, {maxidx, shyps, hyps, tpairs, prop, ...}) = orule;
+    val Thm (der, {maxidx, constraints, shyps, hyps, tpairs, prop, ...}) = orule;
     val (As, B) = Logic.strip_horn prop;
   in
     if T <> propT then raise THM ("lift_rule: the term must have type prop", 0, [])
@@ -1621,13 +1759,14 @@
        {cert = join_certificate1 (goal, orule),
         tags = [],
         maxidx = maxidx + inc,
+        constraints = constraints,
         shyps = Sorts.union shyps sorts,  (*sic!*)
         hyps = hyps,
         tpairs = map (apply2 lift_abs) tpairs,
         prop = Logic.list_implies (map lift_all As, lift_all B)})
   end;
 
-fun incr_indexes i (thm as Thm (der, {cert, maxidx, shyps, hyps, tpairs, prop, ...})) =
+fun incr_indexes i (thm as Thm (der, {cert, maxidx, constraints, shyps, hyps, tpairs, prop, ...})) =
   if i < 0 then raise THM ("negative increment", 0, [thm])
   else if i = 0 then thm
   else
@@ -1635,6 +1774,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx + i,
+      constraints = constraints,
       shyps = shyps,
       hyps = hyps,
       tpairs = map (apply2 (Logic.incr_indexes ([], [], i))) tpairs,
@@ -1643,7 +1783,7 @@
 (*Solve subgoal Bi of proof state B1...Bn/C by assumption. *)
 fun assumption opt_ctxt i state =
   let
-    val Thm (der, {cert, maxidx, shyps, hyps, ...}) = state;
+    val Thm (der, {cert, maxidx, constraints, shyps, hyps, ...}) = state;
     val (context, cert') = make_context_certificate [state] opt_ctxt cert;
     val (tpairs, Bs, Bi, C) = dest_state (state, i);
     fun newth n (env, tpairs) =
@@ -1652,6 +1792,7 @@
             Proofterm.assumption_proof Bs Bi n) der,
        {tags = [],
         maxidx = Envir.maxidx_of env,
+        constraints = insert_constraints_env (Context.certificate_theory cert') env constraints,
         shyps = Envir.insert_sorts env shyps,
         hyps = hyps,
         tpairs =
@@ -1679,7 +1820,7 @@
   Checks if Bi's conclusion is alpha/eta-convertible to one of its assumptions*)
 fun eq_assumption i state =
   let
-    val Thm (der, {cert, maxidx, shyps, hyps, ...}) = state;
+    val Thm (der, {cert, maxidx, constraints, shyps, hyps, ...}) = state;
     val (tpairs, Bs, Bi, C) = dest_state (state, i);
     val (_, asms, concl) = Logic.assum_problems (~1, Bi);
   in
@@ -1690,6 +1831,7 @@
          {cert = cert,
           tags = [],
           maxidx = maxidx,
+          constraints = constraints,
           shyps = shyps,
           hyps = hyps,
           tpairs = tpairs,
@@ -1700,7 +1842,7 @@
 (*For rotate_tac: fast rotation of assumptions of subgoal i*)
 fun rotate_rule k i state =
   let
-    val Thm (der, {cert, maxidx, shyps, hyps, ...}) = state;
+    val Thm (der, {cert, maxidx, constraints, shyps, hyps, ...}) = state;
     val (tpairs, Bs, Bi, C) = dest_state (state, i);
     val params = Term.strip_all_vars Bi;
     val rest = Term.strip_all_body Bi;
@@ -1719,6 +1861,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx,
+      constraints = constraints,
       shyps = shyps,
       hyps = hyps,
       tpairs = tpairs,
@@ -1731,7 +1874,7 @@
   number of premises.  Useful with eresolve_tac and underlies defer_tac*)
 fun permute_prems j k rl =
   let
-    val Thm (der, {cert, maxidx, shyps, hyps, tpairs, prop, ...}) = rl;
+    val Thm (der, {cert, maxidx, constraints, shyps, hyps, tpairs, prop, ...}) = rl;
     val prems = Logic.strip_imp_prems prop
     and concl = Logic.strip_imp_concl prop;
     val moved_prems = List.drop (prems, j)
@@ -1750,6 +1893,7 @@
      {cert = cert,
       tags = [],
       maxidx = maxidx,
+      constraints = constraints,
       shyps = shyps,
       hyps = hyps,
       tpairs = tpairs,
@@ -1814,6 +1958,7 @@
   rename_bvs (fold_rev Term.match_bvars dpairs []) dpairs;
 
 
+
 (*** RESOLUTION ***)
 
 (** Lifting optimizations **)
@@ -1866,8 +2011,8 @@
 in
 fun bicompose_aux opt_ctxt {flatten, match, incremented} (state, (stpairs, Bs, Bi, C), lifted)
                         (eres_flg, orule, nsubgoal) =
- let val Thm (sder, {maxidx=smax, shyps=sshyps, hyps=shyps, ...}) = state
-     and Thm (rder, {maxidx=rmax, shyps=rshyps, hyps=rhyps,
+ let val Thm (sder, {maxidx=smax, constraints = constraints2, shyps = shyps2, hyps = hyps2, ...}) = state
+     and Thm (rder, {maxidx=rmax, constraints = constraints1, shyps = shyps1, hyps = hyps1,
              tpairs=rtpairs, prop=rprop,...}) = orule
          (*How many hyps to skip over during normalization*)
      and nlift = Logic.count_prems (strip_all_body Bi) + (if eres_flg then ~1 else 0)
@@ -1890,6 +2035,9 @@
                 else (*normalize the new rule fully*)
                   (ntps, (map normt (Bs @ As), normt C))
              end
+           val constraints' =
+             union_constraints constraints1 constraints2
+             |> insert_constraints_env (Context.certificate_theory cert) env;
            val th =
              Thm (deriv_rule2
                    ((if Envir.is_empty env then I
@@ -1900,8 +2048,9 @@
                     (Proofterm.bicompose_proof flatten Bs oldAs As A n (nlift+1))) rder' sder,
                 {tags = [],
                  maxidx = Envir.maxidx_of env,
-                 shyps = Envir.insert_sorts env (Sorts.union rshyps sshyps),
-                 hyps = union_hyps rhyps shyps,
+                 constraints = constraints',
+                 shyps = Envir.insert_sorts env (Sorts.union shyps1 shyps2),
+                 hyps = union_hyps hyps1 hyps2,
                  tpairs = ntpairs,
                  prop = Logic.list_implies normp,
                  cert = cert})