thread through step IDs from Z3 to Sledgehammer
authorblanchet
Thu, 13 Mar 2014 13:18:14 +0100
changeset 56099 bc036c1cf111
parent 56098 d530cc905c2f
child 56100 0dc5f68a7802
thread through step IDs from Z3 to Sledgehammer
src/HOL/Tools/SMT2/smt2_solver.ML
src/HOL/Tools/SMT2/z3_new_isar.ML
src/HOL/Tools/SMT2/z3_new_replay.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML
--- a/src/HOL/Tools/SMT2/smt2_solver.ML	Thu Mar 13 13:18:14 2014 +0100
+++ b/src/HOL/Tools/SMT2/smt2_solver.ML	Thu Mar 13 13:18:14 2014 +0100
@@ -20,20 +20,20 @@
     cex_parser: (Proof.context -> SMT2_Translate.replay_data -> string list ->
       term list * term list) option,
     replay: (Proof.context -> SMT2_Translate.replay_data -> string list ->
-      (int list * Z3_New_Proof.z3_step list) * thm) option }
+      ((int * int) list * Z3_New_Proof.z3_step list) * thm) option }
 
   (*registry*)
   val add_solver: solver_config -> theory -> theory
   val solver_name_of: Proof.context -> string
   val available_solvers_of: Proof.context -> string list
   val apply_solver: Proof.context -> (int * (int option * thm)) list ->
-    (int list * Z3_New_Proof.z3_step list) * thm
+    ((int * int) list * Z3_New_Proof.z3_step list) * thm
   val default_max_relevant: Proof.context -> string -> int
 
   (*filter*)
   val smt2_filter: Proof.context -> thm list -> thm -> ('a * (int option * thm)) list -> int ->
-    Time.time -> {outcome: SMT2_Failure.failure option, used_facts: ('a * thm) list,
-      z3_steps: Z3_New_Proof.z3_step list}
+    Time.time -> {outcome: SMT2_Failure.failure option, used_fact_infos: (int * ('a * thm)) list,
+      z3_proof: Z3_New_Proof.z3_step list}
 
   (*tactic*)
   val smt2_tac: Proof.context -> thm list -> int -> tactic
@@ -152,7 +152,7 @@
   cex_parser: (Proof.context -> SMT2_Translate.replay_data -> string list ->
     term list * term list) option,
   replay: (Proof.context -> SMT2_Translate.replay_data -> string list ->
-    (int list * Z3_New_Proof.z3_step list) * thm) option }
+    ((int * int) list * Z3_New_Proof.z3_step list) * thm) option }
 
 
 (* registry *)
@@ -162,7 +162,7 @@
   default_max_relevant: int,
   supports_filter: bool,
   replay: Proof.context -> string list * SMT2_Translate.replay_data ->
-    (int list * Z3_New_Proof.z3_step list) * thm }
+    ((int * int) list * Z3_New_Proof.z3_step list) * thm }
 
 structure Solvers = Generic_Data
 (
@@ -275,7 +275,9 @@
 
     val xthms = map (apsnd snd) xwthms
 
-    val filter_thms = if supports_filter ctxt then map_filter (try (nth xthms)) else K xthms
+    val used_fact_infos_of =
+      if supports_filter ctxt then map_filter (try (apsnd (nth xthms)))
+      else K (map (pair ~1) xthms)
   in
     map snd xwthms
     |> map_index I
@@ -283,23 +285,24 @@
     |> check_topsorts ctxt
     |> apply_solver ctxt
     |> fst
-    |> (fn (is, z3_steps) => {outcome = NONE, used_facts = filter_thms is, z3_steps = z3_steps})
+    |> (fn (idis, z3_proof) =>
+      {outcome = NONE, used_fact_infos = used_fact_infos_of idis, z3_proof = z3_proof})
   end
-  handle SMT2_Failure.SMT fail => {outcome = SOME fail, used_facts = [], z3_steps = []}
+  handle SMT2_Failure.SMT fail => {outcome = SOME fail, used_fact_infos = [], z3_proof = []}
 
 
 (* SMT tactic *)
 
 local
-  fun trace_assumptions ctxt iwthms idxs =
+  fun trace_assumptions ctxt iwthms idis =
     let
       val wthms =
-        idxs
+        idis
+        |> map snd
         |> filter (fn i => i >= 0)
         |> map_filter (AList.lookup (op =) iwthms)
     in
-      if Config.get ctxt SMT2_Config.trace_used_facts andalso length wthms > 0
-      then
+      if Config.get ctxt SMT2_Config.trace_used_facts andalso length wthms > 0 then
         tracing (Pretty.string_of (Pretty.big_list "SMT used facts:"
           (map (Display.pretty_thm ctxt o snd) wthms)))
       else ()
--- a/src/HOL/Tools/SMT2/z3_new_isar.ML	Thu Mar 13 13:18:14 2014 +0100
+++ b/src/HOL/Tools/SMT2/z3_new_isar.ML	Thu Mar 13 13:18:14 2014 +0100
@@ -8,7 +8,8 @@
 sig
   type ('a, 'b) atp_step = ('a, 'b) ATP_Proof.atp_step
 
-  val atp_proof_of_z3_proof: theory -> Z3_New_Proof.z3_step list -> (term, string) atp_step list
+  val atp_proof_of_z3_proof: theory -> Z3_New_Proof.z3_step list -> (int * string) list ->
+    (term, string) atp_step list
 end;
 
 structure Z3_New_Isar: Z3_NEW_ISAR =
@@ -73,29 +74,24 @@
 
 fun simplify_line (name, role, t, rule, deps) = (name, role, simplify_prop t, rule, deps)
 
-fun atp_proof_of_z3_proof thy proof =
+fun atp_proof_of_z3_proof thy proof fact_ids =
   let
-    fun step_name_of id = (string_of_int id, [])
-
-    (* FIXME: find actual conjecture *)
-    val id_of_conjecture =
-      proof
-      |> find_first (fn Z3_New_Proof.Z3_Step {rule, ...} => rule = Z3_New_Proof.Asserted)
-      |> Option.map (fn Z3_New_Proof.Z3_Step {id, ...} => id)
-
     fun step_of (Z3_New_Proof.Z3_Step {id, rule, prems, concl, ...}) =
       let
+        fun step_name_of id = (string_of_int id, the_list (AList.lookup (op =) fact_ids id))
+
+        val name as (_, ss) = step_name_of id
         val role =
           (case rule of
             Z3_New_Proof.Asserted =>
-              if id_of_conjecture = SOME id then Negated_Conjecture else Hypothesis
+              if null ss then Negated_Conjecture (* FIXME: or hypothesis! *) else Axiom
           | Z3_New_Proof.Rewrite => Lemma
           | Z3_New_Proof.Rewrite_Star => Lemma
           | Z3_New_Proof.Skolemize => Lemma
           | Z3_New_Proof.Th_Lemma _ => Lemma
           | _ => Plain)
       in
-        (step_name_of id, role, HOLogic.mk_Trueprop (Object_Logic.atomize_term thy concl),
+        (name, role, HOLogic.mk_Trueprop (Object_Logic.atomize_term thy concl),
          Z3_New_Proof.string_of_rule rule, map step_name_of prems)
       end
   in
--- a/src/HOL/Tools/SMT2/z3_new_replay.ML	Thu Mar 13 13:18:14 2014 +0100
+++ b/src/HOL/Tools/SMT2/z3_new_replay.ML	Thu Mar 13 13:18:14 2014 +0100
@@ -8,7 +8,7 @@
 signature Z3_NEW_REPLAY =
 sig
   val replay: Proof.context -> SMT2_Translate.replay_data -> string list ->
-    (int list * Z3_New_Proof.z3_step list) * thm
+    ((int * int) list * Z3_New_Proof.z3_step list) * thm
 end
 
 structure Z3_New_Replay: Z3_NEW_REPLAY =
@@ -106,26 +106,24 @@
         val (thm', ctxt') = yield_singleton Assumption.add_assumes ct ctxt
       in (thm' RS thm, ctxt') end
 
-    fun add1 id fixes thm1 ((i, th), exact) ((is, thms), (ctxt, ptab)) =
+    fun add1 id fixes thm1 ((i, th), exact) ((idis, thms), (ctxt, ptab)) =
       let
         val (thm, ctxt') = if exact then (Thm.implies_elim thm1 th, ctxt) else assume thm1 ctxt
         val thms' = if exact then thms else th :: thms
-      in ((insert (op =) i is, thms'), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
+      in (((id, i) :: idis, thms'), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
 
     fun add (Z3_New_Proof.Z3_Step {id, rule, concl, fixes, ...})
-        (cx as ((is, thms), (ctxt, ptab))) =
+        (cx as ((idis, thms), (ctxt, ptab))) =
       if Z3_New_Replay_Methods.is_assumption rule andalso rule <> Z3_New_Proof.Hypothesis then
         let
           val ct = SMT2_Util.certify ctxt concl
-          val thm1 =
-            Thm.trivial ct
-            |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
+          val thm1 = Thm.trivial ct |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
           val thm2 = singleton (Variable.export ctxt outer_ctxt) thm1
         in
           (case lookup_assm assms_net (Thm.cprem_of thm2 1) of
             [] =>
               let val (thm, ctxt') = assume thm1 ctxt
-              in ((is, thms), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
+              in ((idis, thms), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
           | ithms => fold (add1 id fixes thm1) ithms cx)
         end
       else
@@ -178,10 +176,10 @@
     ({context=ctxt, typs, terms, rewrite_rules, assms} : SMT2_Translate.replay_data) output =
   let
     val (steps, ctxt2) = Z3_New_Proof.parse typs terms output ctxt
-    val ((is, rules), (ctxt3, assumed)) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
+    val ((idis, rules), (ctxt3, assumed)) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
   in
     if Config.get ctxt3 SMT2_Config.filter_only_facts then
-      ((is, steps), TrueI)
+      ((idis, steps), TrueI)
     else
       let
         val ctxt4 = put_simpset (Z3_New_Replay_Util.make_simpset ctxt3 []) ctxt3
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML	Thu Mar 13 13:18:14 2014 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML	Thu Mar 13 13:18:14 2014 +0100
@@ -154,14 +154,14 @@
         val birth = Timer.checkRealTimer timer
         val _ = if debug then Output.urgent_message "Invoking SMT solver..." else ()
 
-        val {outcome, used_facts, z3_steps} =
+        val {outcome, used_fact_infos, z3_proof} =
           SMT2_Solver.smt2_filter ctxt [] goal weighted_facts i slice_timeout
           handle exn =>
             if Exn.is_interrupt exn orelse debug then
               reraise exn
             else
               {outcome = SOME (SMT2_Failure.Other_Failure (ML_Compiler.exn_message exn)),
-               used_facts = [], z3_steps = []}
+               used_fact_infos = [], z3_proof = []}
 
         val death = Timer.checkRealTimer timer
         val outcome0 = if is_none outcome0 then SOME outcome else outcome0
@@ -206,8 +206,9 @@
             do_slice timeout (slice + 1) outcome0 time_so_far weighted_factss
           end
         else
-          {outcome = if is_none outcome then NONE else the outcome0, used_facts = used_facts,
-           used_from = map (apsnd snd) weighted_facts, z3_steps = z3_steps, run_time = time_so_far}
+          {outcome = if is_none outcome then NONE else the outcome0,
+           used_fact_infos = used_fact_infos, used_from = map (apsnd snd) weighted_facts,
+           z3_proof = z3_proof, run_time = time_so_far}
       end
   in
     do_slice timeout 1 NONE Time.zeroTime
@@ -225,21 +226,23 @@
         map (weight_smt2_fact ctxt num_facts) (facts ~~ (0 upto num_facts - 1))
       end
 
-    val weighted_factss = factss |> map (apsnd weight_facts)
-    val {outcome, used_facts = used_pairs, used_from, z3_steps, run_time} =
+    val weighted_factss = map (apsnd weight_facts) factss
+    val {outcome, used_fact_infos, used_from, z3_proof, run_time} =
       smt2_filter_loop name params state goal subgoal weighted_factss
-    val used_facts = used_pairs |> map fst
-    val outcome = outcome |> Option.map failure_of_smt2_failure
+    val used_named_facts = map snd used_fact_infos
+    val used_facts = map fst used_named_facts
+    val outcome = Option.map failure_of_smt2_failure outcome
 
     val (preplay, message, message_tail) =
       (case outcome of
         NONE =>
         (Lazy.lazy (fn () =>
-           play_one_line_proof mode debug verbose preplay_timeout used_pairs state subgoal
+           play_one_line_proof mode debug verbose preplay_timeout used_named_facts state subgoal
              SMT2_Method (bunch_of_proof_methods (smt_proofs <> SOME false) false liftingN)),
          fn preplay =>
             let
-              val atp_proof = Z3_New_Isar.atp_proof_of_z3_proof thy z3_steps
+              val fact_ids = map (fn (id, ((name, _), _)) => (id, name)) used_fact_infos
+              val atp_proof = Z3_New_Isar.atp_proof_of_z3_proof thy z3_proof fact_ids
               val isar_params =
                 K (verbose, (NONE, NONE), preplay_timeout, compress_isar, try0_isar,
                    minimize <> SOME false, atp_proof, goal)