--- a/src/HOL/Tools/Mirabelle/mirabelle.ML Sat May 15 17:40:36 2021 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle.ML Sat May 15 22:06:05 2021 +0200
@@ -8,7 +8,8 @@
(*core*)
val print_name: string -> string
val print_properties: Properties.T -> string
- type context = {tag: string, arguments: Properties.T, timeout: Time.time, theory: theory}
+ type context =
+ {index: int, tag: string, arguments: Properties.T, timeout: Time.time, theory: theory}
type command = {name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state}
val theory_action: binding -> (context -> command list -> XML.body) -> theory -> theory
val log_command: command -> XML.body -> XML.tree
@@ -48,7 +49,8 @@
(* actions *)
type command = {name: string, pos: Position.T, pre: Proof.state, post: Toplevel.state};
-type context = {tag: string, arguments: Properties.T, timeout: Time.time, theory: theory};
+type context =
+ {index: int, tag: string, arguments: Properties.T, timeout: Time.time, theory: theory};
structure Data = Theory_Data
(
@@ -82,7 +84,7 @@
let
val action = #2 (Name_Space.check (Context.Theory thy) (Data.get thy) (name, Position.none));
val tag = "#" ^ Value.print_int index ^ " " ^ name ^ ": ";
- val context = {tag = tag, arguments = arguments, timeout = timeout, theory = thy};
+ val context = {index = index, tag = tag, arguments = arguments, timeout = timeout, theory = thy};
val export_body = action context commands;
val export_head = log_action name arguments;
val export_name = Path.binding0 (Path.basic "mirabelle" + Path.basic (Value.print_int index));
--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML Sat May 15 17:40:36 2021 +0200
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML Sat May 15 22:06:05 2021 +0200
@@ -1,10 +1,11 @@
(* Title: HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML
Author: Jasmin Blanchette and Sascha Boehme and Tobias Nipkow, TU Munich
+ Author: Makarius
Mirabelle action: "sledgehammer".
*)
-structure Mirabelle_Sledgehammer : MIRABELLE_ACTION =
+structure Mirabelle_Sledgehammer: sig end =
struct
(*To facilitate synching the description of Mirabelle Sledgehammer parameters
@@ -40,9 +41,6 @@
val type_encK = "type_enc" (*=STRING: type encoding scheme*)
val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*)
-fun sh_tag id = "#" ^ string_of_int id ^ " sledgehammer: "
-fun proof_method_tag meth id = "#" ^ string_of_int id ^ " " ^ (!meth) ^ " (sledgehammer): "
-
val separator = "-----"
(*FIXME sensible to have Mirabelle-level Sledgehammer defaults?*)
@@ -121,6 +119,8 @@
re_u: re_data (* proof method with unminimized set of lemmas *)
}
+type change_data = (data -> data) -> unit
+
fun make_data (sh, re_u) = Data {sh=sh, re_u=re_u}
val empty_data = make_data (empty_sh_data, empty_re_data)
@@ -216,89 +216,98 @@
val str = string_of_int
val str3 = Real.fmt (StringCvt.FIX (SOME 3))
fun percentage a b = string_of_int (a * 100 div b)
-fun time t = Real.fromInt t / 1000.0
+fun ms t = Real.fromInt t / 1000.0
fun avg_time t n =
if n > 0 then (Real.fromInt t / 1000.0) / Real.fromInt n else 0.0
-fun log_sh_data log
- (calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail) =
- (log ("Total number of sledgehammer calls: " ^ str calls);
- log ("Number of successful sledgehammer calls: " ^ str success);
- log ("Number of sledgehammer lemmas: " ^ str lemmas);
- log ("Max number of sledgehammer lemmas: " ^ str max_lems);
- log ("Success rate: " ^ percentage success calls ^ "%");
- log ("Total number of nontrivial sledgehammer calls: " ^ str nontriv_calls);
- log ("Number of successful nontrivial sledgehammer calls: " ^ str nontriv_success);
- log ("Total time for sledgehammer calls (Isabelle): " ^ str3 (time time_isa));
- log ("Total time for successful sledgehammer calls (ATP): " ^ str3 (time time_prover));
- log ("Total time for failed sledgehammer calls (ATP): " ^ str3 (time time_prover_fail));
- log ("Average time for sledgehammer calls (Isabelle): " ^
- str3 (avg_time time_isa calls));
- log ("Average time for successful sledgehammer calls (ATP): " ^
- str3 (avg_time time_prover success));
- log ("Average time for failed sledgehammer calls (ATP): " ^
- str3 (avg_time time_prover_fail (calls - success)))
- )
+fun log_sh_data (ShData
+ {calls, success, nontriv_calls, nontriv_success, lemmas, max_lems, time_isa, time_prover, time_prover_fail}) =
+ let
+ val props =
+ [("sh_calls", str calls),
+ ("sh_success", str success),
+ ("sh_nontriv_calls", str nontriv_calls),
+ ("sh_nontriv_success", str nontriv_success),
+ ("sh_lemmas", str lemmas),
+ ("sh_max_lems", str max_lems),
+ ("sh_time_isa", str3 (ms time_isa)),
+ ("sh_time_prover", str3 (ms time_prover)),
+ ("sh_time_prover_fail", str3 (ms time_prover_fail))]
+ val text =
+ "\nTotal number of sledgehammer calls: " ^ str calls ^
+ "\nNumber of successful sledgehammer calls: " ^ str success ^
+ "\nNumber of sledgehammer lemmas: " ^ str lemmas ^
+ "\nMax number of sledgehammer lemmas: " ^ str max_lems ^
+ "\nSuccess rate: " ^ percentage success calls ^ "%" ^
+ "\nTotal number of nontrivial sledgehammer calls: " ^ str nontriv_calls ^
+ "\nNumber of successful nontrivial sledgehammer calls: " ^ str nontriv_success ^
+ "\nTotal time for sledgehammer calls (Isabelle): " ^ str3 (ms time_isa) ^
+ "\nTotal time for successful sledgehammer calls (ATP): " ^ str3 (ms time_prover) ^
+ "\nTotal time for failed sledgehammer calls (ATP): " ^ str3 (ms time_prover_fail) ^
+ "\nAverage time for sledgehammer calls (Isabelle): " ^
+ str3 (avg_time time_isa calls) ^
+ "\nAverage time for successful sledgehammer calls (ATP): " ^
+ str3 (avg_time time_prover success) ^
+ "\nAverage time for failed sledgehammer calls (ATP): " ^
+ str3 (avg_time time_prover_fail (calls - success))
+ in (props, text) end
-fun str_of_pos (pos, triv) =
- str0 (Position.line_of pos) ^ ":" ^ str0 (Position.offset_of pos) ^
- (if triv then "[T]" else "")
-
-fun log_re_data log tag sh_calls (re_calls, re_success, re_nontriv_calls,
- re_nontriv_success, re_proofs, re_time, re_timeout,
- (lemmas, lems_sos, lems_max), re_posns) =
- (log ("Total number of " ^ tag ^ "proof method calls: " ^ str re_calls);
- log ("Number of successful " ^ tag ^ "proof method calls: " ^ str re_success ^
- " (proof: " ^ str re_proofs ^ ")");
- log ("Number of " ^ tag ^ "proof method timeouts: " ^ str re_timeout);
- log ("Success rate: " ^ percentage re_success sh_calls ^ "%");
- log ("Total number of nontrivial " ^ tag ^ "proof method calls: " ^ str re_nontriv_calls);
- log ("Number of successful nontrivial " ^ tag ^ "proof method calls: " ^ str re_nontriv_success ^
- " (proof: " ^ str re_proofs ^ ")");
- log ("Number of successful " ^ tag ^ "proof method lemmas: " ^ str lemmas);
- log ("SOS of successful " ^ tag ^ "proof method lemmas: " ^ str lems_sos);
- log ("Max number of successful " ^ tag ^ "proof method lemmas: " ^ str lems_max);
- log ("Total time for successful " ^ tag ^ "proof method calls: " ^ str3 (time re_time));
- log ("Average time for successful " ^ tag ^ "proof method calls: " ^
- str3 (avg_time re_time re_success));
- if tag=""
- then log ("Proved: " ^ space_implode " " (map str_of_pos re_posns))
- else ()
- )
+fun log_re_data sh_calls (ReData {calls, success, nontriv_calls,
+ nontriv_success, proofs, time, timeout, lemmas = (lemmas, lems_sos, lems_max), posns}) =
+ let
+ val proved =
+ posns |> map (fn (pos, triv) =>
+ str0 (Position.line_of pos) ^ ":" ^ str0 (Position.offset_of pos) ^
+ (if triv then "[T]" else ""))
+ val props =
+ [("re_calls", str calls),
+ ("re_success", str success),
+ ("re_nontriv_calls", str nontriv_calls),
+ ("re_nontriv_success", str nontriv_success),
+ ("re_proofs", str proofs),
+ ("re_time", str3 (ms time)),
+ ("re_timeout", str timeout),
+ ("re_lemmas", str lemmas),
+ ("re_lems_sos", str lems_sos),
+ ("re_lems_max", str lems_max),
+ ("re_proved", space_implode "," proved)]
+ val text =
+ "\nTotal number of proof method calls: " ^ str calls ^
+ "\nNumber of successful proof method calls: " ^ str success ^
+ " (proof: " ^ str proofs ^ ")" ^
+ "\nNumber of proof method timeouts: " ^ str timeout ^
+ "\nSuccess rate: " ^ percentage success sh_calls ^ "%" ^
+ "\nTotal number of nontrivial proof method calls: " ^ str nontriv_calls ^
+ "\nNumber of successful nontrivial proof method calls: " ^ str nontriv_success ^
+ " (proof: " ^ str proofs ^ ")" ^
+ "\nNumber of successful proof method lemmas: " ^ str lemmas ^
+ "\nSOS of successful proof method lemmas: " ^ str lems_sos ^
+ "\nMax number of successful proof method lemmas: " ^ str lems_max ^
+ "\nTotal time for successful proof method calls: " ^ str3 (ms time) ^
+ "\nAverage time for successful proof method calls: " ^ str3 (avg_time time success) ^
+ "\nProved: " ^ space_implode " " proved
+ in (props, text) end
in
-fun log_data id log (Data {sh, re_u}) =
+fun log_data index (Data {sh, re_u}) =
let
val ShData {calls=sh_calls, ...} = sh
-
- fun app_if (ReData {calls, ...}) f = if calls > 0 then f () else ()
- fun log_re tag m =
- log_re_data log tag sh_calls (tuple_of_re_data m)
- fun log_proof_method (tag1, m1) = app_if m1 (fn () => (log_re tag1 m1; log ""))
+ val ReData {calls=re_calls, ...} = re_u
in
- if sh_calls > 0
- then
- (log ("\n\n\nReport #" ^ string_of_int id ^ ":\n");
- log_sh_data log (tuple_of_sh_data sh);
- log "";
- log_proof_method ("", re_u))
- else ()
+ if sh_calls > 0 then
+ let
+ val (props1, text1) = log_sh_data sh
+ val (props2, text2) = log_re_data sh_calls re_u
+ val text =
+ "\n\nReport #" ^ string_of_int index ^ ":\n" ^
+ (if re_calls > 0 then text1 ^ "\n" ^ text2 else text1)
+ in [Mirabelle.log_report (props1 @ props2) [XML.Text text]] end
+ else []
end
end
-(* Warning: we implicitly assume single-threaded execution here *)
-val data = Unsynchronized.ref ([] : (int * data) list)
-
-fun init id thy = (Unsynchronized.change data (cons (id, empty_data)); thy)
-fun done id ({log, ...}: Mirabelle.done_args) =
- AList.lookup (op =) (!data) id
- |> Option.map (log_data id log)
- |> K ()
-
-fun change_data id f = (Unsynchronized.change data (AList.map_entry (op =) id f); ())
-
fun get_prover_name thy args =
let
fun default_prover_name () =
@@ -450,13 +459,12 @@
in
-fun run_sledgehammer trivial args proof_method named_thms id
- ({pre=st, log, pos, ...}: Mirabelle.run_args) =
+fun run_sledgehammer change_data trivial args proof_method named_thms pos st =
let
val thy = Proof.theory_of st
val triv_str = if trivial then "[T] " else ""
- val _ = change_data id inc_sh_calls
- val _ = if trivial then () else change_data id inc_sh_nontriv_calls
+ val _ = change_data inc_sh_calls
+ val _ = if trivial then () else change_data inc_sh_nontriv_calls
val prover_name = get_prover_name thy args
val fact_filter = AList.lookup (op =) args fact_filterK |> the_default fact_filter_default
val type_enc = AList.lookup (op =) args type_encK |> the_default type_enc_default
@@ -476,7 +484,7 @@
val force_sos = AList.lookup (op =) args force_sosK
|> Option.map (curry (op <>) "false")
val dir = AList.lookup (op =) args keepK
- val timeout = Mirabelle.get_int_setting args (prover_timeoutK, 30)
+ val timeout = Mirabelle.get_int_argument args (prover_timeoutK, 30)
(* always use a hard timeout, but give some slack so that the automatic
minimizer has a chance to do its magic *)
val preplay_timeout = AList.lookup (op =) args preplay_timeoutK
@@ -502,23 +510,23 @@
name
|> Option.map (pair (name, stature))
in
- change_data id inc_sh_success;
- if trivial then () else change_data id inc_sh_nontriv_success;
- change_data id (inc_sh_lemmas (length names));
- change_data id (inc_sh_max_lems (length names));
- change_data id (inc_sh_time_isa time_isa);
- change_data id (inc_sh_time_prover time_prover);
+ change_data inc_sh_success;
+ if trivial then () else change_data inc_sh_nontriv_success;
+ change_data (inc_sh_lemmas (length names));
+ change_data (inc_sh_max_lems (length names));
+ change_data (inc_sh_time_isa time_isa);
+ change_data (inc_sh_time_prover time_prover);
proof_method := proof_method_from_msg args msg;
named_thms := SOME (map_filter get_thms names);
- log (sh_tag id ^ triv_str ^ "succeeded (" ^ string_of_int time_isa ^ "+" ^
- string_of_int time_prover ^ ") [" ^ prover_name ^ "]:\n" ^ msg)
+ triv_str ^ "succeeded (" ^ string_of_int time_isa ^ "+" ^
+ string_of_int time_prover ^ ") [" ^ prover_name ^ "]:\n" ^ msg
end
| SH_FAIL (time_isa, time_prover) =>
let
- val _ = change_data id (inc_sh_time_isa time_isa)
- val _ = change_data id (inc_sh_time_prover_fail time_prover)
- in log (sh_tag id ^ triv_str ^ "failed: " ^ msg) end
- | SH_ERROR => log (sh_tag id ^ "failed: " ^ msg))
+ val _ = change_data (inc_sh_time_isa time_isa)
+ val _ = change_data (inc_sh_time_prover_fail time_prover)
+ in triv_str ^ "failed: " ^ msg end
+ | SH_ERROR => "failed: " ^ msg)
end
end
@@ -531,8 +539,7 @@
("slice", "false"),
("timeout", timeout |> Time.toSeconds |> string_of_int)]
-fun run_proof_method trivial full name meth named_thms id
- ({pre=st, timeout, log, pos, ...}: Mirabelle.run_args) =
+fun run_proof_method change_data trivial full name meth named_thms timeout pos st =
let
fun do_method named_thms ctxt =
let
@@ -582,79 +589,96 @@
Mirabelle.can_apply timeout (do_method named_thms) st
fun with_time (false, t) = "failed (" ^ string_of_int t ^ ")"
- | with_time (true, t) = (change_data id inc_proof_method_success;
+ | with_time (true, t) = (change_data inc_proof_method_success;
if trivial then ()
- else change_data id inc_proof_method_nontriv_success;
- change_data id (inc_proof_method_lemmas (length named_thms));
- change_data id (inc_proof_method_time t);
- change_data id (inc_proof_method_posns (pos, trivial));
- if name = "proof" then change_data id inc_proof_method_proofs else ();
+ else change_data inc_proof_method_nontriv_success;
+ change_data (inc_proof_method_lemmas (length named_thms));
+ change_data (inc_proof_method_time t);
+ change_data (inc_proof_method_posns (pos, trivial));
+ if name = "proof" then change_data inc_proof_method_proofs else ();
"succeeded (" ^ string_of_int t ^ ")")
fun timed_method named_thms =
- (with_time (Mirabelle.cpu_time apply_method named_thms), true)
- handle Timeout.TIMEOUT _ => (change_data id inc_proof_method_timeout; ("timeout", false))
- | ERROR msg => ("error: " ^ msg, false)
+ with_time (Mirabelle.cpu_time apply_method named_thms)
+ handle Timeout.TIMEOUT _ => (change_data inc_proof_method_timeout; "timeout")
+ | ERROR msg => ("error: " ^ msg)
- val _ = log separator
- val _ = change_data id inc_proof_method_calls
- val _ = if trivial then () else change_data id inc_proof_method_nontriv_calls
- in
- named_thms
- |> timed_method
- |>> log o prefix (proof_method_tag meth id)
- |> snd
- end
+ val _ = change_data inc_proof_method_calls
+ val _ = if trivial then () else change_data inc_proof_method_nontriv_calls
+ in timed_method named_thms end
val try_timeout = seconds 5.0
+fun catch e =
+ e () handle exn =>
+ if Exn.is_interrupt exn then Exn.reraise exn
+ else Mirabelle.print_exn exn
+
(* crude hack *)
val num_sledgehammer_calls = Unsynchronized.ref 0
val remaining_stride = Unsynchronized.ref stride_default
-fun sledgehammer_action args =
- let
- val stride = Mirabelle.get_int_setting args (strideK, stride_default)
- val max_calls = Mirabelle.get_int_setting args (max_callsK, max_calls_default)
- val check_trivial = Mirabelle.get_bool_setting args (check_trivialK, check_trivial_default)
- in
- fn id => fn (st as {pre, name, log, ...}: Mirabelle.run_args) =>
- let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
- if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then
- ()
- else if !remaining_stride > 1 then
- (* We still have some steps to do *)
- (remaining_stride := !remaining_stride - 1;
- log "Skipping because of stride")
- else
- (* This was the last step, now run the action *)
- let
- val _ = remaining_stride := stride
- val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1
- in
- if !num_sledgehammer_calls > max_calls then
- log "Skipping because max number of calls reached"
- else
- let
- val meth = Unsynchronized.ref ""
- val named_thms =
- Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
- val trivial =
- if check_trivial then
- Try0.try0 (SOME try_timeout) ([], [], [], []) pre
- handle Timeout.TIMEOUT _ => false
- else false
- fun apply_method () =
- (Mirabelle.catch_result (proof_method_tag meth) false
- (run_proof_method trivial false name meth (these (!named_thms))) id st; ())
- in
- Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st;
- if is_some (!named_thms) then apply_method () else ()
- end
- end
- end
- end
+val _ =
+ Theory.setup (Mirabelle.theory_action \<^binding>\<open>sledgehammer\<close>
+ (fn context => fn commands =>
+ let
+ val {index, tag = sh_tag, arguments = args, timeout, ...} = context
+ fun proof_method_tag meth = "#" ^ string_of_int index ^ " " ^ meth ^ " (sledgehammer): "
+
+ val data = Unsynchronized.ref empty_data
+ val change_data = Unsynchronized.change data
+
+ val stride = Mirabelle.get_int_argument args (strideK, stride_default)
+ val max_calls = Mirabelle.get_int_argument args (max_callsK, max_calls_default)
+ val check_trivial = Mirabelle.get_bool_argument args (check_trivialK, check_trivial_default)
-fun invoke args =
- Mirabelle.register (init, sledgehammer_action args, done)
+ val results =
+ commands |> maps (fn command =>
+ let
+ val {name, pos, pre = st, ...} = command
+ val goal = Thm.major_prem_of (#goal (Proof.goal st))
+ val log =
+ if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then []
+ else if !remaining_stride > 1 then
+ (* We still have some steps to do *)
+ (Unsynchronized.dec remaining_stride; ["Skipping because of stride"])
+ else
+ (* This was the last step, now run the action *)
+ let
+ val _ = remaining_stride := stride
+ val _ = Unsynchronized.inc num_sledgehammer_calls
+ in
+ if !num_sledgehammer_calls > max_calls then
+ ["Skipping because max number of calls reached"]
+ else
+ let
+ val meth = Unsynchronized.ref ""
+ val named_thms =
+ Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
+ val trivial =
+ if check_trivial then
+ Try0.try0 (SOME try_timeout) ([], [], [], []) st
+ handle Timeout.TIMEOUT _ => false
+ else false
+ val log1 =
+ sh_tag ^ catch (fn () =>
+ run_sledgehammer change_data trivial args meth named_thms pos st)
+ val log2 =
+ (case ! named_thms of
+ SOME thms =>
+ [separator,
+ proof_method_tag (!meth) ^
+ catch (fn () =>
+ run_proof_method change_data trivial false name meth thms
+ timeout pos st)]
+ | NONE => [])
+ in log1 :: log2 end
+ end
+ in
+ if null log then []
+ else [Mirabelle.log_command command [XML.Text (cat_lines log)]]
+ end)
+
+ val report = log_data index (! data)
+ in results @ report end))
end