honor original format of conjecture or hypotheses in Z3-to-Isar proofs
authorblanchet
Fri, 16 May 2014 19:13:50 +0200
changeset 56981 3ef45ce002b5
parent 56980 9c5220e05e04
child 56982 51d4189d95cf
honor original format of conjecture or hypotheses in Z3-to-Isar proofs
src/HOL/Tools/SMT2/smt2_normalize.ML
src/HOL/Tools/SMT2/smt2_solver.ML
src/HOL/Tools/SMT2/z3_new_isar.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML
--- a/src/HOL/Tools/SMT2/smt2_normalize.ML	Fri May 16 17:11:56 2014 +0200
+++ b/src/HOL/Tools/SMT2/smt2_normalize.ML	Fri May 16 19:13:50 2014 +0200
@@ -8,6 +8,11 @@
 sig
   val drop_fact_warning: Proof.context -> thm -> unit
   val atomize_conv: Proof.context -> conv
+
+  val special_quant_table: (string * thm) list
+  val case_bool_entry: string * thm
+  val abs_min_max_table: (string * thm) list
+
   type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list
   val add_extra_norm: SMT2_Util.class * extra_norm -> Context.generic -> Context.generic
   val normalize: Proof.context -> (int option * thm) list -> (int * thm) list
@@ -71,13 +76,13 @@
 
 (** unfold special quantifiers **)
 
+val special_quant_table = [
+  (@{const_name Ex1}, @{thm Ex1_def_raw}),
+  (@{const_name Ball}, @{thm Ball_def_raw}),
+  (@{const_name Bex}, @{thm Bex_def_raw})]
+
 local
-  val special_quants = [
-    (@{const_name Ex1}, @{thm Ex1_def_raw}),
-    (@{const_name Ball}, @{thm Ball_def_raw}),
-    (@{const_name Bex}, @{thm Bex_def_raw})]
-  
-  fun special_quant (Const (n, _)) = AList.lookup (op =) special_quants n
+  fun special_quant (Const (n, _)) = AList.lookup (op =) special_quant_table n
     | special_quant _ = NONE
 
   fun special_quant_conv _ ct =
@@ -89,7 +94,7 @@
 fun unfold_special_quants_conv ctxt =
   SMT2_Util.if_exists_conv (is_some o special_quant) (Conv.top_conv special_quant_conv ctxt)
 
-val setup_unfolded_quants = fold (SMT2_Builtin.add_builtin_fun_ext'' o fst) special_quants
+val setup_unfolded_quants = fold (SMT2_Builtin.add_builtin_fun_ext'' o fst) special_quant_table
 
 end
 
@@ -326,6 +331,8 @@
 
 (** rewrite bool case expressions as if expressions **)
 
+val case_bool_entry = (@{const_name "bool.case_bool"}, @{thm case_bool_if})
+
 local
   fun is_case_bool (Const (@{const_name "bool.case_bool"}, _)) = true
     | is_case_bool _ = false
@@ -345,14 +352,14 @@
 
 (** unfold abs, min and max **)
 
+val abs_min_max_table = [
+  (@{const_name min}, @{thm min_def_raw}),
+  (@{const_name max}, @{thm max_def_raw}),
+  (@{const_name abs}, @{thm abs_if_raw})]
+
 local
-  val defs = [
-    (@{const_name min}, @{thm min_def_raw}),
-    (@{const_name max}, @{thm max_def_raw}),
-    (@{const_name abs}, @{thm abs_if_raw})]
-
   fun abs_min_max ctxt (Const (n, Type (@{type_name fun}, [T, _]))) =
-        (case AList.lookup (op =) defs n of
+        (case AList.lookup (op =) abs_min_max_table n of
           NONE => NONE
         | SOME thm => if SMT2_Builtin.is_builtin_typ_ext ctxt T then SOME thm else NONE)
     | abs_min_max _ _ = NONE
@@ -366,7 +373,7 @@
 fun unfold_abs_min_max_conv ctxt =
   SMT2_Util.if_exists_conv (is_some o abs_min_max ctxt) (Conv.top_conv unfold_amm_conv ctxt)
   
-val setup_abs_min_max = fold (SMT2_Builtin.add_builtin_fun_ext'' o fst) defs
+val setup_abs_min_max = fold (SMT2_Builtin.add_builtin_fun_ext'' o fst) abs_min_max_table
 
 end
 
--- a/src/HOL/Tools/SMT2/smt2_solver.ML	Fri May 16 17:11:56 2014 +0200
+++ b/src/HOL/Tools/SMT2/smt2_solver.ML	Fri May 16 19:13:50 2014 +0200
@@ -29,7 +29,7 @@
   (*filter*)
   val smt2_filter: Proof.context -> thm -> ('a * (int option * thm)) list -> int -> Time.time ->
     {outcome: SMT2_Failure.failure option, rewrite_rules: thm list, conjecture_id: int,
-     helper_ids: (int * thm) list, fact_ids: (int * ('a * thm)) list,
+     prem_ids: int list, helper_ids: (int * thm) list, fact_ids: (int * ('a * thm)) list,
      z3_proof: Z3_New_Proof.z3_step list}
 
   (*tactic*)
@@ -260,17 +260,20 @@
     val iwthms = map_index I wthms
 
     val conjecture_i = 0
-    val facts_i = 1 + length wprems
+    val prems_i = 1
+    val facts_i = prems_i + length wprems
   in
     wthms
     |> apply_solver ctxt
     |> (fn (((iidths0, z3_proof), _), {rewrite_rules, ...}) =>
-      let val iidths = if can_filter ctxt then iidths0 else map (apsnd (apfst (K no_id))) iwthms
+      let
+        val iidths = if can_filter ctxt then iidths0 else map (apsnd (apfst (K no_id))) iwthms
+        fun id_of_index i = the_default no_id (Option.map fst (AList.lookup (op =) iidths i))
       in
         {outcome = NONE,
          rewrite_rules = rewrite_rules,
-         conjecture_id =
-           the_default no_id (Option.map fst (AList.lookup (op =) iidths conjecture_i)),
+         conjecture_id = id_of_index conjecture_i,
+         prem_ids = map id_of_index (prems_i upto facts_i - 1),
          helper_ids = map_filter (try (fn (~1, idth) => idth)) iidths,
          fact_ids = map_filter (fn (i, (id, _)) =>
            try (apsnd (apsnd snd o nth xwfacts)) (id, i - facts_i)) iidths,
@@ -278,7 +281,7 @@
       end)
   end
   handle SMT2_Failure.SMT fail => {outcome = SOME fail, rewrite_rules = [], conjecture_id = no_id,
-    helper_ids = [], fact_ids = [], z3_proof = []}
+    prem_ids = [], helper_ids = [], fact_ids = [], z3_proof = []}
 
 
 (* SMT tactic *)
--- a/src/HOL/Tools/SMT2/z3_new_isar.ML	Fri May 16 17:11:56 2014 +0200
+++ b/src/HOL/Tools/SMT2/z3_new_isar.ML	Fri May 16 19:13:50 2014 +0200
@@ -8,8 +8,8 @@
 sig
   type ('a, 'b) atp_step = ('a, 'b) ATP_Proof.atp_step
 
-  val atp_proof_of_z3_proof: theory -> thm list -> int -> (int * string) list ->
-    Z3_New_Proof.z3_step list -> (term, string) atp_step list
+  val atp_proof_of_z3_proof: Proof.context -> thm list -> term list -> term -> int list -> int ->
+    (int * string) list -> Z3_New_Proof.z3_step list -> (term, string) atp_step list
 end;
 
 structure Z3_New_Isar: Z3_NEW_ISAR =
@@ -83,38 +83,67 @@
   Term.map_abs_vars (perhaps (try Name.dest_skolem))
   #> Term.map_aterms (perhaps (try (fn Free (s, T) => Free (Name.dest_skolem s, T))))
 
-fun atp_proof_of_z3_proof thy rewrite_rules conjecture_id fact_ids proof =
+fun atp_proof_of_z3_proof ctxt rewrite_rules hyp_ts concl_t prem_ids conjecture_id fact_ids proof =
   let
-    fun step_of (Z3_New_Proof.Z3_Step {id, rule, prems, concl, ...}) =
+    val thy = Proof_Context.theory_of ctxt
+
+    fun steps_of (Z3_New_Proof.Z3_Step {id, rule, prems, concl, ...}) =
       let
         fun step_name_of id = (string_of_int id, the_list (AList.lookup (op =) fact_ids id))
 
-        val name as (_, ss) = step_name_of id
+        val name as (sid, ss) = step_name_of id
 
-        val role =
-          (case rule of
-            Z3_New_Proof.Asserted =>
-              if not (null ss) then Axiom
-              else if id = conjecture_id then Negated_Conjecture
-              else Hypothesis
-          | Z3_New_Proof.Rewrite => Lemma
-          | Z3_New_Proof.Rewrite_Star => Lemma
-          | Z3_New_Proof.Skolemize => Lemma
-          | Z3_New_Proof.Th_Lemma _ => Lemma
-          | _ => Plain)
-
-        val concl' = concl
+        val concl' =
+          concl
           |> Raw_Simplifier.rewrite_term thy rewrite_rules []
           |> Object_Logic.atomize_term thy
           |> simplify_bool
           |> unskolemize_names
           |> HOLogic.mk_Trueprop
+
+        fun standard_step role =
+          (name, role, concl', Z3_New_Proof.string_of_rule rule, map step_name_of prems)
       in
-        (name, role, concl', Z3_New_Proof.string_of_rule rule, map step_name_of prems)
+        (case rule of
+          Z3_New_Proof.Asserted =>
+          let
+            val name0 = (sid ^ "a", ss)
+            val (role0, concl0) =
+              if not (null ss) then
+                (Axiom, concl(*FIXME*))
+              else if id = conjecture_id then
+                (Conjecture, concl_t)
+              else
+                (Hypothesis,
+                 (case find_index (curry (op =) id) prem_ids of
+                   ~1 => concl
+                 | i => nth hyp_ts i))
+
+            val normalize_prems =
+              SMT2_Normalize.case_bool_entry :: SMT2_Normalize.special_quant_table @
+              SMT2_Normalize.abs_min_max_table
+              |> map_filter (fn (c, th) =>
+                if exists_Const (curry (op =) c o fst) concl0 then
+                  let val s = short_thm_name ctxt th in SOME (s, [s]) end
+                else
+                  NONE)
+          in
+            if null normalize_prems then
+              [(name, role0, concl0, Z3_New_Proof.string_of_rule rule, [])]
+            else
+              [(name0, role0, concl0, Z3_New_Proof.string_of_rule rule, []),
+               (name, Plain, concl', Z3_New_Proof.string_of_rule Z3_New_Proof.Rewrite,
+                name0 :: normalize_prems)]
+          end
+        | Z3_New_Proof.Rewrite => [standard_step Lemma]
+        | Z3_New_Proof.Rewrite_Star => [standard_step Lemma]
+        | Z3_New_Proof.Skolemize => [standard_step Lemma]
+        | Z3_New_Proof.Th_Lemma _ => [standard_step Lemma]
+        | _ => [standard_step Plain])
       end
   in
     proof
-    |> map step_of
+    |> maps steps_of
     |> inline_z3_defs []
     |> inline_z3_hypotheses [] []
   end
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML	Fri May 16 17:11:56 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML	Fri May 16 19:13:50 2014 +0200
@@ -161,8 +161,8 @@
               reraise exn
             else
               {outcome = SOME (SMT2_Failure.Other_Failure (Runtime.exn_message exn)),
-               rewrite_rules = [], conjecture_id = ~1, helper_ids = [], fact_ids = [],
-               z3_proof = []}
+               rewrite_rules = [], conjecture_id = ~1, prem_ids = [], helper_ids = [],
+               fact_ids = [], z3_proof = []}
 
         val death = Timer.checkRealTimer timer
         val outcome0 = if is_none outcome0 then SOME outcome else outcome0
@@ -227,8 +227,8 @@
       end
 
     val weighted_factss = map (apsnd weight_facts) factss
-    val {outcome, filter_result = {conjecture_id, rewrite_rules, helper_ids, fact_ids, z3_proof,
-           ...}, used_from, run_time} =
+    val {outcome, filter_result = {rewrite_rules, conjecture_id, prem_ids, helper_ids, fact_ids,
+           z3_proof, ...}, used_from, run_time} =
       smt2_filter_loop name params state goal subgoal weighted_factss
     val used_named_facts = map snd fact_ids
     val used_facts = map fst used_named_facts
@@ -245,7 +245,7 @@
               val fact_ids =
                 map (fn (id, th) => (id, short_thm_name ctxt th)) helper_ids @
                 map (fn (id, ((name, _), _)) => (id, name)) fact_ids
-              val atp_proof = Z3_New_Isar.atp_proof_of_z3_proof thy rewrite_rules conjecture_id
+              val atp_proof = Z3_New_Isar.atp_proof_of_z3_proof thy rewrite_rules prem_ids conjecture_id
                 fact_ids z3_proof
               val isar_params =
                 K (verbose, (NONE, NONE), preplay_timeout, compress_isar, try0_isar,