tuned instantiate (avoid subst_atomic, subst_atomic_types);
authorwenzelm
Tue, 19 Jul 2005 17:21:56 +0200
changeset 16884 1678a796b6b2
parent 16883 a89fafe1cbd8
child 16885 cabcd33cde18
tuned instantiate (avoid subst_atomic, subst_atomic_types); Logic.incr_tvar;
src/Pure/thm.ML
--- a/src/Pure/thm.ML	Tue Jul 19 17:21:55 2005 +0200
+++ b/src/Pure/thm.ML	Tue Jul 19 17:21:56 2005 +0200
@@ -309,7 +309,7 @@
   if i < 0 then raise CTERM "negative increment"
   else if i = 0 then ct
   else Cterm {thy_ref = thy_ref, t = Logic.incr_indexes ([], i) t,
-    T = Term.incr_tvar i T, maxidx = maxidx + i, sorts = sorts};
+    T = Logic.incr_tvar i T, maxidx = maxidx + i, sorts = sorts};
 
 
 
@@ -374,7 +374,7 @@
 
 fun eq_tpairs ((t, u), (t', u')) = t aconv t' andalso u aconv u';
 val union_tpairs = gen_merge_lists eq_tpairs;
-val maxidx_tpairs = fold (fn (t, u) => Term.maxidx_term t o Term.maxidx_term u);
+val maxidx_tpairs = fold (fn (t, u) => Term.maxidx_term t #> Term.maxidx_term u);
 
 fun attach_tpairs tpairs prop =
   Logic.list_implies (map Logic.mk_equals tpairs, prop);
@@ -897,7 +897,7 @@
         let
           val tpairs' = tpairs |> map (pairself (Envir.norm_term env))
             (*remove trivial tpairs, of the form t==t*)
-            |> List.filter (not o op aconv);
+            |> filter_out (op aconv);
           val prop' = Envir.norm_term env prop;
         in
           Thm {thy_ref = thy_ref,
@@ -921,29 +921,34 @@
 fun pretty_typing thy t T =
   Pretty.block [Sign.pretty_term thy t, Pretty.str " ::", Pretty.brk 1, Sign.pretty_typ thy T];
 
-fun add_ctpair ((thy, sorts), (ct, cu)) =
+fun add_inst (ct, cu) (thy_ref, sorts) =
   let
-    val Cterm {t = t, T = T, sorts = sorts1, ...} = ct
-    and Cterm {t = u, T = U, sorts = sorts2, ...} = cu;
-    val thy' = Theory.merge (thy, Theory.deref (merge_thys0 ct cu));
-    val sorts' = Sorts.union sorts2 (Sorts.union sorts1 sorts);
+    val Cterm {t = t, T = T, ...} = ct
+    and Cterm {t = u, T = U, sorts = sorts_u, ...} = cu;
+    val thy_ref' = Theory.merge_refs (thy_ref, merge_thys0 ct cu);
+    val sorts' = Sorts.union sorts_u sorts;
   in
-    if T = U then ((thy', sorts'), (t, u))
-    else raise TYPE (Pretty.string_of (Pretty.block
-     [Pretty.str "instantiate: type conflict",
-      Pretty.fbrk, pretty_typing thy' t T,
-      Pretty.fbrk, pretty_typing thy' u U]), [T,U], [t,u])
+    (case t of Var v =>
+      if T = U then ((v, u), (thy_ref', sorts'))
+      else raise TYPE (Pretty.string_of (Pretty.block
+       [Pretty.str "instantiate: type conflict",
+        Pretty.fbrk, pretty_typing (Theory.deref thy_ref') t T,
+        Pretty.fbrk, pretty_typing (Theory.deref thy_ref') u U]), [T, U], [t, u])
+    | _ => raise TYPE (Pretty.string_of (Pretty.block
+       [Pretty.str "instantiate: not a variable",
+        Pretty.fbrk, Sign.pretty_term (Theory.deref thy_ref') t]), [], [t]))
   end;
 
-fun add_ctyp ((thy, sorts), (cT, cU)) =
+fun add_instT (cT, cU) (thy_ref, sorts) =
   let
-    val Ctyp {T, thy_ref = thy_ref1, sorts = sorts1, ...} = cT
-    and Ctyp {T = U, thy_ref = thy_ref2, sorts = sorts2, ...} = cU;
-    val thy' = Theory.merge (thy, Theory.deref (Theory.merge_refs (thy_ref1, thy_ref2)));
-    val sorts' = Sorts.union sorts2 (Sorts.union sorts1 sorts);
+    val Ctyp {T, thy_ref = thy_ref1, ...} = cT
+    and Ctyp {T = U, thy_ref = thy_ref2, sorts = sorts_U, ...} = cU;
+    val thy_ref' = Theory.merge_refs (thy_ref, Theory.merge_refs (thy_ref1, thy_ref2));
+    val thy' = Theory.deref thy_ref';
+    val sorts' = Sorts.union sorts_U sorts;
   in
-    (case T of TVar (_, S) =>
-      if Type.of_sort (Sign.tsig_of thy') (U, S) then ((thy', sorts'), (T, U))
+    (case T of TVar (v as (_, S)) =>
+      if Type.of_sort (Sign.tsig_of thy') (U, S) then ((v, U), (thy_ref', sorts'))
       else raise TYPE ("Type not of sort " ^ Sign.string_of_sort thy' S, [U], [])
     | _ => raise TYPE (Pretty.string_of (Pretty.block
         [Pretty.str "instantiate: not a type variable",
@@ -956,26 +961,26 @@
   Instantiates distinct Vars by terms of same type.
   Does NOT normalize the resulting theorem!*)
 fun instantiate ([], []) th = th
-  | instantiate (vcTs, ctpairs) th =
+  | instantiate (instT, inst) th =
       let
-        val Thm {thy_ref, der, hyps, shyps, tpairs = dpairs, prop, ...} = th;
-        val (context, tpairs) = foldl_map add_ctpair ((Theory.deref thy_ref, shyps), ctpairs);
-        val ((thy', shyps'), vTs) = foldl_map add_ctyp (context, vcTs);
-        fun subst t = subst_atomic tpairs (subst_atomic_types vTs t);
+        val Thm {thy_ref, der, hyps, shyps, tpairs, prop, ...} = th;
+        val (inst', (instT', (thy_ref', shyps'))) =
+          (thy_ref, shyps) |> fold_map add_inst inst ||> fold_map add_instT instT;
+        val subst = Term.instantiate (instT', inst');
         val prop' = subst prop;
-        val dpairs' = map (pairself subst) dpairs;
+        val tpairs' = map (pairself subst) tpairs;
       in
-        if not (forall (is_Var o #1) tpairs andalso null (gen_duplicates eq_fst tpairs)) then
+        if has_duplicates (fn ((v, _), (v', _)) => Term.eq_var (v, v')) inst' then
           raise THM ("instantiate: variables not distinct", 0, [th])
-        else if not (null (gen_duplicates eq_fst vTs)) then
+        else if has_duplicates (fn ((v, _), (v', _)) => Term.eq_tvar (v, v')) instT' then
           raise THM ("instantiate: type variables not distinct", 0, [th])
         else
-          Thm {thy_ref = Theory.self_ref thy',
-            der = Pt.infer_derivs' (Pt.instantiate vTs tpairs) der,
-            maxidx = maxidx_tpairs dpairs' (maxidx_of_term prop'),
+          Thm {thy_ref = thy_ref',
+            der = Pt.infer_derivs' (Pt.instantiate (instT', inst')) der,
+            maxidx = maxidx_tpairs tpairs' (maxidx_of_term prop'),
             shyps = shyps',
             hyps = hyps,
-            tpairs = dpairs',
+            tpairs = tpairs',
             prop = prop'}
       end
       handle TYPE (msg, _, _) => raise THM (msg, 0, [th]);
@@ -1083,7 +1088,8 @@
   else if i = 0 then thm
   else
     Thm {thy_ref = thy_ref,
-      der = Pt.infer_derivs' (Pt.map_proof_terms (Logic.incr_indexes ([], i)) (incr_tvar i)) der,
+      der = Pt.infer_derivs'
+        (Pt.map_proof_terms (Logic.incr_indexes ([], i)) (Logic.incr_tvar i)) der,
       maxidx = maxidx + i,
       shyps = shyps,
       hyps = hyps,