ctyp: added 'sorts' field;
authorwenzelm
Fri, 01 Jul 2005 14:42:00 +0200
changeset 16656 18b0cb22057d
parent 16655 3e4d726aaed1
child 16657 a6f65f47eda1
ctyp: added 'sorts' field; may_insert_typ/term/env_sorts: observe Sign.all_sorts_nonempty; may_insert_env_sorts: insert sorts of type subst only; instantiate: insert sorts of insts; tuned;
src/Pure/thm.ML
--- a/src/Pure/thm.ML	Fri Jul 01 14:41:59 2005 +0200
+++ b/src/Pure/thm.ML	Fri Jul 01 14:42:00 2005 +0200
@@ -11,7 +11,11 @@
   sig
   (*certified types*)
   type ctyp
-  val rep_ctyp: ctyp -> {thy: theory, sign: theory, T: typ}
+  val rep_ctyp: ctyp ->
+   {thy: theory,
+    sign: theory,       (*obsolete*)
+    T: typ,
+    sorts: sort list}
   val theory_of_ctyp: ctyp -> theory
   val typ_of: ctyp -> typ
   val ctyp_of: theory -> typ -> ctyp
@@ -21,7 +25,12 @@
   type cterm
   exception CTERM of string
   val rep_cterm: cterm ->
-    {thy: theory, sign: theory, t: term, T: typ, maxidx: int, sorts: sort list}
+   {thy: theory,
+    sign: theory,       (*obsolete*)
+    t: term,
+    T: typ,
+    maxidx: int,
+    sorts: sort list}
   val crep_cterm: cterm ->
     {thy: theory, sign: theory, t: term, T: ctyp, maxidx: int, sorts: sort list}
   val theory_of_cterm: cterm -> theory
@@ -44,7 +53,8 @@
   (*meta theorems*)
   type thm
   val rep_thm: thm ->
-   {thy: theory, sign: theory,
+   {thy: theory,
+    sign: theory,       (*obsolete*)
     der: bool * Proofterm.proof,
     maxidx: int,
     shyps: sort list,
@@ -52,7 +62,8 @@
     tpairs: (term * term) list,
     prop: term}
   val crep_thm: thm ->
-   {thy: theory, sign: theory,
+   {thy: theory,
+    sign: theory,       (*obsolete*)
     der: bool * Proofterm.proof,
     maxidx: int,
     shyps: sort list,
@@ -67,12 +78,12 @@
   val sign_of_thm: thm -> theory    (*obsolete*)
   val prop_of: thm -> term
   val proof_of: thm -> Proofterm.proof
-  val transfer: theory -> thm -> thm
   val tpairs_of: thm -> (term * term) list
+  val concl_of: thm -> term
   val prems_of: thm -> term list
   val nprems_of: thm -> int
-  val concl_of: thm -> term
   val cprop_of: thm -> cterm
+  val transfer: theory -> thm -> thm
   val extra_shyps: thm -> sort list
   val strip_shyps: thm -> thm
   val get_axiom_i: theory -> string -> thm
@@ -146,28 +157,48 @@
 structure Thm: THM =
 struct
 
+
 (*** Certified terms and types ***)
 
+(** collect occurrences of sorts -- unless all sorts non-empty **)
+
+fun may_insert_typ_sorts thy T =
+  if Sign.all_sorts_nonempty thy then I
+  else Sorts.insert_typ T;
+
+fun may_insert_term_sorts thy t =
+  if Sign.all_sorts_nonempty thy then I
+  else Sorts.insert_term t;
+
+(*NB: type unification may invent new sorts*)
+fun may_insert_env_sorts thy (env as Envir.Envir {iTs, ...}) =
+  if Sign.all_sorts_nonempty thy then I
+  else Vartab.fold (fn (_, (_, T)) => Sorts.insert_typ T) iTs;
+
+
+
 (** certified types **)
 
-datatype ctyp = Ctyp of {thy_ref: theory_ref, T: typ};
+datatype ctyp = Ctyp of {thy_ref: theory_ref, T: typ, sorts: sort list};
 
-fun rep_ctyp (Ctyp {thy_ref, T}) =
+fun rep_ctyp (Ctyp {thy_ref, T, sorts}) =
   let val thy = Theory.deref thy_ref
-  in {thy = thy, sign = thy, T = T} end;
+  in {thy = thy, sign = thy, T = T, sorts = sorts} end;
 
-val theory_of_ctyp = #thy o rep_ctyp;
+fun theory_of_ctyp (Ctyp {thy_ref, ...}) = Theory.deref thy_ref;
 
 fun typ_of (Ctyp {T, ...}) = T;
 
-fun ctyp_of thy T =
-  Ctyp {thy_ref = Theory.self_ref thy, T = Sign.certify_typ thy T};
+fun ctyp_of thy raw_T =
+  let val T = Sign.certify_typ thy raw_T
+  in Ctyp {thy_ref = Theory.self_ref thy, T = T, sorts = may_insert_typ_sorts thy T []} end;
 
 fun read_ctyp thy s =
-  Ctyp {thy_ref = Theory.self_ref thy, T = Sign.read_typ (thy, K NONE) s};
+  let val T = Sign.read_typ (thy, K NONE) s
+  in Ctyp {thy_ref = Theory.self_ref thy, T = T, sorts = may_insert_typ_sorts thy T []} end;
 
-fun dest_ctyp (Ctyp {thy_ref, T = Type (s, Ts)}) =
-      map (fn T => Ctyp {thy_ref = thy_ref, T = T}) Ts
+fun dest_ctyp (Ctyp {thy_ref, T = Type (s, Ts), sorts}) =
+      map (fn T => Ctyp {thy_ref = thy_ref, T = T, sorts = sorts}) Ts
   | dest_ctyp cT = [cT];
 
 
@@ -188,7 +219,7 @@
 
 fun crep_cterm (Cterm {thy_ref, t, T, maxidx, sorts}) =
   let val thy = Theory.deref thy_ref in
-   {thy = thy, sign = thy, t = t, T = Ctyp {thy_ref = thy_ref, T = T},
+   {thy = thy, sign = thy, t = t, T = Ctyp {thy_ref = thy_ref, T = T, sorts = sorts},
     maxidx = maxidx, sorts = sorts}
   end;
 
@@ -197,33 +228,38 @@
 
 fun term_of (Cterm {t, ...}) = t;
 
-fun ctyp_of_term (Cterm {thy_ref, T, ...}) = Ctyp {thy_ref = thy_ref, T = T};
+fun ctyp_of_term (Cterm {thy_ref, T, sorts, ...}) =
+  Ctyp {thy_ref = thy_ref, T = T, sorts = sorts};
 
 fun cterm_of thy tm =
   let
     val (t, T, maxidx) = Sign.certify_term (Sign.pp thy) thy tm;
-    val sorts = Sorts.insert_term t [];
+    val sorts = may_insert_term_sorts thy t [];
   in Cterm {thy_ref = Theory.self_ref thy, t = t, T = T, maxidx = maxidx, sorts = sorts} end;
 
+fun merge_thys0 (Cterm {thy_ref = r1, ...}) (Cterm {thy_ref = r2, ...}) =
+  Theory.merge_refs (r1, r2);
+
 exception CTERM of string;
 
 (*Destruct application in cterms*)
 fun dest_comb (Cterm {thy_ref, T, maxidx, sorts, t = A $ B}) =
-      let val typeA = fastype_of A;
-          val typeB =
-            case typeA of Type("fun",[S,T]) => S
-                        | _ => sys_error "Function type expected in dest_comb";
+      let
+        val typeA = fastype_of A;
+        val typeB =
+          (case typeA of Type ("fun", [S, T]) => S
+          | _ => sys_error "Function type expected in dest_comb");
       in
-      (Cterm {thy_ref=thy_ref, maxidx=maxidx, sorts=sorts, t=A, T=typeA},
-       Cterm {thy_ref=thy_ref, maxidx=maxidx, sorts=sorts, t=B, T=typeB})
+        (Cterm {t = A, T = typeA, thy_ref = thy_ref, maxidx = maxidx, sorts = sorts},
+         Cterm {t = B, T = typeB, thy_ref = thy_ref, maxidx = maxidx, sorts = sorts})
       end
   | dest_comb _ = raise CTERM "dest_comb";
 
 (*Destruct abstraction in cterms*)
-fun dest_abs a (Cterm {thy_ref, T as Type("fun",[_,S]), maxidx, sorts, t=Abs(x,ty,M)}) =
-      let val (y,N) = variant_abs (if_none a x, ty, M)
-      in (Cterm {thy_ref = thy_ref, T = ty, maxidx = 0, sorts = sorts, t = Free(y,ty)},
-          Cterm {thy_ref = thy_ref, T = S, maxidx = maxidx, sorts = sorts, t = N})
+fun dest_abs a (Cterm {t = Abs (x, ty, M), T as Type("fun",[_,S]), thy_ref, maxidx, sorts}) =
+      let val (y, N) = variant_abs (if_none a x, ty, M) in
+        (Cterm {t = Free (y, ty), T = ty, thy_ref = thy_ref, maxidx = 0, sorts = sorts},
+          Cterm {t = N, T = S, thy_ref = thy_ref, maxidx = maxidx, sorts = sorts})
       end
   | dest_abs _ _ = raise CTERM "dest_abs";
 
@@ -234,43 +270,45 @@
 
 (*Form cterm out of a function and an argument*)
 fun capply
-  (Cterm {t=f, thy_ref=thy_ref1, T=Type("fun",[dty,rty]), maxidx=maxidx1, sorts = sorts1})
-  (Cterm {t=x, thy_ref=thy_ref2, T, maxidx=maxidx2, sorts = sorts2}) =
+  (cf as Cterm {t = f, T = Type ("fun", [dty, rty]), maxidx = maxidx1, sorts = sorts1, ...})
+  (cx as Cterm {t = x, T, maxidx = maxidx2, sorts = sorts2, ...}) =
     if T = dty then
-      Cterm{t = f $ x,
-        thy_ref=Theory.merge_refs(thy_ref1,thy_ref2), T=rty,
-        maxidx=Int.max(maxidx1, maxidx2),
+      Cterm {thy_ref = merge_thys0 cf cx,
+        t = f $ x,
+        T = rty,
+        maxidx = Int.max (maxidx1, maxidx2),
         sorts = Sorts.union sorts1 sorts2}
       else raise CTERM "capply: types don't agree"
   | capply _ _ = raise CTERM "capply: first arg is not a function"
 
 fun cabs
-  (Cterm {t=t1, thy_ref=thy_ref1, T=T1, maxidx=maxidx1, sorts = sorts1})
-  (Cterm {t=t2, thy_ref=thy_ref2, T=T2, maxidx=maxidx2, sorts = sorts2}) =
+  (ct1 as Cterm {t = t1, T = T1, maxidx = maxidx1, sorts = sorts1, ...})
+  (ct2 as Cterm {t = t2, T = T2, maxidx = maxidx2, sorts = sorts2, ...}) =
     let val t = lambda t1 t2 handle TERM _ => raise CTERM "cabs: first arg is not a variable" in
-      Cterm {t = t, T = T1 --> T2, thy_ref = Theory.merge_refs (thy_ref1, thy_ref2),
-        maxidx = Int.max (maxidx1, maxidx2), sorts = Sorts.union sorts1 sorts2}
+      Cterm {thy_ref = merge_thys0 ct1 ct2,
+        t = t, T = T1 --> T2,
+        maxidx = Int.max (maxidx1, maxidx2),
+        sorts = Sorts.union sorts1 sorts2}
     end;
 
 (*Matching of cterms*)
-fun gen_cterm_match mtch
-    (Cterm {thy_ref = thy_ref1, maxidx = maxidx1, t = t1, sorts = sorts1, ...},
-     Cterm {thy_ref = thy_ref2, maxidx = maxidx2, t = t2, sorts = sorts2, ...}) =
+fun gen_cterm_match match
+    (ct1 as Cterm {t = t1, maxidx = maxidx1, sorts = sorts1, ...},
+     ct2 as Cterm {t = t2, maxidx = maxidx2, sorts = sorts2, ...}) =
   let
-    val thy_ref = Theory.merge_refs (thy_ref1, thy_ref2);
-    val tsig = Sign.tsig_of (Theory.deref thy_ref);
-    val (Tinsts, tinsts) = mtch tsig (t1, t2);
+    val thy_ref = merge_thys0 ct1 ct2;
+    val (Tinsts, tinsts) = match (Sign.tsig_of (Theory.deref thy_ref)) (t1, t2);
     val maxidx = Int.max (maxidx1, maxidx2);
     val sorts = Sorts.union sorts1 sorts2;
-    fun mk_cTinsts (ixn, (S, T)) =
-      (Ctyp {thy_ref = thy_ref, T = TVar (ixn, S)},
-       Ctyp {thy_ref = thy_ref, T = T});
-    fun mk_ctinsts (ixn, (T, t)) =
+    fun mk_cTinst (ixn, (S, T)) =
+      (Ctyp {T = TVar (ixn, S), thy_ref = thy_ref, sorts = sorts},
+       Ctyp {T = T, thy_ref = thy_ref, sorts = sorts});
+    fun mk_ctinst (ixn, (T, t)) =
       let val T = Envir.typ_subst_TVars Tinsts T in
-        (Cterm {thy_ref = thy_ref, maxidx = maxidx, T = T, t = Var (ixn, T), sorts = sorts},
-         Cterm {thy_ref = thy_ref, maxidx = maxidx, T = T, t = t, sorts = sorts})
+        (Cterm {t = Var (ixn, T), T = T, thy_ref = thy_ref, maxidx = maxidx, sorts = sorts},
+         Cterm {t = t, T = T, thy_ref = thy_ref, maxidx = maxidx, sorts = sorts})
       end;
-  in (map mk_cTinsts (Vartab.dest Tinsts), map mk_ctinsts (Vartab.dest tinsts)) end;
+  in (Vartab.fold (cons o mk_cTinst) Tinsts [], Vartab.fold (cons o mk_ctinst) tinsts []) end;
 
 val cterm_match = gen_cterm_match Pattern.match;
 val cterm_first_order_match = gen_cterm_match Pattern.first_order_match;
@@ -320,8 +358,10 @@
   tpairs: (term * term) list,  (*flex-flex pairs*)
   prop: term};                 (*conclusion*)
 
-fun terms_of_tpairs tpairs = List.concat (map (fn (t, u) => [t, u]) tpairs);
-val union_tpairs = gen_merge_lists (op = : (term * term) * (term * term) -> bool);
+fun terms_of_tpairs tpairs = fold_rev (fn (t, u) => cons t o cons u) tpairs [];
+
+fun eq_tpairs ((t, u), (t', u')) = t aconv t' andalso u aconv u';
+val union_tpairs = gen_merge_lists eq_tpairs;
 
 fun attach_tpairs tpairs prop =
   Logic.list_implies (map Logic.mk_equals tpairs, prop);
@@ -355,7 +395,7 @@
 fun applys_attributes (x_ths, atts) = foldl_map (Library.apply atts) x_ths;
 
 
-(* shyps and hyps *)
+(* hyps *)
 
 val remove_hyps = OrdList.remove Term.term_ord;
 val union_hyps = OrdList.union Term.term_ord;
@@ -374,7 +414,7 @@
     Context.joinable (thy1, thy2) andalso
     Sorts.eq_set (shyps1, shyps2) andalso
     eq_set_hyps (hyps1, hyps2) andalso
-    aconvs (terms_of_tpairs tpairs1, terms_of_tpairs tpairs2) andalso
+    equal_lists eq_tpairs (tpairs1, tpairs2) andalso
     prop1 aconv prop2
   end;
 
@@ -401,20 +441,7 @@
 fun cprop_of (Thm {thy_ref, maxidx, shyps, prop, ...}) =
   Cterm {thy_ref = thy_ref, maxidx = maxidx, T = propT, t = prop, sorts = shyps};
 
-
-(* merge theories of cterms/thms; raise exception if incompatible *)
-
-fun merge_thys1
-    (Cterm {thy_ref = r1, ...}) (th as Thm {thy_ref = r2, ...}) =
-  Theory.merge_refs (r1, r2) handle TERM (msg, _) => raise THM (msg, 0, [th]);
-
-fun merge_thys2
-    (th1 as Thm {thy_ref = r1, ...}) (th2 as Thm {thy_ref = r2, ...}) =
-  Theory.merge_refs (r1, r2) handle TERM (msg, _) => raise THM (msg, 0, [th1, th2]);
-
-
-(* explicit transfer thm to super theory *)
-
+(*explicit transfer to a super theory*)
 fun transfer thy' thm =
   let
     val Thm {thy_ref, der, maxidx, shyps, hyps, tpairs, prop} = thm;
@@ -428,38 +455,43 @@
   end;
 
 
+(* merge theories of cterms/thms; raise exception if incompatible *)
+
+fun merge_thys1 (Cterm {thy_ref = r1, ...}) (th as Thm {thy_ref = r2, ...}) =
+  Theory.merge_refs (r1, r2) handle TERM (msg, _) => raise THM (msg, 0, [th]);
+
+fun merge_thys2 (th1 as Thm {thy_ref = r1, ...}) (th2 as Thm {thy_ref = r2, ...}) =
+  Theory.merge_refs (r1, r2) handle TERM (msg, _) => raise THM (msg, 0, [th1, th2]);
+
+
 
 (** sort contexts of theorems **)
 
-fun insert_env_sorts (env as Envir.Envir {iTs, asol, ...}) =
-  Vartab.fold (fn (_, (_, t)) => Sorts.insert_term t) asol o
-  Vartab.fold (fn (_, (_, T)) => Sorts.insert_typ T) iTs;
-
-fun insert_thm_sorts (Thm {hyps, tpairs, prop, ...}) =
-  fold (fn (t, u) => Sorts.insert_term t o Sorts.insert_term u) tpairs o
-  Sorts.insert_terms hyps o Sorts.insert_term prop;
-
-(*dangling sort constraints of a thm*)
-fun extra_shyps (th as Thm {shyps, ...}) =
-  Sorts.subtract (insert_thm_sorts th []) shyps;
-
-
-(* strip_shyps *)
+fun present_sorts (Thm {hyps, tpairs, prop, ...}) =
+  fold (fn (t, u) => Sorts.insert_term t o Sorts.insert_term u) tpairs
+    (Sorts.insert_terms hyps (Sorts.insert_term prop []));
 
 (*remove extra sorts that are non-empty by virtue of type signature information*)
 fun strip_shyps (thm as Thm {shyps = [], ...}) = thm
   | strip_shyps (thm as Thm {thy_ref, der, maxidx, shyps, hyps, tpairs, prop}) =
       let
         val thy = Theory.deref thy_ref;
-        val present_sorts = insert_thm_sorts thm [];
-        val extra_shyps = Sorts.subtract present_sorts shyps;
-        val witnessed_shyps = Sign.witness_sorts thy present_sorts extra_shyps;
+        val shyps' =
+          if Sign.all_sorts_nonempty thy then []
+          else
+            let
+              val present = present_sorts thm;
+              val extra = Sorts.subtract present shyps;
+              val witnessed = map #2 (Sign.witness_sorts thy present extra);
+            in Sorts.subtract witnessed shyps end;
       in
         Thm {thy_ref = thy_ref, der = der, maxidx = maxidx,
-             shyps = Sorts.subtract (map #2 witnessed_shyps) shyps,
-             hyps = hyps, tpairs = tpairs, prop = prop}
+          shyps = shyps', hyps = hyps, tpairs = tpairs, prop = prop}
       end;
 
+(*dangling sort constraints of a thm*)
+fun extra_shyps (th as Thm {shyps, ...}) = Sorts.subtract (present_sorts th) shyps;
+
 
 
 (** Axioms **)
@@ -473,7 +505,7 @@
           Thm {thy_ref = Theory.self_ref thy,
             der = Pt.infer_derivs' I (false, Pt.axm_proof name prop),
             maxidx = maxidx_of_term prop,
-            shyps = Sorts.insert_term prop [],
+            shyps = may_insert_term_sorts thy prop [],
             hyps = [],
             tpairs = [],
             prop = prop});
@@ -531,7 +563,7 @@
 
 (** primitive rules **)
 
-(*The assumption rule A |- A in a theory*)
+(*The assumption rule A |- A*)
 fun assume raw_ct =
   let val Cterm {thy_ref, t = prop, T, maxidx, sorts} = adjust_maxidx raw_ct in
     if T <> propT then
@@ -584,7 +616,7 @@
     case prop of
       imp $ A $ B =>
         if imp = Term.implies andalso A aconv propA then
-          Thm {thy_ref= merge_thys2 thAB thA,
+          Thm {thy_ref = merge_thys2 thAB thA,
             der = Pt.infer_derivs (curry Pt.%%) der derA,
             maxidx = Int.max (maxA, maxidx),
             shyps = Sorts.union shypsA shyps,
@@ -596,11 +628,11 @@
   end;
 
 (*Forall introduction.  The Free or Var x must not be free in the hypotheses.
-   [x]
-    :
-    A
-  -----
-  !!x.A
+    [x]
+     :
+     A
+  ------
+  !!x. A
 *)
 fun forall_intr
     (ct as Cterm {t = x, T, sorts, ...})
@@ -626,7 +658,7 @@
   end;
 
 (*Forall elimination
-  !!x.A
+  !!x. A
   ------
   A[t/x]
 *)
@@ -654,7 +686,7 @@
   t == t
 *)
 fun reflexive (ct as Cterm {thy_ref, t, T, maxidx, sorts}) =
-  Thm {thy_ref= thy_ref,
+  Thm {thy_ref = thy_ref,
     der = Pt.infer_derivs' I (false, Pt.reflexive),
     maxidx = maxidx,
     shyps = sorts,
@@ -696,7 +728,7 @@
       ((eq as Const ("==", Type (_, [T, _]))) $ t1 $ u, Const ("==", _) $ u' $ t2) =>
         if not (u aconv u') then err "middle term"
         else
-          Thm {thy_ref= merge_thys2 th1 th2,
+          Thm {thy_ref = merge_thys2 th1 th2,
             der = Pt.infer_derivs (Pt.transitive u T) der1 der2,
             maxidx = Int.max (max1, max2),
             shyps = Sorts.union shyps1 shyps2,
@@ -707,7 +739,7 @@
   end;
 
 (*Beta-conversion
-  (%x.t)(u) == t[u/x]
+  (%x. t)(u) == t[u/x]
   fully beta-reduces the term if full = true
 *)
 fun beta_conversion full (Cterm {thy_ref, t, T, maxidx, sorts}) =
@@ -876,7 +908,7 @@
           Thm {thy_ref = thy_ref,
             der = Pt.infer_derivs' (Pt.norm_proof' env) der,
             maxidx = maxidx_of_terms (prop' :: terms_of_tpairs tpairs'),
-            shyps = insert_env_sorts env shyps,
+            shyps = may_insert_env_sorts (Theory.deref thy_ref) env shyps,
             hyps = hyps,
             tpairs = tpairs',
             prop = prop'}
@@ -884,50 +916,44 @@
 
 
 (*Instantiation of Vars
-            A
-  ---------------------
-  A[t1/v1, ...., tn/vn]
+           A
+  --------------------
+  A[t1/v1, ..., tn/vn]
 *)
 
 local
 
-(*Check that all the terms are Vars and are distinct*)
-fun instl_ok ts = forall is_Var ts andalso null (findrep ts);
-
 fun pretty_typing thy t T =
   Pretty.block [Sign.pretty_term thy t, Pretty.str " ::", Pretty.brk 1, Sign.pretty_typ thy T];
 
-(*For instantiate: process pair of cterms, merge theories*)
-fun add_ctpair ((ct, cu), (thy_ref, tpairs)) =
+fun add_ctpair ((thy, sorts), (ct, cu)) =
   let
-    val Cterm {thy_ref = thy_reft, t = t, T = T, ...} = ct
-    and Cterm {thy_ref = thy_refu, t = u, T = U, ...} = cu;
-    val thy_ref_merged = Theory.merge_refs (thy_ref, Theory.merge_refs (thy_reft, thy_refu));
-    val thy_merged = Theory.deref thy_ref_merged;
+    val Cterm {t = t, T = T, sorts = sorts1, ...} = ct
+    and Cterm {t = u, T = U, sorts = sorts2, ...} = cu;
+    val thy' = Theory.merge (thy, Theory.deref (merge_thys0 ct cu));
+    val sorts' = Sorts.union sorts2 (Sorts.union sorts1 sorts);
   in
-    if T = U then (thy_ref_merged, (t, u) :: tpairs)
+    if T = U then ((thy', sorts'), (t, u))
     else raise TYPE (Pretty.string_of (Pretty.block
      [Pretty.str "instantiate: type conflict",
-      Pretty.fbrk, pretty_typing thy_merged t T,
-      Pretty.fbrk, pretty_typing thy_merged u U]), [T,U], [t,u])
+      Pretty.fbrk, pretty_typing thy' t T,
+      Pretty.fbrk, pretty_typing thy' u U]), [T,U], [t,u])
   end;
 
-fun add_ctyp
-  ((Ctyp {T = T as TVar (_, S), thy_ref = thy_refT},
-    Ctyp {T = U, thy_ref = thy_refU}), (thy_ref, vTs)) =
-      let
-        val thy_ref_merged = Theory.merge_refs
-          (thy_ref, Theory.merge_refs (thy_refT, thy_refU));
-        val thy_merged = Theory.deref thy_ref_merged;
-      in
-        if Type.of_sort (Sign.tsig_of thy_merged) (U, S) then
-          (thy_ref_merged, (T, U) :: vTs)
-        else raise TYPE ("Type not of sort " ^ Sign.string_of_sort thy_merged S, [U], [])
-      end
-  | add_ctyp ((Ctyp {T, thy_ref}, _), _) =
-      raise TYPE (Pretty.string_of (Pretty.block
+fun add_ctyp ((thy, sorts), (cT, cU)) =
+  let
+    val Ctyp {T = T, thy_ref = thy_ref1, sorts = sorts1, ...} = cT
+    and Ctyp {T = U, thy_ref = thy_ref2, sorts = sorts2, ...} = cU;
+    val thy' = Theory.merge (thy, Theory.deref (Theory.merge_refs (thy_ref1, thy_ref2)));
+    val sorts' = Sorts.union sorts2 (Sorts.union sorts1 sorts);
+  in
+    (case T of TVar (_, S) =>
+      if Type.of_sort (Sign.tsig_of thy') (U, S) then ((thy', sorts'), (T, U))
+      else raise TYPE ("Type not of sort " ^ Sign.string_of_sort thy' S, [U], [])
+    | _ => raise TYPE (Pretty.string_of (Pretty.block
         [Pretty.str "instantiate: not a type variable",
-         Pretty.fbrk, Sign.pretty_typ (Theory.deref thy_ref) T]), [T], []);
+         Pretty.fbrk, Sign.pretty_typ thy' T]), [T], []))
+  end;
 
 in
 
@@ -936,31 +962,28 @@
   Does NOT normalize the resulting theorem!*)
 fun instantiate ([], []) th = th
   | instantiate (vcTs, ctpairs) th =
-  let
-    val Thm {thy_ref, der, maxidx, hyps, shyps, tpairs = dpairs, prop} = th;
-    val (newthy_ref, tpairs) = foldr add_ctpair (thy_ref, []) ctpairs;
-    val (newthy_ref, vTs) = foldr add_ctyp (newthy_ref, []) vcTs;
-    fun subst t = subst_atomic tpairs (map_term_types (typ_subst_atomic vTs) t);
-    val newprop = subst prop;
-    val newdpairs = map (pairself subst) dpairs;
-    val newth =
-      Thm {thy_ref = newthy_ref,
-        der = Pt.infer_derivs' (Pt.instantiate vTs tpairs) der,
-        maxidx = maxidx_of_terms (newprop :: terms_of_tpairs newdpairs),
-        shyps = shyps
-          |> fold (Sorts.insert_typ o #2) vTs
-          |> fold (Sorts.insert_term o #2) tpairs,
-        hyps = hyps,
-        tpairs = newdpairs,
-        prop = newprop};
-  in
-    if not (instl_ok (map #1 tpairs)) then
-      raise THM ("instantiate: variables not distinct", 0, [th])
-    else if not (null (findrep (map #1 vTs))) then
-      raise THM ("instantiate: type variables not distinct", 0, [th])
-    else newth
-  end
-  handle TYPE (msg, _, _) => raise THM (msg, 0, [th]);
+      let
+        val Thm {thy_ref, der, maxidx, hyps, shyps, tpairs = dpairs, prop} = th;
+        val (thy_sorts, tpairs) = foldl_map add_ctpair ((Theory.deref thy_ref, shyps), ctpairs);
+        val ((thy', shyps'), vTs) = foldl_map add_ctyp (thy_sorts, vcTs);
+        fun subst t = subst_atomic tpairs (map_term_types (typ_subst_atomic vTs) t);
+        val prop' = subst prop;
+        val dpairs' = map (pairself subst) dpairs;
+      in
+        if not (forall (is_Var o #1) tpairs andalso null (gen_duplicates eq_fst tpairs)) then
+          raise THM ("instantiate: variables not distinct", 0, [th])
+        else if not (null (gen_duplicates eq_fst vTs)) then
+          raise THM ("instantiate: type variables not distinct", 0, [th])
+        else
+          Thm {thy_ref = Theory.self_ref thy',
+            der = Pt.infer_derivs' (Pt.instantiate vTs tpairs) der,
+            maxidx = maxidx_of_terms (prop' :: terms_of_tpairs dpairs'),
+            shyps = shyps',
+            hyps = hyps,
+            tpairs = dpairs',
+            prop = prop'}
+      end
+      handle TYPE (msg, _, _) => raise THM (msg, 0, [th]);
 
 end;
 
@@ -1076,6 +1099,7 @@
 fun assumption i state =
   let
     val Thm {thy_ref, der, maxidx, shyps, hyps, prop, ...} = state;
+    val thy = Theory.deref thy_ref;
     val (tpairs, Bs, Bi, C) = dest_state (state, i);
     fun newth n (env as Envir.Envir {maxidx, ...}, tpairs) =
       Thm {thy_ref = thy_ref,
@@ -1083,7 +1107,7 @@
           ((if Envir.is_empty env then I else (Pt.norm_proof' env)) o
             Pt.assumption_proof Bs Bi n) der,
         maxidx = maxidx,
-        shyps = insert_env_sorts env shyps,
+        shyps = may_insert_env_sorts thy env shyps,
         hyps = hyps,
         tpairs =
           if Envir.is_empty env then tpairs
@@ -1096,7 +1120,7 @@
     fun addprfs [] _ = Seq.empty
       | addprfs ((t, u) :: apairs) n = Seq.make (fn () => Seq.pull
           (Seq.mapp (newth n)
-            (Unify.unifiers (Theory.deref thy_ref, Envir.empty maxidx, (t, u) :: tpairs))
+            (Unify.unifiers (thy, Envir.empty maxidx, (t, u) :: tpairs))
             (addprfs apairs (n + 1))))
   in addprfs (Logic.assum_pairs (~1, Bi)) 1 end;
 
@@ -1150,7 +1174,7 @@
 
 (*Rotates a rule's premises to the left by k, leaving the first j premises
   unchanged.  Does nothing if k=0 or if k equals n-j, where n is the
-  number of premises.  Useful with etac and underlies tactic/defer_tac*)
+  number of premises.  Useful with etac and underlies defer_tac*)
 fun permute_prems j k rl =
   let
     val Thm {thy_ref, der, maxidx, shyps, hyps, tpairs, prop} = rl;
@@ -1227,7 +1251,7 @@
        prop = prop'});
 
 
-(* strip_apply f A(,B) strips off all assumptions/parameters from A
+(* strip_apply f (A, B) strips off all assumptions/parameters from A
    introduced by lifting over B, and applies f to remaining part of A*)
 fun strip_apply f =
   let fun strip(Const("==>",_)$ A1 $ B1,
@@ -1338,7 +1362,7 @@
                        curry op oo (Pt.norm_proof' env))
                     (Pt.bicompose_proof Bs oldAs As A n)) rder' sder,
                  maxidx = maxidx,
-                 shyps = insert_env_sorts env (Sorts.union rshyps sshyps),
+                 shyps = may_insert_env_sorts thy env (Sorts.union rshyps sshyps),
                  hyps = union_hyps rhyps shyps,
                  tpairs = ntpairs,
                  prop = Logic.list_implies normp}
@@ -1433,7 +1457,7 @@
           Thm {thy_ref = Theory.self_ref thy',
             der = (true, Pt.oracle_proof name prop),
             maxidx = maxidx,
-            shyps = Sorts.insert_term prop [],
+            shyps = may_insert_term_sorts thy' prop [],
             hyps = [],
             tpairs = [],
             prop = prop}
@@ -1443,9 +1467,7 @@
 fun invoke_oracle thy =
   invoke_oracle_i thy o NameSpace.intern (Theory.oracle_space thy);
 
-
 end;
 
-
 structure BasicThm: BASIC_THM = Thm;
 open BasicThm;