src/Pure/Isar/obtain.ML
changeset 19779 5c77dfb74c7b
parent 19585 70a1ce3b23ae
child 19844 2c1fdc397ded
--- a/src/Pure/Isar/obtain.ML	Mon Jun 05 21:54:24 2006 +0200
+++ b/src/Pure/Isar/obtain.ML	Mon Jun 05 21:54:25 2006 +0200
@@ -173,7 +173,7 @@
 
 local
 
-fun match_params ctxt vars rule =
+fun unify_params ctxt vars raw_rule =
   let
     val thy = ProofContext.theory_of ctxt;
     val string_of_typ = ProofContext.string_of_typ ctxt;
@@ -181,26 +181,27 @@
 
     fun err msg th = error (msg ^ ":\n" ^ ProofContext.string_of_thm ctxt th);
 
+    val maxidx = fold (Term.maxidx_typ o snd) vars ~1;
+    val rule = Thm.incr_indexes (maxidx + 1) raw_rule;
+
     val params = RuleCases.strip_params (Logic.nth_prem (1, Thm.prop_of rule));
     val m = length vars;
     val n = length params;
-    val _ = conditional (m > n)
-      (fn () => err "More variables than parameters in obtained rule" rule);
+    val _ = m <= n orelse err "More variables than parameters in obtained rule" rule;
 
-    fun match ((x, SOME T), (y, U)) tyenv =
-        ((x, T), Sign.typ_match thy (U, T) tyenv handle Type.TYPE_MATCH =>
-          err ("Failed to match variable " ^
-            string_of_term (Free (x, T)) ^ " against parameter " ^
-            string_of_term (Syntax.mark_boundT (y, Envir.norm_type tyenv U)) ^ " in") rule)
-      | match ((x, NONE), (_, U)) tyenv = ((x, U), tyenv);
-    val (xs, tyenv) = fold_map match (vars ~~ Library.take (m, params)) Vartab.empty;
-    val ys = Library.drop (m, params);
+    fun unify ((x, T), (y, U)) (tyenv, max) = Sign.typ_unify thy (T, U) (tyenv, max)
+      handle Type.TUNIFY =>
+        err ("Failed to unify variable " ^
+          string_of_term (Free (x, Envir.norm_type tyenv T)) ^ " against parameter " ^
+          string_of_term (Syntax.mark_boundT (y, Envir.norm_type tyenv U)) ^ " in") rule;
+    val (tyenv, _) = fold unify (vars ~~ Library.take (m, params))
+      (Vartab.empty, Int.max (maxidx, Thm.maxidx_of rule));
     val norm_type = Envir.norm_type tyenv;
 
-    val xs' = xs |> map (apsnd norm_type);
-    val ys' =
-      map Syntax.internal (Term.variantlist (map fst ys, map fst xs)) ~~
-      map (norm_type o snd) ys;
+    val xs = map (apsnd norm_type) vars;
+    val ys = map (apsnd norm_type) (Library.drop (m, params));
+    val ys' = map Syntax.internal (Term.variantlist (map fst ys, map fst xs)) ~~ map #2 ys;
+
     val instT =
       fold (Term.add_tvarsT o #2) params []
       |> map (TVar #> (fn T => (Thm.ctyp_of thy T, Thm.ctyp_of thy (norm_type T))));
@@ -212,11 +213,15 @@
       if null tvars andalso null vars then ()
       else err ("Illegal schematic variable(s) " ^
         commas (map (string_of_typ o TVar) tvars @ map (string_of_term o Var) vars) ^ " in") rule';
-  in (xs' @ ys', rule') end;
+  in (xs @ ys', rule') end;
 
 fun inferred_type (x, _, mx) ctxt =
   let val ((_, T), ctxt') = ProofContext.inferred_param x ctxt
-  in ((x, SOME T, mx), ctxt') end;
+  in ((x, T, mx), ctxt') end;
+
+fun polymorphic (vars, ctxt) =
+  let val Ts = map Logic.dest_type (ProofContext.polymorphic ctxt (map (Logic.mk_type o #2) vars))
+  in map2 (fn (x, _, mx) => fn T => ((x, T), mx)) vars Ts end;
 
 fun gen_guess prep_vars raw_vars int state =
   let
@@ -226,7 +231,8 @@
     val chain_facts = if can Proof.assert_chain state then Proof.the_facts state else [];
 
     val (thesis_var, thesis) = bind_judgment ctxt AutoBind.thesisN;
-    val (vars, _) = ctxt |> prep_vars (map Syntax.no_syn raw_vars) |-> fold_map inferred_type;
+    val vars = ctxt |> prep_vars (map Syntax.no_syn raw_vars)
+      |-> fold_map inferred_type |> polymorphic;
 
     fun check_result th =
       (case Thm.prems_of th of
@@ -237,33 +243,37 @@
       | [] => error "Goal solved -- nothing guessed."
       | _ => error ("Guess split into several cases:\n" ^ ProofContext.string_of_thm ctxt th));
 
-    fun guess_context raw_rule =
+    fun guess_context [_, raw_rule] =
       let
-        val (parms, rule) = match_params ctxt (map (fn (x, T, _) => (x, T)) vars) raw_rule;
+        val (parms, rule) = unify_params ctxt (map #1 vars) raw_rule;
         val (bind, _) = ProofContext.bind_fixes (map #1 parms) ctxt;
         val ts = map (bind o Free) parms;
         val ps = map dest_Free ts;
         val asms =
           Logic.strip_assums_hyp (Logic.nth_prem (1, Thm.prop_of rule))
           |> map (fn asm => (Term.betapplys (Term.list_abs (ps, asm), ts), []));
-        val _ = conditional (null asms) (fn () => error "Trivial result -- nothing guessed");
+        val _ = not (null asms) orelse error "Trivial result -- nothing guessed";
       in
         Proof.fix_i (map (apsnd SOME) parms)
         #> Proof.assm_i (K (obtain_export ctxt ts rule)) [(("", []), asms)]
         #> Proof.add_binds_i AutoBind.no_facts
       end;
 
-    val before_qed = SOME (Method.primitive_text (Goal.conclude #> Goal.protect));
-    fun after_qed [[res]] =
-      (check_result res; Proof.end_block #> Seq.map (`Proof.the_fact #-> guess_context));
+    val goal = Var (("guess", 0), propT);
+    fun print_result ctxt' (k, [(s, [_, th])]) =
+      ProofDisplay.print_results int ctxt' (k, [(s, [th])]);
+    val before_qed = SOME (Method.primitive_text (Goal.conclude #> (fn th =>
+      Goal.protect (Conjunction.intr (Drule.mk_term (Thm.cprop_of th)) th))));
+    fun after_qed [[_, res]] =
+      (check_result res; Proof.end_block #> Seq.map (`Proof.the_facts #-> guess_context));
   in
     state
     |> Proof.enter_forward
     |> Proof.begin_block
     |> Proof.fix_i [(AutoBind.thesisN, NONE)]
     |> Proof.chain_facts chain_facts
-    |> Proof.local_goal (ProofDisplay.print_results int) (K I) (apsnd (rpair I))
-      "guess" before_qed after_qed [(("", []), [Var (("guess", 0), propT)])]
+    |> Proof.local_goal print_result (K I) (apsnd (rpair I))
+      "guess" before_qed after_qed [(("", []), [Logic.mk_term goal, goal])]
     |> Proof.refine (Method.primitive_text (K (Goal.init (Thm.cterm_of thy thesis)))) |> Seq.hd
   end;