tuned mirabelle_sledgehammer to have a single call to Synchronized.change per run
authordesharna
Fri, 21 Jan 2022 15:29:36 +0100
changeset 74996 1f4c39ffb116
parent 74991 d699eb2d26ad
child 74997 d4a52993a81e
tuned mirabelle_sledgehammer to have a single call to Synchronized.change per run
src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML
--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Fri Jan 21 12:09:55 2022 +0100
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Fri Jan 21 15:29:36 2022 +0100
@@ -310,15 +310,13 @@
 
 in
 
-fun run_sledgehammer change_data (params as {provers, ...}) output_dir
-  e_selection_heuristic term_order force_sos keep_probs keep_proofs proof_method_from_msg thy_index
-  trivial proof_method named_thms pos st =
+fun run_sledgehammer (params as {provers, ...}) output_dir e_selection_heuristic term_order
+  force_sos keep_probs keep_proofs proof_method_from_msg thy_index trivial proof_method named_thms
+  pos st =
   let
     val thy = Proof.theory_of st
     val thy_name = Context.theory_name thy
     val triv_str = if trivial then "[T] " else ""
-    val _ = change_data inc_sh_calls
-    val _ = if trivial then () else change_data inc_sh_nontriv_calls
     val keep =
       if keep_probs orelse keep_proofs then
         let val subdir = StringCvt.padLeft #"0" 4 (string_of_int thy_index) ^ "_" ^ thy_name in
@@ -332,7 +330,7 @@
     val prover_name = hd provers
     val (sledgehamer_outcome, msg, cpu_time) =
       run_sh params e_selection_heuristic term_order force_sos keep pos st
-    val outcome_msg =
+    val (outcome_msg, change_data) =
       (case sledgehamer_outcome of
         Sledgehammer.SH_Some {used_facts, run_time, ...} =>
         let
@@ -342,21 +340,23 @@
             try (Sledgehammer_Util.thms_of_name (Proof.context_of st))
               name
             |> Option.map (pair (name, stature))
+          val outcome_msg =
+            " (" ^ string_of_int cpu_time ^ "+" ^ string_of_int time_prover ^ ")" ^
+            " [" ^ prover_name ^ "]:\n"
+          val change_data =
+            inc_sh_success
+            #> not trivial ? inc_sh_nontriv_success
+            #> inc_sh_lemmas num_used_facts
+            #> inc_sh_max_lems num_used_facts
+            #> inc_sh_time_prover time_prover
         in
-          change_data inc_sh_success;
-          if trivial then () else change_data inc_sh_nontriv_success;
-          change_data (inc_sh_lemmas num_used_facts);
-          change_data (inc_sh_max_lems num_used_facts);
-          change_data (inc_sh_time_prover time_prover);
           proof_method := proof_method_from_msg msg;
           named_thms := SOME (map_filter get_thms used_facts);
-          " (" ^ string_of_int cpu_time ^ "+" ^ string_of_int time_prover ^ ")" ^
-          " [" ^ prover_name ^ "]:\n"
+          (outcome_msg, change_data)
         end
-      | _ => "")
+      | _ => ("", I))
   in
-    change_data (inc_sh_time_isa cpu_time);
-    (sledgehamer_outcome, triv_str ^ outcome_msg ^ msg)
+    (sledgehamer_outcome, triv_str ^ outcome_msg ^ msg, change_data #> inc_sh_time_isa cpu_time)
   end
 
 end
@@ -369,7 +369,7 @@
    ("slice", "false"),
    ("timeout", timeout |> Time.toSeconds |> string_of_int)]
 
-fun run_proof_method change_data trivial full name meth named_thms timeout pos st =
+fun run_proof_method trivial full name meth named_thms timeout pos st =
   let
     fun do_method named_thms ctxt =
       let
@@ -418,23 +418,25 @@
     fun apply_method named_thms =
       Mirabelle.can_apply timeout (do_method named_thms) st
 
-    fun with_time (false, t) = "failed (" ^ string_of_int t ^ ")"
+    fun with_time (false, t) = ("failed (" ^ string_of_int t ^ ")", I)
       | with_time (true, t) =
-          (change_data inc_proof_method_success;
-           if trivial then () else change_data inc_proof_method_nontriv_success;
-           change_data (inc_proof_method_lemmas (length named_thms));
-           change_data (inc_proof_method_time t);
-           change_data (inc_proof_method_posns (pos, trivial));
-           if name = "proof" then change_data inc_proof_method_proofs else ();
-           "succeeded (" ^ string_of_int t ^ ")")
+          ("succeeded (" ^ string_of_int t ^ ")",
+           inc_proof_method_success
+           #> not trivial ? inc_proof_method_nontriv_success
+           #> inc_proof_method_lemmas (length named_thms)
+           #> inc_proof_method_time t
+           #> inc_proof_method_posns (pos, trivial)
+           #> name = "proof" ? inc_proof_method_proofs)
     fun timed_method named_thms =
       with_time (Mirabelle.cpu_time apply_method named_thms)
-        handle Timeout.TIMEOUT _ => (change_data inc_proof_method_timeout; "timeout")
-          | ERROR msg => ("error: " ^ msg)
-
-    val _ = change_data inc_proof_method_calls
-    val _ = if trivial then () else change_data inc_proof_method_nontriv_calls
-  in timed_method named_thms end
+        handle Timeout.TIMEOUT _ => ("timeout", inc_proof_method_timeout)
+          | ERROR msg => ("error: " ^ msg, I)
+  in
+    timed_method named_thms
+    |> apsnd (fn change_data => change_data
+      #> inc_proof_method_calls
+      #> not trivial ? inc_proof_method_nontriv_calls)
+  end
 
 val try_timeout = seconds 5.0
 
@@ -459,7 +461,6 @@
             | _ => error "sledgehammer action requires one and only one prover"))
 
     val data = Synchronized.var "Mirabelle_Sledgehammer.data" empty_data
-    val change_data = Synchronized.change data
 
     val init_msg = "Params for sledgehammer: " ^ Sledgehammer_Prover.string_of_params params
 
@@ -475,16 +476,18 @@
             val trivial =
               check_trivial andalso Try0.try0 (SOME try_timeout) ([], [], [], []) pre
               handle Timeout.TIMEOUT _ => false
-            val (outcome, log1) =
-              run_sledgehammer change_data params output_dir e_selection_heuristic term_order
+            val (outcome, log1, change_data1) =
+              run_sledgehammer params output_dir e_selection_heuristic term_order
                 force_sos keep_probs keep_proofs proof_method_from_msg theory_index trivial meth
                 named_thms pos pre
-            val log2 =
+            val (log2, change_data2) =
               (case !named_thms of
                 SOME thms =>
-                !meth ^ " (sledgehammer): " ^ run_proof_method change_data trivial false name meth
-                  thms timeout pos pre
-              | NONE => "")
+                run_proof_method trivial false name meth thms timeout pos pre
+                |> apfst (prefix (!meth ^ " (sledgehammer): "))
+              | NONE => ("", I))
+            val () = Synchronized.change data
+              (change_data1 #> change_data2 #> inc_sh_calls #> not trivial ? inc_sh_nontriv_calls)
           in
             log1 ^ "\n" ^ log2
             |> Symbol.trim_blanks