--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML Sun Jun 06 21:39:26 2021 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML Thu Jun 10 11:21:57 2021 +0200
@@ -1,11 +1,14 @@
(* Title: HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML
- Author: Jasmin Blanchette and Sascha Boehme and Tobias Nipkow, TU Munich
+ Author: Jasmin Blanchette, TU Munich
+ Author: Sascha Boehme, TU Munich
+ Author: Tobias Nipkow, TU Munich
Author: Makarius
+ Author: Martin Desharnais, UniBw Munich
Mirabelle action: "sledgehammer".
*)
-structure Mirabelle_Sledgehammer: sig end =
+structure Mirabelle_Sledgehammer: MIRABELLE_ACTION =
struct
(*To facilitate synching the description of Mirabelle Sledgehammer parameters
@@ -23,7 +26,6 @@
val isar_proofsK = "isar_proofs" (*=SMART_BOOL: enable Isar proof generation*)
val keepK = "keep" (*=PATH: path where to keep temporary files created by sledgehammer*)
val lam_transK = "lam_trans" (*=STRING: lambda translation scheme*)
-val max_callsK = "max_calls" (*=NUM: max. no. of calls to sledgehammer*)
val max_factsK = "max_facts" (*=NUM: max. relevant clauses to use*)
val max_mono_itersK = "max_mono_iters" (*=NUM: max. iterations of monomorphiser*)
val max_new_mono_instancesK = "max_new_mono_instances" (*=NUM: max. new monomorphic instances*)
@@ -40,8 +42,6 @@
val type_encK = "type_enc" (*=STRING: type encoding scheme*)
val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*)
-val separator = "-----"
-
(*FIXME sensible to have Mirabelle-level Sledgehammer defaults?*)
(*defaults used in this Mirabelle action*)
val preplay_timeout_default = "1"
@@ -52,7 +52,6 @@
val strict_default = "false"
val max_facts_default = "smart"
val slice_default = "true"
-val max_calls_default = 10000000
val check_trivial_default = false
(*If a key is present in args then augment a list with its pair*)
@@ -193,7 +192,7 @@
fun inc_proof_method_time t = map_re_data
(fn (calls,success,nontriv_calls, nontriv_success, proofs,time,timeout,lemmas,posns)
- => (calls, success, nontriv_calls, nontriv_success, proofs, time + t, timeout, lemmas,posns))
+ => (calls, success, nontriv_calls, nontriv_success, proofs, time + t, timeout, lemmas,posns))
val inc_proof_method_timeout = map_re_data
(fn (calls,success,nontriv_calls, nontriv_success, proofs,time,timeout,lemmas,posns)
@@ -218,90 +217,62 @@
fun avg_time t n =
if n > 0 then (Real.fromInt t / 1000.0) / Real.fromInt n else 0.0
-fun log_sh_data (ShData
- {calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail}) =
- let
- val props =
- [("sh_calls", str calls),
- ("sh_success", str success),
- ("sh_nontriv_calls", str nontriv_calls),
- ("sh_nontriv_success", str nontriv_success),
- ("sh_lemmas", str lemmas),
- ("sh_max_lems", str max_lems),
- ("sh_time_isa", str3 (ms time_isa)),
- ("sh_time_prover", str3 (ms time_prover)),
- ("sh_time_prover_fail", str3 (ms time_prover_fail))]
- val text =
- "\nTotal number of sledgehammer calls: " ^ str calls ^
- "\nNumber of successful sledgehammer calls: " ^ str success ^
- "\nNumber of sledgehammer lemmas: " ^ str lemmas ^
- "\nMax number of sledgehammer lemmas: " ^ str max_lems ^
- "\nSuccess rate: " ^ percentage success calls ^ "%" ^
- "\nTotal number of nontrivial sledgehammer calls: " ^ str nontriv_calls ^
- "\nNumber of successful nontrivial sledgehammer calls: " ^ str nontriv_success ^
- "\nTotal time for sledgehammer calls (Isabelle): " ^ str3 (ms time_isa) ^
- "\nTotal time for successful sledgehammer calls (ATP): " ^ str3 (ms time_prover) ^
- "\nTotal time for failed sledgehammer calls (ATP): " ^ str3 (ms time_prover_fail) ^
- "\nAverage time for sledgehammer calls (Isabelle): " ^
- str3 (avg_time time_isa calls) ^
- "\nAverage time for successful sledgehammer calls (ATP): " ^
- str3 (avg_time time_prover success) ^
- "\nAverage time for failed sledgehammer calls (ATP): " ^
- str3 (avg_time time_prover_fail (calls - success))
- in (props, text) end
+fun log_sh_data (ShData {calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa,
+ time_prover, time_prover_fail}) =
+ "\nTotal number of sledgehammer calls: " ^ str calls ^
+ "\nNumber of successful sledgehammer calls: " ^ str success ^
+ "\nNumber of sledgehammer lemmas: " ^ str lemmas ^
+ "\nMax number of sledgehammer lemmas: " ^ str max_lems ^
+ "\nSuccess rate: " ^ percentage success calls ^ "%" ^
+ "\nTotal number of nontrivial sledgehammer calls: " ^ str nontriv_calls ^
+ "\nNumber of successful nontrivial sledgehammer calls: " ^ str nontriv_success ^
+ "\nTotal time for sledgehammer calls (Isabelle): " ^ str3 (ms time_isa) ^
+ "\nTotal time for successful sledgehammer calls (ATP): " ^ str3 (ms time_prover) ^
+ "\nTotal time for failed sledgehammer calls (ATP): " ^ str3 (ms time_prover_fail) ^
+ "\nAverage time for sledgehammer calls (Isabelle): " ^
+ str3 (avg_time time_isa calls) ^
+ "\nAverage time for successful sledgehammer calls (ATP): " ^
+ str3 (avg_time time_prover success) ^
+ "\nAverage time for failed sledgehammer calls (ATP): " ^
+ str3 (avg_time time_prover_fail (calls - success))
-fun log_re_data sh_calls (ReData {calls, success, nontriv_calls,
- nontriv_success, proofs, time, timeout, lemmas = (lemmas, lems_sos, lems_max), posns}) =
+fun log_re_data sh_calls (ReData {calls, success, nontriv_calls, nontriv_success, proofs, time,
+ timeout, lemmas = (lemmas, lems_sos, lems_max), posns}) =
let
val proved =
posns |> map (fn (pos, triv) =>
str0 (Position.line_of pos) ^ ":" ^ str0 (Position.offset_of pos) ^
(if triv then "[T]" else ""))
- val props =
- [("re_calls", str calls),
- ("re_success", str success),
- ("re_nontriv_calls", str nontriv_calls),
- ("re_nontriv_success", str nontriv_success),
- ("re_proofs", str proofs),
- ("re_time", str3 (ms time)),
- ("re_timeout", str timeout),
- ("re_lemmas", str lemmas),
- ("re_lems_sos", str lems_sos),
- ("re_lems_max", str lems_max),
- ("re_proved", space_implode "," proved)]
- val text =
- "\nTotal number of proof method calls: " ^ str calls ^
- "\nNumber of successful proof method calls: " ^ str success ^
- " (proof: " ^ str proofs ^ ")" ^
- "\nNumber of proof method timeouts: " ^ str timeout ^
- "\nSuccess rate: " ^ percentage success sh_calls ^ "%" ^
- "\nTotal number of nontrivial proof method calls: " ^ str nontriv_calls ^
- "\nNumber of successful nontrivial proof method calls: " ^ str nontriv_success ^
- " (proof: " ^ str proofs ^ ")" ^
- "\nNumber of successful proof method lemmas: " ^ str lemmas ^
- "\nSOS of successful proof method lemmas: " ^ str lems_sos ^
- "\nMax number of successful proof method lemmas: " ^ str lems_max ^
- "\nTotal time for successful proof method calls: " ^ str3 (ms time) ^
- "\nAverage time for successful proof method calls: " ^ str3 (avg_time time success) ^
- "\nProved: " ^ space_implode " " proved
- in (props, text) end
+ in
+ "\nTotal number of proof method calls: " ^ str calls ^
+ "\nNumber of successful proof method calls: " ^ str success ^
+ " (proof: " ^ str proofs ^ ")" ^
+ "\nNumber of proof method timeouts: " ^ str timeout ^
+ "\nSuccess rate: " ^ percentage success sh_calls ^ "%" ^
+ "\nTotal number of nontrivial proof method calls: " ^ str nontriv_calls ^
+ "\nNumber of successful nontrivial proof method calls: " ^ str nontriv_success ^
+ " (proof: " ^ str proofs ^ ")" ^
+ "\nNumber of successful proof method lemmas: " ^ str lemmas ^
+ "\nSOS of successful proof method lemmas: " ^ str lems_sos ^
+ "\nMax number of successful proof method lemmas: " ^ str lems_max ^
+ "\nTotal time for successful proof method calls: " ^ str3 (ms time) ^
+ "\nAverage time for successful proof method calls: " ^ str3 (avg_time time success) ^
+ "\nProved: " ^ space_implode " " proved
+ end
in
-fun log_data index (Data {sh, re_u}) =
+fun log_data (Data {sh, re_u}) =
let
val ShData {calls=sh_calls, ...} = sh
val ReData {calls=re_calls, ...} = re_u
in
if sh_calls > 0 then
- let
- val (props1, text1) = log_sh_data sh
- val (props2, text2) = log_re_data sh_calls re_u
- val text =
- "\n\nReport #" ^ string_of_int index ^ ":\n" ^
- (if re_calls > 0 then text1 ^ "\n" ^ text2 else text1)
- in [Mirabelle.log_report (props1 @ props2) [XML.Text text]] end
- else []
+ let val text1 = log_sh_data sh in
+ if re_calls > 0 then text1 ^ "\n" ^ log_re_data sh_calls re_u else text1
+ end
+ else
+ ""
end
end
@@ -375,7 +346,7 @@
fun set_file_name (SOME dir) =
Config.put Sledgehammer_Prover_ATP.atp_dest_dir dir
#> Config.put Sledgehammer_Prover_ATP.atp_problem_prefix
- ("prob_" ^ str0 (Position.line_of pos) ^ "__")
+ ("prob_" ^ StringCvt.padLeft #"0" 5 (str0 (Position.line_of pos)) ^ "__")
#> Config.put SMT_Config.debug_files
(dir ^ "/" ^ Name.desymbolize (SOME false) (ATP_Util.timestamp ()) ^ "_"
^ serial_string ())
@@ -457,9 +428,10 @@
in
-fun run_sledgehammer change_data trivial args proof_method named_thms pos st =
+fun run_sledgehammer change_data thy_index trivial args proof_method named_thms pos st =
let
val thy = Proof.theory_of st
+ val thy_name = Context.theory_name thy
val triv_str = if trivial then "[T] " else ""
val _ = change_data inc_sh_calls
val _ = if trivial then () else change_data inc_sh_nontriv_calls
@@ -482,6 +454,12 @@
val force_sos = AList.lookup (op =) args force_sosK
|> Option.map (curry (op <>) "false")
val dir = AList.lookup (op =) args keepK
+ |> Option.map (fn dir =>
+ let val subdir = StringCvt.padLeft #"0" 4 (string_of_int thy_index) ^ "_" ^ thy_name in
+ Path.append (Path.explode dir) (Path.basic subdir)
+ |> Isabelle_System.make_directory
+ |> Path.implode
+ end)
val timeout = Mirabelle.get_int_argument args (prover_timeoutK, 30)
(* always use a hard timeout, but give some slack so that the automatic
minimizer has a chance to do its magic *)
@@ -587,14 +565,14 @@
Mirabelle.can_apply timeout (do_method named_thms) st
fun with_time (false, t) = "failed (" ^ string_of_int t ^ ")"
- | with_time (true, t) = (change_data inc_proof_method_success;
- if trivial then ()
- else change_data inc_proof_method_nontriv_success;
- change_data (inc_proof_method_lemmas (length named_thms));
- change_data (inc_proof_method_time t);
- change_data (inc_proof_method_posns (pos, trivial));
- if name = "proof" then change_data inc_proof_method_proofs else ();
- "succeeded (" ^ string_of_int t ^ ")")
+ | with_time (true, t) =
+ (change_data inc_proof_method_success;
+ if trivial then () else change_data inc_proof_method_nontriv_success;
+ change_data (inc_proof_method_lemmas (length named_thms));
+ change_data (inc_proof_method_time t);
+ change_data (inc_proof_method_posns (pos, trivial));
+ if name = "proof" then change_data inc_proof_method_proofs else ();
+ "succeeded (" ^ string_of_int t ^ ")")
fun timed_method named_thms =
with_time (Mirabelle.cpu_time apply_method named_thms)
handle Timeout.TIMEOUT _ => (change_data inc_proof_method_timeout; "timeout")
@@ -606,70 +584,40 @@
val try_timeout = seconds 5.0
-fun catch e =
- e () handle exn =>
- if Exn.is_interrupt exn then Exn.reraise exn
- else Mirabelle.print_exn exn
-
-(* crude hack *)
-val num_sledgehammer_calls = Unsynchronized.ref 0
+fun make_action ({arguments, timeout, ...} : Mirabelle.action_context) =
+ let
+ val check_trivial =
+ Mirabelle.get_bool_argument arguments (check_trivialK, check_trivial_default)
-val _ =
- Theory.setup (Mirabelle.theory_action \<^binding>\<open>sledgehammer\<close>
- (fn context => fn commands =>
- let
- val {index, tag = sh_tag, arguments = args, timeout, ...} = context
- fun proof_method_tag meth = "#" ^ string_of_int index ^ " " ^ meth ^ " (sledgehammer): "
-
- val data = Unsynchronized.ref empty_data
- val change_data = Unsynchronized.change data
-
- val max_calls = Mirabelle.get_int_argument args (max_callsK, max_calls_default)
- val check_trivial = Mirabelle.get_bool_argument args (check_trivialK, check_trivial_default)
+ val data = Synchronized.var "Mirabelle_Sledgehammer.data" empty_data
+ val change_data = Synchronized.change data
- val results =
- commands |> maps (fn command =>
- let
- val {name, pos, pre = st, ...} = command
- val goal = Thm.major_prem_of (#goal (Proof.goal st))
- val log =
- if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then []
- else
- let
- val _ = Unsynchronized.inc num_sledgehammer_calls
- in
- if !num_sledgehammer_calls > max_calls then
- ["Skipping because max number of calls reached"]
- else
- let
- val meth = Unsynchronized.ref ""
- val named_thms =
- Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
- val trivial =
- if check_trivial then
- Try0.try0 (SOME try_timeout) ([], [], [], []) st
- handle Timeout.TIMEOUT _ => false
- else false
- val log1 =
- sh_tag ^ catch (fn () =>
- run_sledgehammer change_data trivial args meth named_thms pos st)
- val log2 =
- (case ! named_thms of
- SOME thms =>
- [separator,
- proof_method_tag (!meth) ^
- catch (fn () =>
- run_proof_method change_data trivial false name meth thms
- timeout pos st)]
- | NONE => [])
- in log1 :: log2 end
- end
- in
- if null log then []
- else [Mirabelle.log_command command [XML.Text (cat_lines log)]]
- end)
+ fun run_action ({theory_index, name, pos, pre, ...} : Mirabelle.command) =
+ let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
+ if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then
+ ""
+ else
+ let
+ val meth = Unsynchronized.ref ""
+ val named_thms =
+ Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
+ val trivial =
+ check_trivial andalso Try0.try0 (SOME try_timeout) ([], [], [], []) pre
+ handle Timeout.TIMEOUT _ => false
+ val log1 =
+ run_sledgehammer change_data theory_index trivial arguments meth named_thms pos pre
+ val log2 =
+ (case !named_thms of
+ SOME thms =>
+ !meth ^ " (sledgehammer): " ^ run_proof_method change_data trivial false name meth
+ thms timeout pos pre
+ | NONE => "")
+ in log1 ^ "\n" ^ log2 end
+ end
- val report = log_data index (! data)
- in results @ report end))
+ fun finalize () = log_data (Synchronized.value data)
+ in {run_action = run_action, finalize = finalize} end
+
+val () = Mirabelle.register_action "sledgehammer" make_action
end