correctly reconstruct helper facts (e.g. 'nat_int') in Isar proofs
authorblanchet
Thu, 13 Mar 2014 14:48:20 +0100
changeset 56104 fd6e132ee4fb
parent 56103 6689512f3710
child 56105 75dc126f5dcb
correctly reconstruct helper facts (e.g. 'nat_int') in Isar proofs
src/HOL/Tools/ATP/atp_proof_reconstruct.ML
src/HOL/Tools/ATP/atp_util.ML
src/HOL/Tools/SMT2/smt2_normalize.ML
src/HOL/Tools/SMT2/smt2_solver.ML
src/HOL/Tools/SMT2/smt2_translate.ML
src/HOL/Tools/SMT2/z3_new_isar.ML
src/HOL/Tools/SMT2/z3_new_replay.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML
--- a/src/HOL/Tools/ATP/atp_proof_reconstruct.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/ATP/atp_proof_reconstruct.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -377,22 +377,12 @@
       union (op =) (filter (fn (_, (_, status)) => status = Non_Rec_Def) facts) facts')
     accum fact_names
 
-val isa_ext = Thm.get_name_hint @{thm ext}
-val isa_short_ext = Long_Name.base_name isa_ext
-
-fun ext_name ctxt =
-  if Thm.eq_thm_prop (@{thm ext},
-       singleton (Attrib.eval_thms ctxt) (Facts.named isa_short_ext, [])) then
-    isa_short_ext
-  else
-    isa_ext
-
 val leo2_extcnf_equal_neg_rule = "extcnf_equal_neg"
 val leo2_unfold_def_rule = "unfold_def"
 
 fun add_fact ctxt fact_names ((_, ss), _, _, rule, deps) =
   (if rule = leo2_extcnf_equal_neg_rule then
-     insert (op =) (ext_name ctxt, (Global, General))
+     insert (op =) (short_thm_name ctxt ext, (Global, General))
    else if rule = leo2_unfold_def_rule then
      (* LEO 1.3.3 does not record definitions properly, leading to missing dependencies in the TSTP
         proof. Remove the next line once this is fixed. *)
@@ -401,7 +391,7 @@
      (fn [] =>
          (* agsyHOL and Satallax don't include definitions in their unsatisfiable cores, so we
             assume the worst and include them all here. *)
-         [(ext_name ctxt, (Global, General))] |> add_non_rec_defs fact_names
+         [(short_thm_name ctxt ext, (Global, General))] |> add_non_rec_defs fact_names
        | facts => facts)
    else
      I)
--- a/src/HOL/Tools/ATP/atp_util.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/ATP/atp_util.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -48,8 +48,8 @@
   val is_legitimate_tptp_def : term -> bool
   val transform_elim_prop : term -> term
   val specialize_type : theory -> (string * typ) -> term -> term
-  val strip_subgoal :
-    thm -> int -> Proof.context -> (string * typ) list * term list * term
+  val strip_subgoal : thm -> int -> Proof.context -> (string * typ) list * term list * term
+  val short_thm_name : Proof.context -> thm -> string
 end;
 
 structure ATP_Util : ATP_UTIL =
@@ -425,4 +425,13 @@
     val concl_t = t |> Logic.strip_assums_concl |> curry subst_bounds frees
   in (rev params, hyp_ts, concl_t) end
 
+fun short_thm_name ctxt th =
+  let
+    val long = Thm.get_name_hint th
+    val short = Long_Name.base_name long
+  in
+    if Thm.eq_thm_prop (th, singleton (Attrib.eval_thms ctxt) (Facts.named short, [])) then short
+    else long
+  end
+
 end;
--- a/src/HOL/Tools/SMT2/smt2_normalize.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/SMT2/smt2_normalize.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -497,7 +497,7 @@
   let
     val (is, thms) = split_list ithms
     val (thms', extra_thms) = f thms
-  in (is ~~ thms') @ map (pair ~1) extra_thms end
+  in (is ~~ thms') @ tag_list (length is) extra_thms end
 
 fun unfold2 ctxt ithms =
   ithms
--- a/src/HOL/Tools/SMT2/smt2_solver.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/SMT2/smt2_solver.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -20,20 +20,20 @@
     cex_parser: (Proof.context -> SMT2_Translate.replay_data -> string list ->
       term list * term list) option,
     replay: (Proof.context -> SMT2_Translate.replay_data -> string list ->
-      ((int * int) list * Z3_New_Proof.z3_step list) * thm) option }
+      ((int * (int * thm)) list * Z3_New_Proof.z3_step list) * thm) option }
 
   (*registry*)
   val add_solver: solver_config -> theory -> theory
   val solver_name_of: Proof.context -> string
   val available_solvers_of: Proof.context -> string list
   val apply_solver: Proof.context -> (int * (int option * thm)) list ->
-    ((int * int) list * Z3_New_Proof.z3_step list) * thm
+    ((int * (int * thm)) list * Z3_New_Proof.z3_step list) * thm
   val default_max_relevant: Proof.context -> string -> int
 
   (*filter*)
-  val smt2_filter: Proof.context -> thm list -> thm -> ('a * (int option * thm)) list -> int ->
-    Time.time -> {outcome: SMT2_Failure.failure option, used_fact_infos: (int * ('a * thm)) list,
-      z3_proof: Z3_New_Proof.z3_step list}
+  val smt2_filter: Proof.context -> thm -> ('a * (int option * thm)) list -> int -> Time.time ->
+    {outcome: SMT2_Failure.failure option, conjecture_id: int, helper_ids: (int * thm) list,
+     fact_ids: (int * ('a * thm)) list, z3_proof: Z3_New_Proof.z3_step list}
 
   (*tactic*)
   val smt2_tac: Proof.context -> thm list -> int -> tactic
@@ -152,7 +152,7 @@
   cex_parser: (Proof.context -> SMT2_Translate.replay_data -> string list ->
     term list * term list) option,
   replay: (Proof.context -> SMT2_Translate.replay_data -> string list ->
-    ((int * int) list * Z3_New_Proof.z3_step list) * thm) option }
+    ((int * (int * thm)) list * Z3_New_Proof.z3_step list) * thm) option }
 
 
 (* registry *)
@@ -162,7 +162,7 @@
   default_max_relevant: int,
   supports_filter: bool,
   replay: Proof.context -> string list * SMT2_Translate.replay_data ->
-    ((int * int) list * Z3_New_Proof.z3_step list) * thm }
+    ((int * (int * thm)) list * Z3_New_Proof.z3_step list) * thm }
 
 structure Solvers = Generic_Data
 (
@@ -258,7 +258,7 @@
 
 val cnot = Thm.cterm_of @{theory} @{const Not}
 
-fun smt2_filter ctxt facts goal xwthms i time_limit =
+fun smt2_filter ctxt goal xwfacts i time_limit =
   let
     val ctxt =
       ctxt
@@ -273,53 +273,61 @@
         SOME ct => ct
       | NONE => raise SMT2_Failure.SMT (SMT2_Failure.Other_Failure "goal is not a HOL term"))
 
-    val xthms = map (apsnd snd) xwthms
+    val iwconjecture = (~1, (NONE, Thm.assume cprop))
+    val iwprems = map (pair ~2 o pair NONE) prems
+    val iwfacts = map_index I (map snd xwfacts)
 
-    val used_fact_infos_of =
-      if supports_filter ctxt then map_filter (try (apsnd (nth xthms)))
-      else K (map (pair ~1) xthms)
+    val n = length iwfacts
+    val xfacts = map (apsnd snd) xwfacts
   in
-    map snd xwthms
-    |> map_index I
-    |> append (map (pair ~1 o pair NONE) (Thm.assume cprop :: prems @ facts))
+    iwconjecture :: iwprems @ iwfacts
     |> check_topsorts ctxt
     |> apply_solver ctxt
     |> fst
-    |> (fn (idis, z3_proof) =>
-      {outcome = NONE, used_fact_infos = used_fact_infos_of idis, z3_proof = z3_proof})
+    |> (fn (iidths0, z3_proof) =>
+      let val iidths = if supports_filter ctxt then iidths0 else map (apsnd (apfst (K ~1))) iwfacts
+      in
+        {outcome = NONE, 
+         conjecture_id =
+           the_default ~1 (Option.map fst (AList.lookup (op =) iidths (fst iwconjecture))),
+         helper_ids = map_filter (fn (i, (id, th)) => if i >= n then SOME (id, th) else NONE) iidths,
+         fact_ids = map_filter (fn (i, (id, _)) => try (apsnd (nth xfacts)) (id, i)) iidths,
+         z3_proof = z3_proof}
+      end)
   end
-  handle SMT2_Failure.SMT fail => {outcome = SOME fail, used_fact_infos = [], z3_proof = []}
+  handle SMT2_Failure.SMT fail => {outcome = SOME fail, conjecture_id = ~1, helper_ids = [],
+    fact_ids = [], z3_proof = []}
 
 
 (* SMT tactic *)
 
 local
-  fun trace_assumptions ctxt iwthms idis =
+  fun trace_assumptions ctxt iwfacts iidths =
     let
-      val wthms =
-        idis
-        |> map snd
+      val wfacts =
+        iidths
+        |> map fst
         |> filter (fn i => i >= 0)
-        |> map_filter (AList.lookup (op =) iwthms)
+        |> map_filter (AList.lookup (op =) iwfacts)
     in
-      if Config.get ctxt SMT2_Config.trace_used_facts andalso length wthms > 0 then
+      if Config.get ctxt SMT2_Config.trace_used_facts andalso length wfacts > 0 then
         tracing (Pretty.string_of (Pretty.big_list "SMT used facts:"
-          (map (Display.pretty_thm ctxt o snd) wthms)))
+          (map (Display.pretty_thm ctxt o snd) wfacts)))
       else ()
     end
 
-  fun solve ctxt iwthms =
-    iwthms
+  fun solve ctxt iwfacts =
+    iwfacts
     |> check_topsorts ctxt
     |> apply_solver ctxt
-    |>> apfst (trace_assumptions ctxt iwthms)
+    |>> apfst (trace_assumptions ctxt iwfacts)
     |> snd
 
   fun str_of ctxt fail =
     SMT2_Failure.string_of_failure ctxt fail
     |> prefix ("Solver " ^ SMT2_Config.solver_of ctxt ^ ": ")
 
-  fun safe_solve ctxt iwthms = SOME (solve ctxt iwthms)
+  fun safe_solve ctxt iwfacts = SOME (solve ctxt iwfacts)
     handle
       SMT2_Failure.SMT (fail as SMT2_Failure.Counterexample _) =>
         (SMT2_Config.verbose_msg ctxt (str_of ctxt) fail; NONE)
--- a/src/HOL/Tools/SMT2/smt2_translate.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/SMT2/smt2_translate.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -321,8 +321,7 @@
   exception BAD_PATTERN of unit
 
   fun wrap_in_if pat t =
-    if pat then raise BAD_PATTERN ()
-    else @{const If (bool)} $ t $ @{const True} $ @{const False}
+    if pat then raise BAD_PATTERN () else @{const If (bool)} $ t $ @{const True} $ @{const False}
 
   fun is_builtin_conn_or_pred ctxt c ts =
     is_some (SMT2_Builtin.dest_builtin_conn ctxt c ts) orelse
@@ -338,8 +337,7 @@
         (@{const True}, []) => t
       | (@{const False}, []) => t
       | (u as Const (@{const_name If}, _), [t1, t2, t3]) =>
-          if pat then raise BAD_PATTERN ()
-          else u $ in_form t1 $ in_term pat t2 $ in_term pat t3
+          if pat then raise BAD_PATTERN () else u $ in_form t1 $ in_term pat t2 $ in_term pat t3
       | (Const (c as (n, _)), ts) =>
           if is_builtin_conn_or_pred ctxt c ts then wrap_in_if pat (in_form t)
           else if is_quant n then wrap_in_if pat (in_form t)
@@ -357,11 +355,9 @@
       | in_pat t = raise TERM ("bad pattern", [t])
 
     and in_pats ps =
-      in_list @{typ "SMT2.pattern list"}
-        (SOME o in_list @{typ SMT2.pattern} (try in_pat)) ps
+      in_list @{typ "SMT2.pattern list"} (SOME o in_list @{typ SMT2.pattern} (try in_pat)) ps
 
-    and in_trigger ((c as @{const SMT2.trigger}) $ p $ t) =
-          c $ in_pats p $ in_weight t
+    and in_trigger ((c as @{const SMT2.trigger}) $ p $ t) = c $ in_pats p $ in_weight t
       | in_trigger t = in_weight t
 
     and in_form t =
@@ -462,7 +458,7 @@
       let val (Us, U) = SMT2_Util.dest_funT (length ts) T
       in
         fold_map transT Us ##>> transT U #-> (fn Up =>
-        add_fun t (SOME Up) ##>> fold_map trans ts #>> SApp)
+          add_fun t (SOME Up) ##>> fold_map trans ts #>> SApp)
       end
 
     val (us, trx') = fold_map trans ts trx
@@ -528,13 +524,12 @@
           |> pair ctxt')
 
     val ((rewrite_rules, builtin), ts4) = folify ctxt2 ts3
-
-    val rewrite_rules' = fun_app_eq :: rewrite_rules
+      |>> apfst (cons fun_app_eq)
   in
     (ts4, tr_context)
     |-> intermediate header dtyps (builtin SMT2_Builtin.dest_builtin) ctxt2
     |>> uncurry (serialize comments)
-    ||> replay_data_of ctxt2 rewrite_rules' ithms
+    ||> replay_data_of ctxt2 rewrite_rules ithms
   end
 
 end
--- a/src/HOL/Tools/SMT2/z3_new_isar.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/SMT2/z3_new_isar.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -8,7 +8,7 @@
 sig
   type ('a, 'b) atp_step = ('a, 'b) ATP_Proof.atp_step
 
-  val atp_proof_of_z3_proof: theory -> Z3_New_Proof.z3_step list -> (int * string) list ->
+  val atp_proof_of_z3_proof: theory -> int -> (int * string) list -> Z3_New_Proof.z3_step list ->
     (term, string) atp_step list
 end;
 
@@ -74,7 +74,7 @@
 
 fun simplify_line (name, role, t, rule, deps) = (name, role, simplify_prop t, rule, deps)
 
-fun atp_proof_of_z3_proof thy proof fact_ids =
+fun atp_proof_of_z3_proof thy conjecture_id fact_ids proof =
   let
     fun step_of (Z3_New_Proof.Z3_Step {id, rule, prems, concl, ...}) =
       let
@@ -84,7 +84,9 @@
         val role =
           (case rule of
             Z3_New_Proof.Asserted =>
-              if null ss then Negated_Conjecture (* FIXME: or hypothesis! *) else Axiom
+              if not (null ss) then Axiom
+              else if id = conjecture_id then Negated_Conjecture
+              else Hypothesis
           | Z3_New_Proof.Rewrite => Lemma
           | Z3_New_Proof.Rewrite_Star => Lemma
           | Z3_New_Proof.Skolemize => Lemma
--- a/src/HOL/Tools/SMT2/z3_new_replay.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/SMT2/z3_new_replay.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -8,7 +8,7 @@
 signature Z3_NEW_REPLAY =
 sig
   val replay: Proof.context -> SMT2_Translate.replay_data -> string list ->
-    ((int * int) list * Z3_New_Proof.z3_step list) * thm
+    ((int * (int * thm)) list * Z3_New_Proof.z3_step list) * thm
 end
 
 structure Z3_New_Replay: Z3_NEW_REPLAY =
@@ -106,14 +106,14 @@
         val (thm', ctxt') = yield_singleton Assumption.add_assumes ct ctxt
       in (thm' RS thm, ctxt') end
 
-    fun add1 id fixes thm1 ((i, th), exact) ((idis, thms), (ctxt, ptab)) =
+    fun add1 id fixes thm1 ((i, th), exact) ((iidths, thms), (ctxt, ptab)) =
       let
         val (thm, ctxt') = if exact then (Thm.implies_elim thm1 th, ctxt) else assume thm1 ctxt
         val thms' = if exact then thms else th :: thms
-      in (((id, i) :: idis, thms'), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
+      in (((i, (id, th)) :: iidths, thms'), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
 
     fun add (Z3_New_Proof.Z3_Step {id, rule, concl, fixes, ...})
-        (cx as ((idis, thms), (ctxt, ptab))) =
+        (cx as ((iidths, thms), (ctxt, ptab))) =
       if Z3_New_Replay_Methods.is_assumption rule andalso rule <> Z3_New_Proof.Hypothesis then
         let
           val ct = SMT2_Util.certify ctxt concl
@@ -123,7 +123,7 @@
           (case lookup_assm assms_net (Thm.cprem_of thm2 1) of
             [] =>
               let val (thm, ctxt') = assume thm1 ctxt
-              in ((idis, thms), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
+              in ((iidths, thms), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
           | ithms => fold (add1 id fixes thm1) ithms cx)
         end
       else
@@ -176,10 +176,10 @@
     ({context=ctxt, typs, terms, rewrite_rules, assms} : SMT2_Translate.replay_data) output =
   let
     val (steps, ctxt2) = Z3_New_Proof.parse typs terms output ctxt
-    val ((idis, rules), (ctxt3, assumed)) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
+    val ((iidths, rules), (ctxt3, assumed)) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
   in
     if Config.get ctxt3 SMT2_Config.filter_only_facts then
-      ((idis, steps), TrueI)
+      ((iidths, steps), TrueI)
     else
       let
         val ctxt4 = put_simpset (Z3_New_Replay_Util.make_simpset ctxt3 []) ctxt3
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt2.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -154,14 +154,14 @@
         val birth = Timer.checkRealTimer timer
         val _ = if debug then Output.urgent_message "Invoking SMT solver..." else ()
 
-        val {outcome, used_fact_infos, z3_proof} =
-          SMT2_Solver.smt2_filter ctxt [] goal weighted_facts i slice_timeout
+        val filter_result as {outcome, ...} =
+          SMT2_Solver.smt2_filter ctxt goal weighted_facts i slice_timeout
           handle exn =>
             if Exn.is_interrupt exn orelse debug then
               reraise exn
             else
               {outcome = SOME (SMT2_Failure.Other_Failure (ML_Compiler.exn_message exn)),
-               used_fact_infos = [], z3_proof = []}
+               conjecture_id = ~1, helper_ids = [], fact_ids = [], z3_proof = []}
 
         val death = Timer.checkRealTimer timer
         val outcome0 = if is_none outcome0 then SOME outcome else outcome0
@@ -206,9 +206,8 @@
             do_slice timeout (slice + 1) outcome0 time_so_far weighted_factss
           end
         else
-          {outcome = if is_none outcome then NONE else the outcome0,
-           used_fact_infos = used_fact_infos, used_from = map (apsnd snd) weighted_facts,
-           z3_proof = z3_proof, run_time = time_so_far}
+          {outcome = if is_none outcome then NONE else the outcome0, filter_result = filter_result,
+           used_from = map (apsnd snd) weighted_facts, run_time = time_so_far}
       end
   in
     do_slice timeout 1 NONE Time.zeroTime
@@ -227,9 +226,9 @@
       end
 
     val weighted_factss = map (apsnd weight_facts) factss
-    val {outcome, used_fact_infos, used_from, z3_proof, run_time} =
-      smt2_filter_loop name params state goal subgoal weighted_factss
-    val used_named_facts = map snd used_fact_infos
+    val {outcome, filter_result = {conjecture_id, helper_ids, fact_ids, z3_proof, ...},
+         used_from, run_time} = smt2_filter_loop name params state goal subgoal weighted_factss
+    val used_named_facts = map snd fact_ids
     val used_facts = map fst used_named_facts
     val outcome = Option.map failure_of_smt2_failure outcome
 
@@ -241,8 +240,10 @@
              SMT2_Method (bunch_of_proof_methods (smt_proofs <> SOME false) false liftingN)),
          fn preplay =>
             let
-              val fact_ids = map (fn (id, ((name, _), _)) => (id, name)) used_fact_infos
-              val atp_proof = Z3_New_Isar.atp_proof_of_z3_proof thy z3_proof fact_ids
+              val fact_ids =
+                map (fn (id, th) => (id, short_thm_name ctxt th)) helper_ids @
+                map (fn (id, ((name, _), _)) => (id, name)) fact_ids
+              val atp_proof = Z3_New_Isar.atp_proof_of_z3_proof thy conjecture_id fact_ids z3_proof
               val isar_params =
                 K (verbose, (NONE, NONE), preplay_timeout, compress_isar, try0_isar,
                    minimize <> SOME false, atp_proof, goal)