added preplay results to sledgehammer_output
authordesharna
Tue, 29 Mar 2022 17:12:15 +0200
changeset 75372 4c8d1ef258d3
parent 75371 136f79711c2a
child 75373 48736d743e8c
added preplay results to sledgehammer_output
src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML
src/HOL/Tools/Sledgehammer/sledgehammer.ML
--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Thu Mar 31 18:14:32 2022 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Tue Mar 29 17:12:15 2022 +0200
@@ -20,12 +20,14 @@
 (* NOTE: Do not forget to update the Sledgehammer documentation to reflect changes here. *)
 
 val check_trivialK = "check_trivial" (*=BOOL: check if goals are "trivial"*)
+val exhaustive_preplayK = "exhaustive_preplay" (*=BOOL: show exhaustive preplay data*)
 val keep_probsK = "keep_probs" (*=BOOL: keep temporary problem files created by sledgehammer*)
 val keep_proofsK = "keep_proofs" (*=BOOL: keep temporary proof files created by ATPs*)
 val proof_methodK = "proof_method" (*=STRING: how to reconstruct proofs (e.g. using metis)*)
 
 (*defaults used in this Mirabelle action*)
 val check_trivial_default = false
+val exhaustive_preplay_default = false
 val keep_probs_default = false
 val keep_proofs_default = false
 
@@ -300,7 +302,7 @@
 in
 
 fun run_sledgehammer (params as {provers, ...}) output_dir keep_probs keep_proofs
-    proof_method_from_msg thy_index trivial pos st =
+    exhaustive_preplay proof_method_from_msg thy_index trivial pos st =
   let
     val thy = Proof.theory_of st
     val thy_name = Context.theory_name thy
@@ -317,9 +319,9 @@
         NONE
     val prover_name = hd provers
     val (sledgehamer_outcome, msg, cpu_time) = run_sh params keep pos st
-    val (time_prover, change_data, proof_method_and_used_thms) =
+    val (time_prover, change_data, proof_method_and_used_thms, exhaustive_preplay_msg) =
       (case sledgehamer_outcome of
-        Sledgehammer.SH_Some {used_facts, run_time, ...} =>
+        Sledgehammer.SH_Some ({used_facts, run_time, ...}, preplay_results) =>
         let
           val num_used_facts = length used_facts
           val time_prover = Time.toMilliseconds run_time
@@ -333,17 +335,32 @@
             #> inc_sh_lemmas num_used_facts
             #> inc_sh_max_lems num_used_facts
             #> inc_sh_time_prover time_prover
+
+          val exhaustive_preplay_msg =
+            if exhaustive_preplay then
+              preplay_results
+              |> map
+                (fn (meth, play_outcome, used_facts) =>
+                    "Preplay: " ^
+                    Sledgehammer_Proof_Methods.string_of_proof_method (map fst used_facts) meth ^
+                    " (" ^ Sledgehammer_Proof_Methods.string_of_play_outcome play_outcome ^ ")")
+              |> cat_lines
+            else
+              ""
         in
           (SOME time_prover, change_data,
-           SOME (proof_method_from_msg msg, map_filter get_thms used_facts))
+           SOME (proof_method_from_msg msg, map_filter get_thms used_facts),
+           exhaustive_preplay_msg)
         end
-      | _ => (NONE, I, NONE))
+      | _ => (NONE, I, NONE, ""))
     val outcome_msg =
       "(SH " ^ string_of_int cpu_time ^ "ms" ^
       (case time_prover of NONE => "" | SOME ms => ", ATP " ^ string_of_int ms ^ "ms") ^
       ") [" ^ prover_name ^ "]: "
   in
-    (sledgehamer_outcome, triv_str ^ outcome_msg ^ msg, change_data #> inc_sh_time_isa cpu_time,
+    (sledgehamer_outcome, triv_str ^ outcome_msg ^ msg ^
+       (if exhaustive_preplay_msg = "" then "" else ("\n" ^ exhaustive_preplay_msg)),
+     change_data #> inc_sh_time_isa cpu_time,
      proof_method_and_used_thms)
   end
 
@@ -434,6 +451,8 @@
       Mirabelle.get_bool_argument arguments (check_trivialK, check_trivial_default)
     val keep_probs = Mirabelle.get_bool_argument arguments (keep_probsK, keep_probs_default)
     val keep_proofs = Mirabelle.get_bool_argument arguments (keep_proofsK, keep_proofs_default)
+    val exhaustive_preplay =
+      Mirabelle.get_bool_argument arguments (exhaustive_preplayK, exhaustive_preplay_default)
     val proof_method_from_msg = proof_method_from_msg arguments
 
     val params = Sledgehammer_Commands.default_params \<^theory> arguments
@@ -450,8 +469,8 @@
           let
             val trivial = check_trivial andalso try0 pre handle Timeout.TIMEOUT _ => false
             val (outcome, log1, change_data1, proof_method_and_used_thms) =
-              run_sledgehammer params output_dir keep_probs keep_proofs proof_method_from_msg
-                theory_index trivial pos pre
+              run_sledgehammer params output_dir keep_probs keep_proofs exhaustive_preplay
+                proof_method_from_msg theory_index trivial pos pre
             val (log2, change_data2) =
               (case proof_method_and_used_thms of
                 SOME (proof_method, used_thms) =>
--- a/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Thu Mar 31 18:14:32 2022 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Tue Mar 29 17:12:15 2022 +0200
@@ -19,16 +19,15 @@
   type prover_problem = Sledgehammer_Prover.prover_problem
   type prover_result = Sledgehammer_Prover.prover_result
 
+  type preplay_result = proof_method * play_outcome * (string * stature) list
+
   datatype sledgehammer_outcome =
-    SH_Some of prover_result
+    SH_Some of prover_result * preplay_result list
   | SH_Unknown
   | SH_Timeout
   | SH_None
 
   val short_string_of_sledgehammer_outcome : sledgehammer_outcome -> string
-
-  val play_one_line_proof : bool -> Time.time -> (string * stature) list -> Proof.state -> int ->
-    proof_method * proof_method list list -> (string * stature) list * (proof_method * play_outcome)
   val string_of_factss : (string * fact list) list -> string
   val run_sledgehammer : params -> mode -> (string -> unit) option -> int -> fact_override ->
     Proof.state -> bool * (sledgehammer_outcome * string)
@@ -53,8 +52,10 @@
 open Sledgehammer_Prover_Minimize
 open Sledgehammer_MaSh
 
+type preplay_result = proof_method * play_outcome * (string * stature) list
+
 datatype sledgehammer_outcome =
-  SH_Some of prover_result
+  SH_Some of prover_result * preplay_result list
 | SH_Unknown
 | SH_Timeout
 | SH_None
@@ -83,26 +84,21 @@
     |> the_default (SH_Unknown, "")
   end
 
-fun play_one_line_proof minimize timeout used_facts state i (preferred_meth, methss) =
+fun play_one_line_proofs minimize timeout used_facts state i methss =
   (if timeout = Time.zeroTime then
-     (used_facts, (preferred_meth, Play_Timed_Out Time.zeroTime))
+     []
    else
      let
        val ctxt = Proof.context_of state
-
-       val fact_names = used_facts |> filter_out (fn (_, (sc, _)) => sc = Chained) |> map fst
+       val used_facts = filter_out (fn (_, (sc, _)) => sc = Chained) used_facts
+       val fact_names = map fst used_facts
        val {facts = chained, goal, ...} = Proof.goal state
        val goal_t = Logic.get_goal (Thm.prop_of goal) i
 
-       fun try_methss [] [] = (used_facts, (preferred_meth, Play_Timed_Out Time.zeroTime))
-         | try_methss ress [] =
-           (used_facts,
-            (case AList.lookup (op =) ress preferred_meth of
-              SOME play => (preferred_meth, play)
-            | NONE => hd (sort (play_outcome_ord o apply2 snd) (rev ress))))
+       fun try_methss ress [] = ress
          | try_methss ress (meths :: methss) =
            let
-             fun mk_step fact_names meths =
+             fun mk_step meths =
                Prove {
                  qualifiers = [],
                  obtains = [],
@@ -112,27 +108,50 @@
                  facts = ([], fact_names),
                  proof_methods = meths,
                  comment = ""}
+             val ress' =
+               preplay_isar_step ctxt chained timeout [] (mk_step meths)
+               |> map (fn result as (meth, play_outcome) =>
+                  (case (minimize, play_outcome) of
+                    (true, Played time) =>
+                    let
+                      val (time', used_names') =
+                        minimized_isar_step ctxt chained time (mk_step [meth])
+                        ||> (facts_of_isar_step #> snd)
+                      val used_facts' = filter (member (op =) used_names' o fst) used_facts
+                    in
+                      (meth, Played time', used_facts')
+                    end
+                  | _ => (meth, play_outcome, used_facts)))
+             val any_succeeded = exists (fn (_, Played _, _) => true | _ => false) ress'
            in
-             (case preplay_isar_step ctxt chained timeout [] (mk_step fact_names meths) of
-               (res as (meth, Played time)) :: _ =>
-               if not minimize then
-                 (used_facts, res)
-               else
-                 let
-                   val (time', used_names') =
-                     minimized_isar_step ctxt chained time (mk_step fact_names [meth])
-                     ||> (facts_of_isar_step #> snd)
-                   val used_facts' = filter (member (op =) used_names' o fst) used_facts
-                 in
-                   (used_facts', (meth, Played time'))
-                 end
-             | ress' => try_methss (ress' @ ress) methss)
+             try_methss (ress' @ ress) (if any_succeeded then [] else methss)
            end
      in
        try_methss [] methss
      end)
-  |> (fn (used_facts, (meth, play)) =>
-        (used_facts |> filter_out (fn (_, (sc, _)) => sc = Chained), (meth, play)))
+  |> map (fn (meth, play_outcome, used_facts) => (meth, play_outcome, filter_out (fn (_, (sc, _)) => sc = Chained) used_facts))
+  |> sort (play_outcome_ord o apply2 (fn (_, play_outcome, _) => play_outcome))
+
+fun select_one_line_proof used_facts preferred_meth preplay_results =
+  (case preplay_results of
+    [] => (used_facts, (preferred_meth, Play_Timed_Out Time.zeroTime))
+  | (best_meth, best_outcome, best_used_facts) :: results' =>
+    let
+      val (prefered_outcome, prefered_used_facts) =
+        (case find_first (fn (meth, _, _) => meth = preferred_meth) preplay_results of
+          NONE => (Play_Timed_Out Time.zeroTime, used_facts)
+        | SOME (_, prefered_outcome, prefered_used_facts) =>
+          (prefered_outcome, prefered_used_facts))
+    in
+      (case (prefered_outcome, best_outcome) of
+        (* If prefered_meth succeeded, use it irrespective of other preplay results *)
+        (Played _, _) => (prefered_used_facts, (preferred_meth, prefered_outcome))
+        (* If prefered_meth did not succeed but best method did, use best method *)
+      | (_, Played _) => (best_used_facts, (best_meth, best_outcome))
+        (* If neither succeeded, use preferred_meth *)
+      | (_, _) => (prefered_used_facts, (preferred_meth, prefered_outcome)))
+    end)
+  |> apfst (filter_out (fn (_, (sc, _)) => sc = Chained))
 
 fun launch_prover (params as {verbose, spy, slices, timeout, ...}) mode learn
     (problem as {state, subgoal, factss, ...} : prover_problem)
@@ -201,15 +220,22 @@
 fun preplay_prover_result ({ minimize, preplay_timeout, ...} : params) state subgoal
     (result as {outcome, used_facts, preferred_methss, message, ...} : prover_result) =
   let
-    val output =
+    val (output, chosen_preplay_outcome) =
       if outcome = SOME ATP_Proof.TimedOut then
-        SH_Timeout
+        (SH_Timeout, select_one_line_proof used_facts (fst preferred_methss) [])
       else if is_some outcome then
-        SH_None
+        (SH_None, select_one_line_proof used_facts (fst preferred_methss) [])
       else
-        SH_Some result
-    fun output_message () = message (fn () =>
-      play_one_line_proof minimize preplay_timeout used_facts state subgoal preferred_methss)
+        let
+          val preplay_results =
+            play_one_line_proofs minimize preplay_timeout used_facts state subgoal
+              (snd preferred_methss)
+          val chosen_preplay_outcome =
+            select_one_line_proof used_facts (fst preferred_methss) preplay_results
+        in
+          (SH_Some (result, preplay_results), chosen_preplay_outcome)
+        end
+    fun output_message () = message (fn () => chosen_preplay_outcome)
   in
     (output, output_message)
   end