src/HOL/Tools/SMT/z3_proof_reconstruction.ML
changeset 41131 fc9d503c3d67
parent 41130 130771a48c70
child 41172 a17c2d669c40
--- a/src/HOL/Tools/SMT/z3_proof_reconstruction.ML	Wed Dec 15 10:12:48 2010 +0100
+++ b/src/HOL/Tools/SMT/z3_proof_reconstruction.ML	Wed Dec 15 10:12:48 2010 +0100
@@ -132,43 +132,70 @@
 (** core proof rules **)
 
 (* assumption *)
+
 local
-  val remove_trigger = @{lemma "trigger t p == p"
-    by (rule eq_reflection, rule trigger_def)}
-
-  val remove_weight = @{lemma "weight w p == p"
-    by (rule eq_reflection, rule weight_def)}
-
-  val prep_rules = [@{thm Let_def}, remove_trigger, remove_weight,
-    L.rewrite_true]
+  val remove_trigger = mk_meta_eq @{thm SMT.trigger_def}
+  val remove_weight = mk_meta_eq @{thm SMT.weight_def}
+  val remove_fun_app = mk_meta_eq @{thm SMT.fun_app_def}
 
   fun rewrite_conv ctxt eqs = Simplifier.full_rewrite
     (Simplifier.context ctxt Simplifier.empty_ss addsimps eqs)
 
-  fun rewrites f ctxt eqs = map (f (Conv.fconv_rule (rewrite_conv ctxt eqs)))
+  val prep_rules = [@{thm Let_def}, remove_trigger, remove_weight,
+    remove_fun_app, L.rewrite_true]
+
+  fun rewrite ctxt eqs = Conv.fconv_rule (rewrite_conv ctxt eqs)
 
   fun burrow_snd_option f (i, thm) = Option.map (pair i) (f thm)
-  fun lookup_assm ctxt assms ct =
-    (case T.net_instance' burrow_snd_option assms ct of
-      SOME ithm => ithm
-    | _ => z3_exn ("not asserted: " ^
-        quote (Syntax.string_of_term ctxt (Thm.term_of ct))))
+
+  fun lookup_assm assms_net ct =
+    T.net_instance' burrow_snd_option assms_net ct
+    |> Option.map (fn ithm as (_, thm) => (ithm, Thm.cprop_of thm aconvc ct))
 in
-fun prepare_assms ctxt rewrite_rules assms =
+
+fun add_asserted outer_ctxt rewrite_rules assms asserted ctxt =
   let
-    val eqs = rewrites I ctxt [L.rewrite_true] rewrite_rules
-    val assms' =
+    val eqs = map (rewrite ctxt [L.rewrite_true]) rewrite_rules
+    val eqs' = union Thm.eq_thm eqs prep_rules
+
+    val assms_net =
       assms
-      |> rewrites apsnd ctxt (union Thm.eq_thm eqs prep_rules)
+      |> map (apsnd (rewrite ctxt eqs'))
       |> map (apsnd (Conv.fconv_rule Thm.eta_conversion))
-  in (eqs, T.thm_net_of snd assms') end
+      |> T.thm_net_of snd 
+
+    fun revert_conv ctxt = rewrite_conv ctxt eqs' then_conv Thm.eta_conversion
+
+    fun assume thm ctxt =
+      let
+        val ct = Thm.cprem_of thm 1
+        val (thm', ctxt') = yield_singleton Assumption.add_assumes ct ctxt
+      in (Thm.implies_elim thm thm', ctxt') end
 
-fun asserted ctxt (eqs, assms) ct =
-  let val revert_conv = rewrite_conv ctxt eqs then_conv Thm.eta_conversion
-  in Thm (T.with_conv revert_conv (snd o lookup_assm ctxt assms) ct) end
+    fun add (idx, ct) ((is, thms), (ctxt, ptab)) =
+      let
+        val thm1 = 
+          Thm.trivial ct
+          |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
+        val thm2 = singleton (Variable.export ctxt outer_ctxt) thm1
+      in
+        (case lookup_assm assms_net (Thm.cprem_of thm2 1) of
+          NONE =>
+            let val (thm, ctxt') = assume thm1 ctxt
+            in ((is, thms), (ctxt', Inttab.update (idx, Thm thm) ptab)) end
+        | SOME ((i, th), exact) =>
+            let
+              val (thm, ctxt') =
+                if exact then (Thm.implies_elim thm1 th, ctxt)
+                else assume thm1 ctxt
+              val thms' = if exact then thms else th :: thms
+            in 
+              ((insert (op =) i is, thms'),
+                (ctxt', Inttab.update (idx, Thm thm) ptab))
+            end)
+      end
+  in fold add asserted (([], []), (ctxt, Inttab.empty)) end
 
-fun find_assm ctxt (unfolds, assms) ct =
-  fst (lookup_assm ctxt assms (Thm.rhs_of (rewrite_conv ctxt unfolds ct)))
 end
 
 
@@ -715,12 +742,12 @@
   fun not_supported r = raise Fail ("Z3: proof rule not implemented: " ^
     quote (P.string_of_rule r))
 
-  fun prove_step assms simpset vars r ps ct (cxp as (cx, ptab)) =
+  fun prove_step simpset vars r ps ct (cxp as (cx, ptab)) =
     (case (r, ps) of
       (* core rules *)
       (P.True_Axiom, _) => (Thm L.true_thm, cxp)
-    | (P.Asserted, _) => (asserted cx assms ct, cxp)
-    | (P.Goal, _) => (asserted cx assms ct, cxp)
+    | (P.Asserted, _) => raise Fail "bad assertion"
+    | (P.Goal, _) => raise Fail "bad assertion"
     | (P.Modus_Ponens, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
     | (P.Modus_Ponens_Oeq, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
     | (P.And_Elim, [(p, i)]) => and_elim (p, i) ct ptab ||> pair cx
@@ -774,55 +801,48 @@
       SOME p => (p, idx)
     | NONE => z3_exn ("unknown proof id: " ^ quote (string_of_int idx)))
 
-  fun prove assms simpset vars (idx, step) (_, cxp as (ctxt, ptab)) =
+  fun prove simpset vars (idx, step) (_, cxp as (ctxt, ptab)) =
     let
       val P.Proof_Step {rule=r, prems, prop, ...} = step
       val ps = map (lookup_proof ptab) prems
       val _ = trace_before ctxt idx r
       val (thm, (ctxt', ptab')) =
         cxp
-        |> prove_step assms simpset vars r ps prop
+        |> prove_step simpset vars r ps prop
         |> tap (check_after idx r ps prop)
     in (thm, (ctxt', Inttab.update (idx, thm) ptab')) end
 
-  val disch_rules = map (pair false)
-    [@{thm allI}, @{thm refl}, @{thm reflexive}]
+  val disch_rules = [@{thm allI}, @{thm refl}, @{thm reflexive}]
+  fun all_disch_rules rules = map (pair false) (disch_rules @ rules)
 
-  fun disch_assm thm =
+  fun disch_assm rules thm =
     if Thm.nprems_of thm = 0 then Drule.flexflex_unique thm
     else
-      (case Seq.pull (Thm.biresolution false disch_rules 1 thm) of
-        SOME (thm', _) => disch_assm thm'
+      (case Seq.pull (Thm.biresolution false rules 1 thm) of
+        SOME (thm', _) => disch_assm rules thm'
       | NONE => raise THM ("failed to discharge premise", 1, [thm]))
 
-  fun discharge outer_ctxt (p, (inner_ctxt, _)) =
+  fun discharge rules outer_ctxt (p, (inner_ctxt, _)) =
     thm_of p
     |> singleton (ProofContext.export inner_ctxt outer_ctxt)
-    |> disch_assm    
-
-  fun filter_assms ctxt assms =
-    let
-      fun add_assm r ct =
-        (case r of
-          P.Asserted => insert (op =) (find_assm ctxt assms ct)
-        | P.Goal => insert (op =) (find_assm ctxt assms ct)
-        | _ => I)
-    in fold (fn (_, P.Proof_Step {rule, prop, ...}) => add_assm rule prop) end
+    |> disch_assm rules
 in
 
 fun reconstruct outer_ctxt recon output =
   let
     val {context=ctxt, typs, terms, rewrite_rules, assms} = recon
-    val (steps, vars, ctxt') = P.parse ctxt typs terms output
-    val assms' = prepare_assms ctxt' rewrite_rules assms
-    val simpset = T.make_simpset ctxt' (Z3_Simps.get ctxt')
+    val (asserted, steps, vars, ctxt1) = P.parse ctxt typs terms output
+
+    val simpset = T.make_simpset ctxt1 (Z3_Simps.get ctxt1)
+
+    val ((is, rules), cxp as (ctxt2, _)) =
+      add_asserted outer_ctxt rewrite_rules assms asserted ctxt1
   in
-    if Config.get ctxt' SMT_Config.filter_only_facts then
-      (filter_assms ctxt' assms' steps [], @{thm TrueI})
+    if Config.get ctxt2 SMT_Config.filter_only_facts then (is, @{thm TrueI})
     else
-      (Thm @{thm TrueI}, (ctxt', Inttab.empty))
-      |> fold (prove assms' simpset vars) steps 
-      |> discharge outer_ctxt
+      (Thm @{thm TrueI}, cxp)
+      |> fold (prove simpset vars) steps 
+      |> discharge (all_disch_rules rules) outer_ctxt
       |> pair []
   end