- reconstruct_proof no longer relies on TypeInfer.infer_types
authorberghofe
Mon, 21 Oct 2002 17:19:51 +0200
changeset 13669 a9f229eafba7
parent 13668 11397ea8b438
child 13670 c71b905a852a
- reconstruct_proof no longer relies on TypeInfer.infer_types - fixed problem with theorems containing TFrees
src/Pure/Proof/reconstruct.ML
--- a/src/Pure/Proof/reconstruct.ML	Mon Oct 21 17:17:40 2002 +0200
+++ b/src/Pure/Proof/reconstruct.ML	Mon Oct 21 17:19:51 2002 +0200
@@ -47,13 +47,8 @@
                  iTs=Vartab.merge (op =) (iTs1, iTs2),
                  maxidx=Int.max (maxidx1, maxidx2)};
 
-fun strip_abs (_::Ts) (Abs (_, _, t)) = strip_abs Ts t
-  | strip_abs _ t = t;
 
-
-(********************************************************************************
-  generate constraints for proof term
-*********************************************************************************)
+(**** generate constraints for proof term ****)
 
 fun mk_var env Ts T = 
   let val (env', v) = Envir.genvar "a" (env, rev Ts ---> T)
@@ -65,38 +60,6 @@
 
 fun mk_abs Ts t = foldl (fn (u, T) => Abs ("", T, u)) (t, Ts);
 
-fun make_Tconstraints_cprf maxidx cprf =
-  let
-    fun mk_Tcnstrts maxidx Ts (Abst (s, Some T, cprf)) =
-          let val (cs, cprf', maxidx') = mk_Tcnstrts maxidx (T::Ts) cprf;
-          in (cs, Abst (s, Some T, cprf'), maxidx') end
-      | mk_Tcnstrts maxidx Ts (Abst (s, None, cprf)) =
-          let
-            val T' = TVar (("'t", maxidx+1), ["logic"]);
-            val (cs, cprf', maxidx') = mk_Tcnstrts (maxidx+1) (T'::Ts) cprf;
-          in (cs, Abst (s, Some T', cprf'), maxidx') end
-      | mk_Tcnstrts maxidx Ts (AbsP (s, Some t, cprf)) =
-          let val (cs, cprf', maxidx') = mk_Tcnstrts maxidx Ts cprf;
-          in ((mk_abs Ts t, rev Ts ---> propT)::cs, AbsP (s, Some t, cprf'), maxidx') end
-      | mk_Tcnstrts maxidx Ts (AbsP (s, None, cprf)) =
-          let val (cs, cprf', maxidx') = mk_Tcnstrts maxidx Ts cprf;
-          in (cs, AbsP (s, None, cprf'), maxidx') end
-      | mk_Tcnstrts maxidx Ts (cprf1 %% cprf2) =
-          let
-            val (cs, cprf1', maxidx') = mk_Tcnstrts maxidx Ts cprf1;
-            val (cs', cprf2', maxidx'') = mk_Tcnstrts maxidx' Ts cprf2;
-          in (cs' @ cs, cprf1' %% cprf2', maxidx'') end
-      | mk_Tcnstrts maxidx Ts (cprf % Some t) =
-          let val (cs, cprf', maxidx') = mk_Tcnstrts maxidx Ts cprf;
-          in ((mk_abs Ts t, rev Ts ---> TypeInfer.logicT)::cs,
-            cprf' % Some t, maxidx')
-          end
-      | mk_Tcnstrts maxidx Ts (cprf % None) =
-          let val (cs, cprf', maxidx') = mk_Tcnstrts maxidx Ts cprf;
-          in (cs, cprf % None, maxidx') end
-      | mk_Tcnstrts maxidx _ cprf = ([], cprf, maxidx);
-  in mk_Tcnstrts maxidx [] cprf end;
-
 fun unifyT sg env T U =
   let
     val Envir.Envir {asol, iTs, maxidx} = env;
@@ -105,6 +68,44 @@
   handle Type.TUNIFY => error ("Non-unifiable types:\n" ^
     Sign.string_of_typ sg T ^ "\n\n" ^ Sign.string_of_typ sg U);
 
+fun chaseT (env as Envir.Envir {iTs, ...}) (T as TVar (ixn, _)) =
+      (case Vartab.lookup (iTs, ixn) of None => T | Some T' => chaseT env T')
+  | chaseT _ T = T;
+
+fun infer_type sg (env as Envir.Envir {maxidx, asol, iTs}) Ts vTs
+      (t as Const (s, T)) = if T = dummyT then (case Sign.const_type sg s of
+          None => error ("reconstruct_proof: No such constant: " ^ quote s)
+        | Some T => 
+            let val T' = incr_tvar (maxidx + 1) T
+            in (Const (s, T'), T', vTs,
+              Envir.Envir {maxidx = maxidx + 1, asol = asol, iTs = iTs})
+            end)
+      else (t, T, vTs, env)
+  | infer_type sg env Ts vTs (t as Free (s, T)) =
+      if T = dummyT then (case Symtab.lookup (vTs, s) of
+          None =>
+            let val (env', T) = mk_tvar (env, [])
+            in (Free (s, T), T, Symtab.update_new ((s, T), vTs), env') end
+        | Some T => (Free (s, T), T, vTs, env))
+      else (t, T, vTs, env)
+  | infer_type sg env Ts vTs (Var _) = error "reconstruct_proof: internal error"
+  | infer_type sg env Ts vTs (Abs (s, T, t)) =
+      let
+        val (env', T') = if T = dummyT then mk_tvar (env, []) else (env, T);
+        val (t', U, vTs', env'') = infer_type sg env' (T' :: Ts) vTs t
+      in (Abs (s, T', t'), T' --> U, vTs', env'') end
+  | infer_type sg env Ts vTs (t $ u) =
+      let
+        val (t', T, vTs1, env1) = infer_type sg env Ts vTs t;
+        val (u', U, vTs2, env2) = infer_type sg env1 Ts vTs1 u;
+      in (case chaseT env2 T of
+          Type ("fun", [U', V]) => (t' $ u', V, vTs2, unifyT sg env2 U U')
+        | _ =>
+          let val (env3, V) = mk_tvar (env2, [])
+          in (t' $ u', V, vTs2, unifyT sg env3 T (U --> V)) end)
+      end
+  | infer_type sg env Ts vTs (t as Bound i) = (t, nth_elem (i, Ts), vTs, env);
+
 fun decompose sg env Ts t u = case (Envir.head_norm env t, Envir.head_norm env u) of
     (Const ("all", _) $ t, Const ("all", _) $ u) => decompose sg env Ts t u
   | (Const ("==>", _) $ t1 $ t2, Const ("==>", _) $ u1 $ u2) =>
@@ -117,104 +118,112 @@
 fun cantunify sg t u = error ("Non-unifiable terms:\n" ^
   Sign.string_of_term sg t ^ "\n\n" ^ Sign.string_of_term sg u);
 
-fun make_constraints_cprf sg env ts cprf =
+fun make_constraints_cprf sg env cprf =
   let
-    fun add_cnstrt Ts prop prf cs env ts (t, u) =
+    fun add_cnstrt Ts prop prf cs env vTs (t, u) =
       let
         val t' = mk_abs Ts t;
         val u' = mk_abs Ts u
       in
-        (prop, prf, cs, Pattern.unify (sg, env, [(t', u')]), ts)
+        (prop, prf, cs, Pattern.unify (sg, env, [(t', u')]), vTs)
         handle Pattern.Pattern =>
             let val (env', cs') = decompose sg env [] t' u'
-            in (prop, prf, cs @ cs', env', ts) end
+            in (prop, prf, cs @ cs', env', vTs) end
         | Pattern.Unif =>
             cantunify sg (Envir.norm_term env t') (Envir.norm_term env u')
       end;
 
-    fun mk_cnstrts_atom env ts prop opTs prf =
+    fun mk_cnstrts_atom env vTs prop opTs prf =
           let
             val tvars = term_tvars prop;
-            val (env', Ts) = if_none (apsome (pair env) opTs)
-              (foldl_map (mk_tvar o apsnd snd) (env, tvars));
-            val prop' = subst_TVars (map fst tvars ~~ Ts) (forall_intr_vfs prop)
-              handle LIST _ => error ("Wrong number of type arguments for " ^
-                quote (fst (get_name_tags [] prop prf)))
-          in (prop', change_type (Some Ts) prf, [], env', ts) end;
+            val tfrees = term_tfrees prop;
+            val (prop', fmap) = Type.varify (prop, []);
+            val (env', Ts) = (case opTs of
+                None => foldl_map mk_tvar (env, map snd tvars @ map snd tfrees)
+              | Some Ts => (env, Ts));
+            val prop'' = subst_TVars (map fst tvars @ map snd fmap ~~ Ts)
+              (forall_intr_vfs prop') handle LIST _ =>
+                error ("Wrong number of type arguments for " ^
+                  quote (fst (get_name_tags [] prop prf)))
+          in (prop'', change_type (Some Ts) prf, [], env', vTs) end;
 
-    fun mk_cnstrts env _ Hs ts (PBound i) = (nth_elem (i, Hs), PBound i, [], env, ts)
-      | mk_cnstrts env Ts Hs ts (Abst (s, Some T, cprf)) =
-          let val (t, prf, cnstrts, env', ts') =
-              mk_cnstrts env (T::Ts) (map (incr_boundvars 1) Hs) ts cprf;
+    fun head_norm (prop, prf, cnstrts, env, vTs) =
+      (Envir.head_norm env prop, prf, cnstrts, env, vTs);
+ 
+    fun mk_cnstrts env _ Hs vTs (PBound i) = (nth_elem (i, Hs), PBound i, [], env, vTs)
+      | mk_cnstrts env Ts Hs vTs (Abst (s, opT, cprf)) =
+          let
+            val (env', T) = (case opT of
+              None => mk_tvar (env, logicS) | Some T => (env, T));
+            val (t, prf, cnstrts, env'', vTs') =
+              mk_cnstrts env' (T::Ts) (map (incr_boundvars 1) Hs) vTs cprf;
           in (Const ("all", (T --> propT) --> propT) $ Abs (s, T, t), Abst (s, Some T, prf),
-            cnstrts, env', ts')
+            cnstrts, env'', vTs')
           end
-      | mk_cnstrts env Ts Hs (t::ts) (AbsP (s, Some _, cprf)) =
+      | mk_cnstrts env Ts Hs vTs (AbsP (s, Some t, cprf)) =
           let
-            val (u, prf, cnstrts, env', ts') = mk_cnstrts env Ts (t::Hs) ts cprf;
-            val t' = strip_abs Ts t;
-          in (Logic.mk_implies (t', u), AbsP (s, Some t', prf), cnstrts, env', ts')
+            val (t', _, vTs', env') = infer_type sg env Ts vTs t;
+            val (u, prf, cnstrts, env'', vTs'') = mk_cnstrts env' Ts (t'::Hs) vTs' cprf;
+          in (Logic.mk_implies (t', u), AbsP (s, Some t', prf), cnstrts, env'', vTs'')
           end
-      | mk_cnstrts env Ts Hs ts (AbsP (s, None, cprf)) =
+      | mk_cnstrts env Ts Hs vTs (AbsP (s, None, cprf)) =
           let
             val (env', t) = mk_var env Ts propT;
-            val (u, prf, cnstrts, env'', ts') = mk_cnstrts env' Ts (t::Hs) ts cprf;
-          in (Logic.mk_implies (t, u), AbsP (s, Some t, prf), cnstrts, env'', ts')
+            val (u, prf, cnstrts, env'', vTs') = mk_cnstrts env' Ts (t::Hs) vTs cprf;
+          in (Logic.mk_implies (t, u), AbsP (s, Some t, prf), cnstrts, env'', vTs')
           end
-      | mk_cnstrts env Ts Hs ts (cprf1 %% cprf2) =
-          let val (u, prf2, cnstrts, env', ts') = mk_cnstrts env Ts Hs ts cprf2
-          in (case mk_cnstrts env' Ts Hs ts' cprf1 of
-              (Const ("==>", _) $ u' $ t', prf1, cnstrts', env'', ts'') =>
+      | mk_cnstrts env Ts Hs vTs (cprf1 %% cprf2) =
+          let val (u, prf2, cnstrts, env', vTs') = mk_cnstrts env Ts Hs vTs cprf2
+          in (case head_norm (mk_cnstrts env' Ts Hs vTs' cprf1) of
+              (Const ("==>", _) $ u' $ t', prf1, cnstrts', env'', vTs'') =>
                 add_cnstrt Ts t' (prf1 %% prf2) (cnstrts' @ cnstrts)
-                  env'' ts'' (u, u')
-            | (t, prf1, cnstrts', env'', ts'') =>
+                  env'' vTs'' (u, u')
+            | (t, prf1, cnstrts', env'', vTs'') =>
                 let val (env''', v) = mk_var env'' Ts propT
                 in add_cnstrt Ts v (prf1 %% prf2) (cnstrts' @ cnstrts)
-                  env''' ts'' (t, Logic.mk_implies (u, v))
+                  env''' vTs'' (t, Logic.mk_implies (u, v))
                 end)
           end
-      | mk_cnstrts env Ts Hs (t::ts) (cprf % Some _) =
-          let val t' = strip_abs Ts t
-          in (case mk_cnstrts env Ts Hs ts cprf of
+      | mk_cnstrts env Ts Hs vTs (cprf % Some t) =
+          let val (t', U, vTs1, env1) = infer_type sg env Ts vTs t
+          in (case head_norm (mk_cnstrts env1 Ts Hs vTs1 cprf) of
              (Const ("all", Type ("fun", [Type ("fun", [T, _]), _])) $ f,
-                 prf, cnstrts, env', ts') =>
-               let val env'' = unifyT sg env' T (Envir.fastype env' Ts t')
-               in (betapply (f, t'), prf % Some t', cnstrts, env'', ts')
+                 prf, cnstrts, env2, vTs2) =>
+               let val env3 = unifyT sg env2 T U
+               in (betapply (f, t'), prf % Some t', cnstrts, env3, vTs2)
                end
-           | (u, prf, cnstrts, env', ts') =>
-               let
-                 val T = Envir.fastype env' Ts t';
-                 val (env'', v) = mk_var env' Ts (T --> propT);
+           | (u, prf, cnstrts, env2, vTs2) =>
+               let val (env3, v) = mk_var env2 Ts (U --> propT);
                in
-                 add_cnstrt Ts (v $ t') (prf % Some t') cnstrts env'' ts'
-                   (u, Const ("all", (T --> propT) --> propT) $ v)
+                 add_cnstrt Ts (v $ t') (prf % Some t') cnstrts env3 vTs2
+                   (u, Const ("all", (U --> propT) --> propT) $ v)
                end)
           end
-      | mk_cnstrts env Ts Hs ts (cprf % None) =
-          (case mk_cnstrts env Ts Hs ts cprf of
+      | mk_cnstrts env Ts Hs vTs (cprf % None) =
+          (case head_norm (mk_cnstrts env Ts Hs vTs cprf) of
              (Const ("all", Type ("fun", [Type ("fun", [T, _]), _])) $ f,
-                 prf, cnstrts, env', ts') =>
+                 prf, cnstrts, env', vTs') =>
                let val (env'', t) = mk_var env' Ts T
-               in (betapply (f, t), prf % Some t, cnstrts, env'', ts')
+               in (betapply (f, t), prf % Some t, cnstrts, env'', vTs')
                end
-           | (u, prf, cnstrts, env', ts') =>
+           | (u, prf, cnstrts, env', vTs') =>
                let
                  val (env1, T) = mk_tvar (env', ["logic"]);
                  val (env2, v) = mk_var env1 Ts (T --> propT);
                  val (env3, t) = mk_var env2 Ts T
                in
-                 add_cnstrt Ts (v $ t) (prf % Some t) cnstrts env3 ts'
+                 add_cnstrt Ts (v $ t) (prf % Some t) cnstrts env3 vTs'
                    (u, Const ("all", (T --> propT) --> propT) $ v)
                end)
-      | mk_cnstrts env _ _ ts (prf as PThm (_, _, prop, opTs)) =
-          mk_cnstrts_atom env ts prop opTs prf
-      | mk_cnstrts env _ _ ts (prf as PAxm (_, prop, opTs)) =
-          mk_cnstrts_atom env ts prop opTs prf
-      | mk_cnstrts env _ _ ts (prf as Oracle (_, prop, opTs)) =
-          mk_cnstrts_atom env ts prop opTs prf
-      | mk_cnstrts env _ _ ts (Hyp t) = (t, Hyp t, [], env, ts)
+      | mk_cnstrts env _ _ vTs (prf as PThm (_, _, prop, opTs)) =
+          mk_cnstrts_atom env vTs prop opTs prf
+      | mk_cnstrts env _ _ vTs (prf as PAxm (_, prop, opTs)) =
+          mk_cnstrts_atom env vTs prop opTs prf
+      | mk_cnstrts env _ _ vTs (prf as Oracle (_, prop, opTs)) =
+          mk_cnstrts_atom env vTs prop opTs prf
+      | mk_cnstrts env _ _ vTs (Hyp t) = (t, Hyp t, [], env, vTs)
       | mk_cnstrts _ _ _ _ _ = error "reconstruct_proof: minimal proof object"
-  in mk_cnstrts env [] [] ts cprf end;
+  in mk_cnstrts env [] [] Symtab.empty cprf end;
 
 fun add_term_ixns (is, Var (i, T)) = add_typ_ixns (i ins is, T)
   | add_term_ixns (is, Free (_, T)) = add_typ_ixns (is, T)
@@ -224,9 +233,7 @@
   | add_term_ixns (is, _) = is;
 
 
-(********************************************************************************
-  update list of free variables of constraints
-*********************************************************************************)
+(**** update list of free variables of constraints ****)
 
 fun upd_constrs env cs =
   let
@@ -243,17 +250,15 @@
           end
   in check_cs cs end;
 
-(********************************************************************************
-  solution of constraints
-*********************************************************************************)
+(**** solution of constraints ****)
 
 fun solve _ [] bigenv = bigenv
   | solve sg cs bigenv =
       let
         fun search env [] = error ("Unsolvable constraints:\n" ^
               Pretty.string_of (Pretty.chunks (map (fn (_, p, _) =>
-                Sign.pretty_term sg (Logic.mk_flexpair (pairself
-                  (Envir.norm_term bigenv) p))) cs)))
+                Display.pretty_flexpair (Sign.pretty_term sg) (pairself
+                  (Envir.norm_term bigenv) p)) cs)))
           | search env ((u, p as (t1, t2), vs)::ps) =
               if u then
                 let
@@ -278,24 +283,14 @@
       end;
 
 
-(********************************************************************************
-  reconstruction of proofs
-*********************************************************************************)
+(**** reconstruction of proofs ****)
 
 fun reconstruct_proof sg prop cprf =
   let
     val (cprf' % Some prop', thawf) = freeze_thaw_prf (cprf % Some prop);
-    val _ = message "Collecting type constraints...";
-    val (Tcs, cprf'', maxidx) = make_Tconstraints_cprf 0 cprf';
-    val (ts, Ts) = ListPair.unzip Tcs;
-    val tsig = Sign.tsig_of sg;
-    val {classrel, arities, ...} = Type.rep_tsig tsig;
-    val _ = message "Solving type constraints...";
-    val (ts', _, unifier) = TypeInfer.infer_types (Sign.pretty_term sg) (Sign.pretty_typ sg)
-      (Sign.const_type sg) classrel arities [] false (K true) ts Ts;
-    val env = Envir.Envir {asol = Vartab.empty, iTs = Vartab.make unifier, maxidx = maxidx};
-    val _ = message "Collecting term constraints...";
-    val (t, prf, cs, env, _) = make_constraints_cprf sg env ts' cprf'';
+    val _ = message "Collecting constraints...";
+    val (t, prf, cs, env, _) = make_constraints_cprf sg
+      (Envir.empty (maxidx_of_proof cprf)) cprf';
     val cs' = map (fn p => (true, p, op union
       (pairself (map (fst o dest_Var) o term_vars) p))) (map (pairself (Envir.norm_term env)) ((t, prop')::cs));
     val _ = message ("Solving remaining constraints (" ^ string_of_int (length cs') ^ ") ...");
@@ -305,7 +300,10 @@
   end;
 
 fun prop_of_atom prop Ts =
-  subst_TVars (map fst (term_tvars prop) ~~ Ts) (forall_intr_vfs prop);
+  let val (prop', fmap) = Type.varify (prop, []);
+  in subst_TVars (map fst (term_tvars prop) @ map snd fmap ~~ Ts)
+    (forall_intr_vfs prop')
+  end;
 
 fun prop_of' Hs (PBound i) = nth_elem (i, Hs)
   | prop_of' Hs (Abst (s, Some T, prf)) =
@@ -327,9 +325,7 @@
 val prop_of = prop_of' [];
 
 
-(********************************************************************************
-  expand and reconstruct subproofs
-*********************************************************************************)
+(**** expand and reconstruct subproofs ****)
 
 fun expand_proof sg thms prf =
   let
@@ -355,7 +351,7 @@
           let
             fun inc i =
               map_proof_terms (Logic.incr_indexes ([], i)) (incr_tvar i);
-            val (maxidx', i, prf, prfs') = (case assoc (prfs, (a, prop)) of
+            val (maxidx', prf, prfs') = (case assoc (prfs, (a, prop)) of
                 None =>
                   let
                     val _ = message ("Reconstructing proof of " ^ a);
@@ -364,15 +360,19 @@
                       (reconstruct_proof sg prop cprf);
                     val (maxidx', prfs', prf) = expand
                       (maxidx_of_proof prf') prfs prf'
-                  in (maxidx' + maxidx + 1, maxidx + 1, inc (maxidx + 1) prf,
+                  in (maxidx' + maxidx + 1, inc (maxidx + 1) prf,
                     ((a, prop), (maxidx', prf)) :: prfs')
                   end
               | Some (maxidx', prf) => (maxidx' + maxidx + 1,
-                  maxidx + 1, inc (maxidx + 1) prf, prfs));
-            val tye = map (fn ((s, j), _) => (s, i + j)) (term_tvars prop) ~~ Ts
+                  inc (maxidx + 1) prf, prfs));
+            val tfrees = term_tfrees prop;
+            val tye = map (fn ((s, j), _) => (s, maxidx + 1 + j))
+              (term_tvars prop) @ map (rpair ~1 o fst) tfrees ~~ Ts;
+            val varify = map_type_tfree (fn p as (a, S) =>
+              if p mem tfrees then TVar ((a, ~1), S) else TFree p)
           in
-            (maxidx', prfs',
-             map_proof_terms (subst_TVars tye) (typ_subst_TVars tye) prf)
+            (maxidx', prfs', map_proof_terms (subst_TVars tye o
+               map_term_types varify) (typ_subst_TVars tye o varify) prf)
           end
       | expand maxidx prfs prf = (maxidx, prfs, prf);