src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML
changeset 73847 58f6b41efe88
parent 73797 f7ea394490f5
child 73849 4eac16052a94
--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Sun Jun 06 21:39:26 2021 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Thu Jun 10 11:21:57 2021 +0200
@@ -1,11 +1,14 @@
 (*  Title:      HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML
-    Author:     Jasmin Blanchette and Sascha Boehme and Tobias Nipkow, TU Munich
+    Author:     Jasmin Blanchette, TU Munich
+    Author:     Sascha Boehme, TU Munich
+    Author:     Tobias Nipkow, TU Munich
     Author:     Makarius
+    Author:     Martin Desharnais, UniBw Munich
 
 Mirabelle action: "sledgehammer".
 *)
 
-structure Mirabelle_Sledgehammer: sig end =
+structure Mirabelle_Sledgehammer: MIRABELLE_ACTION =
 struct
 
 (*To facilitate synching the description of Mirabelle Sledgehammer parameters
@@ -23,7 +26,6 @@
 val isar_proofsK = "isar_proofs" (*=SMART_BOOL: enable Isar proof generation*)
 val keepK = "keep" (*=PATH: path where to keep temporary files created by sledgehammer*)
 val lam_transK = "lam_trans" (*=STRING: lambda translation scheme*)
-val max_callsK = "max_calls" (*=NUM: max. no. of calls to sledgehammer*)
 val max_factsK = "max_facts" (*=NUM: max. relevant clauses to use*)
 val max_mono_itersK = "max_mono_iters" (*=NUM: max. iterations of monomorphiser*)
 val max_new_mono_instancesK = "max_new_mono_instances" (*=NUM: max. new monomorphic instances*)
@@ -40,8 +42,6 @@
 val type_encK = "type_enc" (*=STRING: type encoding scheme*)
 val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*)
 
-val separator = "-----"
-
 (*FIXME sensible to have Mirabelle-level Sledgehammer defaults?*)
 (*defaults used in this Mirabelle action*)
 val preplay_timeout_default = "1"
@@ -52,7 +52,6 @@
 val strict_default = "false"
 val max_facts_default = "smart"
 val slice_default = "true"
-val max_calls_default = 10000000
 val check_trivial_default = false
 
 (*If a key is present in args then augment a list with its pair*)
@@ -193,7 +192,7 @@
 
 fun inc_proof_method_time t = map_re_data
  (fn (calls,success,nontriv_calls, nontriv_success, proofs,time,timeout,lemmas,posns)
-  => (calls, success, nontriv_calls, nontriv_success, proofs, time + t, timeout, lemmas,posns))
+    => (calls, success, nontriv_calls, nontriv_success, proofs, time + t, timeout, lemmas,posns))
 
 val inc_proof_method_timeout = map_re_data
   (fn (calls,success,nontriv_calls, nontriv_success, proofs,time,timeout,lemmas,posns)
@@ -218,90 +217,62 @@
 fun avg_time t n =
   if n > 0 then (Real.fromInt t / 1000.0) / Real.fromInt n else 0.0
 
-fun log_sh_data (ShData
-    {calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail}) =
-  let
-    val props =
-     [("sh_calls", str calls),
-      ("sh_success", str success),
-      ("sh_nontriv_calls", str nontriv_calls),
-      ("sh_nontriv_success", str nontriv_success),
-      ("sh_lemmas", str lemmas),
-      ("sh_max_lems", str max_lems),
-      ("sh_time_isa", str3 (ms time_isa)),
-      ("sh_time_prover", str3 (ms time_prover)),
-      ("sh_time_prover_fail", str3 (ms time_prover_fail))]
-    val text =
-      "\nTotal number of sledgehammer calls: " ^ str calls ^
-      "\nNumber of successful sledgehammer calls: " ^ str success ^
-      "\nNumber of sledgehammer lemmas: " ^ str lemmas ^
-      "\nMax number of sledgehammer lemmas: " ^ str max_lems ^
-      "\nSuccess rate: " ^ percentage success calls ^ "%" ^
-      "\nTotal number of nontrivial sledgehammer calls: " ^ str nontriv_calls ^
-      "\nNumber of successful nontrivial sledgehammer calls: " ^ str nontriv_success ^
-      "\nTotal time for sledgehammer calls (Isabelle): " ^ str3 (ms time_isa) ^
-      "\nTotal time for successful sledgehammer calls (ATP): " ^ str3 (ms time_prover) ^
-      "\nTotal time for failed sledgehammer calls (ATP): " ^ str3 (ms time_prover_fail) ^
-      "\nAverage time for sledgehammer calls (Isabelle): " ^
-        str3 (avg_time time_isa calls) ^
-      "\nAverage time for successful sledgehammer calls (ATP): " ^
-        str3 (avg_time time_prover success) ^
-      "\nAverage time for failed sledgehammer calls (ATP): " ^
-        str3 (avg_time time_prover_fail (calls - success))
-  in (props, text) end
+fun log_sh_data (ShData {calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa,
+      time_prover, time_prover_fail}) =
+  "\nTotal number of sledgehammer calls: " ^ str calls ^
+  "\nNumber of successful sledgehammer calls: " ^ str success ^
+  "\nNumber of sledgehammer lemmas: " ^ str lemmas ^
+  "\nMax number of sledgehammer lemmas: " ^ str max_lems ^
+  "\nSuccess rate: " ^ percentage success calls ^ "%" ^
+  "\nTotal number of nontrivial sledgehammer calls: " ^ str nontriv_calls ^
+  "\nNumber of successful nontrivial sledgehammer calls: " ^ str nontriv_success ^
+  "\nTotal time for sledgehammer calls (Isabelle): " ^ str3 (ms time_isa) ^
+  "\nTotal time for successful sledgehammer calls (ATP): " ^ str3 (ms time_prover) ^
+  "\nTotal time for failed sledgehammer calls (ATP): " ^ str3 (ms time_prover_fail) ^
+  "\nAverage time for sledgehammer calls (Isabelle): " ^
+    str3 (avg_time time_isa calls) ^
+  "\nAverage time for successful sledgehammer calls (ATP): " ^
+    str3 (avg_time time_prover success) ^
+  "\nAverage time for failed sledgehammer calls (ATP): " ^
+    str3 (avg_time time_prover_fail (calls - success))
 
-fun log_re_data sh_calls (ReData {calls, success, nontriv_calls,
-     nontriv_success, proofs, time, timeout, lemmas = (lemmas, lems_sos, lems_max), posns}) =
+fun log_re_data sh_calls (ReData {calls, success, nontriv_calls, nontriv_success, proofs, time,
+      timeout, lemmas = (lemmas, lems_sos, lems_max), posns}) =
   let
     val proved =
       posns |> map (fn (pos, triv) =>
         str0 (Position.line_of pos) ^ ":" ^ str0 (Position.offset_of pos) ^
         (if triv then "[T]" else ""))
-    val props =
-     [("re_calls", str calls),
-      ("re_success", str success),
-      ("re_nontriv_calls", str nontriv_calls),
-      ("re_nontriv_success", str nontriv_success),
-      ("re_proofs", str proofs),
-      ("re_time", str3 (ms time)),
-      ("re_timeout", str timeout),
-      ("re_lemmas", str lemmas),
-      ("re_lems_sos", str lems_sos),
-      ("re_lems_max", str lems_max),
-      ("re_proved", space_implode "," proved)]
-    val text =
-      "\nTotal number of proof method calls: " ^ str calls ^
-      "\nNumber of successful proof method calls: " ^ str success ^
-        " (proof: " ^ str proofs ^ ")" ^
-      "\nNumber of proof method timeouts: " ^ str timeout ^
-      "\nSuccess rate: " ^ percentage success sh_calls ^ "%" ^
-      "\nTotal number of nontrivial proof method calls: " ^ str nontriv_calls ^
-      "\nNumber of successful nontrivial proof method calls: " ^ str nontriv_success ^
-        " (proof: " ^ str proofs ^ ")" ^
-      "\nNumber of successful proof method lemmas: " ^ str lemmas ^
-      "\nSOS of successful proof method lemmas: " ^ str lems_sos ^
-      "\nMax number of successful proof method lemmas: " ^ str lems_max ^
-      "\nTotal time for successful proof method calls: " ^ str3 (ms time) ^
-      "\nAverage time for successful proof method calls: " ^ str3 (avg_time time success) ^
-      "\nProved: " ^ space_implode " " proved
-  in (props, text) end
+  in
+    "\nTotal number of proof method calls: " ^ str calls ^
+    "\nNumber of successful proof method calls: " ^ str success ^
+      " (proof: " ^ str proofs ^ ")" ^
+    "\nNumber of proof method timeouts: " ^ str timeout ^
+    "\nSuccess rate: " ^ percentage success sh_calls ^ "%" ^
+    "\nTotal number of nontrivial proof method calls: " ^ str nontriv_calls ^
+    "\nNumber of successful nontrivial proof method calls: " ^ str nontriv_success ^
+      " (proof: " ^ str proofs ^ ")" ^
+    "\nNumber of successful proof method lemmas: " ^ str lemmas ^
+    "\nSOS of successful proof method lemmas: " ^ str lems_sos ^
+    "\nMax number of successful proof method lemmas: " ^ str lems_max ^
+    "\nTotal time for successful proof method calls: " ^ str3 (ms time) ^
+    "\nAverage time for successful proof method calls: " ^ str3 (avg_time time success) ^
+    "\nProved: " ^ space_implode " " proved
+  end
 
 in
 
-fun log_data index (Data {sh, re_u}) =
+fun log_data (Data {sh, re_u}) =
   let
     val ShData {calls=sh_calls, ...} = sh
     val ReData {calls=re_calls, ...} = re_u
   in
     if sh_calls > 0 then
-      let
-        val (props1, text1) = log_sh_data sh
-        val (props2, text2) = log_re_data sh_calls re_u
-        val text =
-          "\n\nReport #" ^ string_of_int index ^ ":\n" ^
-          (if re_calls > 0 then text1 ^ "\n" ^ text2 else text1)
-      in [Mirabelle.log_report (props1 @ props2) [XML.Text text]] end
-    else []
+      let val text1 = log_sh_data sh in
+        if re_calls > 0 then text1 ^ "\n" ^ log_re_data sh_calls re_u else text1
+      end
+    else
+      ""
   end
 
 end
@@ -375,7 +346,7 @@
     fun set_file_name (SOME dir) =
         Config.put Sledgehammer_Prover_ATP.atp_dest_dir dir
         #> Config.put Sledgehammer_Prover_ATP.atp_problem_prefix
-          ("prob_" ^ str0 (Position.line_of pos) ^ "__")
+          ("prob_" ^ StringCvt.padLeft #"0" 5 (str0 (Position.line_of pos)) ^ "__")
         #> Config.put SMT_Config.debug_files
           (dir ^ "/" ^ Name.desymbolize (SOME false) (ATP_Util.timestamp ()) ^ "_"
           ^ serial_string ())
@@ -457,9 +428,10 @@
 
 in
 
-fun run_sledgehammer change_data trivial args proof_method named_thms pos st =
+fun run_sledgehammer change_data thy_index trivial args proof_method named_thms pos st =
   let
     val thy = Proof.theory_of st
+    val thy_name = Context.theory_name thy
     val triv_str = if trivial then "[T] " else ""
     val _ = change_data inc_sh_calls
     val _ = if trivial then () else change_data inc_sh_nontriv_calls
@@ -482,6 +454,12 @@
     val force_sos = AList.lookup (op =) args force_sosK
       |> Option.map (curry (op <>) "false")
     val dir = AList.lookup (op =) args keepK
+      |> Option.map (fn dir =>
+        let val subdir = StringCvt.padLeft #"0" 4 (string_of_int thy_index) ^ "_" ^ thy_name in
+          Path.append (Path.explode dir) (Path.basic subdir)
+          |> Isabelle_System.make_directory
+          |> Path.implode
+        end)
     val timeout = Mirabelle.get_int_argument args (prover_timeoutK, 30)
     (* always use a hard timeout, but give some slack so that the automatic
        minimizer has a chance to do its magic *)
@@ -587,14 +565,14 @@
       Mirabelle.can_apply timeout (do_method named_thms) st
 
     fun with_time (false, t) = "failed (" ^ string_of_int t ^ ")"
-      | with_time (true, t) = (change_data inc_proof_method_success;
-          if trivial then ()
-          else change_data inc_proof_method_nontriv_success;
-          change_data (inc_proof_method_lemmas (length named_thms));
-          change_data (inc_proof_method_time t);
-          change_data (inc_proof_method_posns (pos, trivial));
-          if name = "proof" then change_data inc_proof_method_proofs else ();
-          "succeeded (" ^ string_of_int t ^ ")")
+      | with_time (true, t) =
+          (change_data inc_proof_method_success;
+           if trivial then () else change_data inc_proof_method_nontriv_success;
+           change_data (inc_proof_method_lemmas (length named_thms));
+           change_data (inc_proof_method_time t);
+           change_data (inc_proof_method_posns (pos, trivial));
+           if name = "proof" then change_data inc_proof_method_proofs else ();
+           "succeeded (" ^ string_of_int t ^ ")")
     fun timed_method named_thms =
       with_time (Mirabelle.cpu_time apply_method named_thms)
         handle Timeout.TIMEOUT _ => (change_data inc_proof_method_timeout; "timeout")
@@ -606,70 +584,40 @@
 
 val try_timeout = seconds 5.0
 
-fun catch e =
-  e () handle exn =>
-    if Exn.is_interrupt exn then Exn.reraise exn
-    else Mirabelle.print_exn exn
-
-(* crude hack *)
-val num_sledgehammer_calls = Unsynchronized.ref 0
+fun make_action ({arguments, timeout, ...} : Mirabelle.action_context) =
+  let
+    val check_trivial =
+      Mirabelle.get_bool_argument arguments (check_trivialK, check_trivial_default)
 
-val _ =
-  Theory.setup (Mirabelle.theory_action \<^binding>\<open>sledgehammer\<close>
-    (fn context => fn commands =>
-      let
-        val {index, tag = sh_tag, arguments = args, timeout, ...} = context
-        fun proof_method_tag meth = "#" ^ string_of_int index ^ " " ^ meth ^ " (sledgehammer): "
-
-        val data = Unsynchronized.ref empty_data
-        val change_data = Unsynchronized.change data
-
-        val max_calls = Mirabelle.get_int_argument args (max_callsK, max_calls_default)
-        val check_trivial = Mirabelle.get_bool_argument args (check_trivialK, check_trivial_default)
+    val data = Synchronized.var "Mirabelle_Sledgehammer.data" empty_data
+    val change_data = Synchronized.change data
 
-        val results =
-          commands |> maps (fn command =>
-            let
-              val {name, pos, pre = st, ...} = command
-              val goal = Thm.major_prem_of (#goal (Proof.goal st))
-              val log =
-                if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then []
-                else
-                  let
-                    val _ = Unsynchronized.inc num_sledgehammer_calls
-                  in
-                    if !num_sledgehammer_calls > max_calls then
-                      ["Skipping because max number of calls reached"]
-                    else
-                      let
-                        val meth = Unsynchronized.ref ""
-                        val named_thms =
-                          Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
-                        val trivial =
-                          if check_trivial then
-                            Try0.try0 (SOME try_timeout) ([], [], [], []) st
-                              handle Timeout.TIMEOUT _ => false
-                          else false
-                        val log1 =
-                          sh_tag ^ catch (fn () =>
-                            run_sledgehammer change_data trivial args meth named_thms pos st)
-                        val log2 =
-                          (case ! named_thms of
-                            SOME thms =>
-                              [separator,
-                               proof_method_tag (!meth) ^
-                               catch (fn () =>
-                                  run_proof_method change_data trivial false name meth thms
-                                    timeout pos st)]
-                          | NONE => [])
-                      in log1 :: log2 end
-                  end
-            in
-              if null log then []
-              else [Mirabelle.log_command command [XML.Text (cat_lines log)]]
-            end)
+    fun run_action ({theory_index, name, pos, pre, ...} : Mirabelle.command) =
+      let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
+        if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then
+          ""
+        else
+          let
+            val meth = Unsynchronized.ref ""
+            val named_thms =
+              Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
+            val trivial =
+              check_trivial andalso Try0.try0 (SOME try_timeout) ([], [], [], []) pre
+              handle Timeout.TIMEOUT _ => false
+            val log1 =
+              run_sledgehammer change_data theory_index trivial arguments meth named_thms pos pre
+            val log2 =
+              (case !named_thms of
+                SOME thms =>
+                !meth ^ " (sledgehammer): " ^ run_proof_method change_data trivial false name meth
+                  thms timeout pos pre
+              | NONE => "")
+          in log1 ^ "\n" ^ log2 end
+      end
 
-        val report = log_data index (! data)
-      in results @ report end))
+    fun finalize () = log_data (Synchronized.value data)
+  in {run_action = run_action, finalize = finalize} end
+
+val () = Mirabelle.register_action "sledgehammer" make_action
 
 end