--- a/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML Fri Feb 12 11:18:12 2021 +0100
+++ b/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML Fri Feb 12 11:18:44 2021 +0100
@@ -33,6 +33,7 @@
val sliceK = "slice" (*=BOOL: allow sledgehammer-level strategy-scheduling*)
val smt_proofsK = "smt_proofs" (*=BOOL: enable SMT proof generation*)
val strictK = "strict" (*=BOOL: run in strict mode*)
+val strideK = "stride" (*=NUM: run every nth goal*)
val term_orderK = "term_order" (*=STRING: term order (in E)*)
val type_encK = "type_enc" (*=STRING: type encoding scheme*)
val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*)
@@ -50,9 +51,10 @@
val fact_filter_default = "smart"
val type_enc_default = "smart"
val strict_default = "false"
+val stride_default = 1
val max_facts_default = "smart"
val slice_default = "true"
-val max_calls_default = "10000000"
+val max_calls_default = 10000000
val trivial_default = "false"
(*If a key is present in args then augment a list with its pair*)
@@ -605,37 +607,49 @@
(* crude hack *)
val num_sledgehammer_calls = Unsynchronized.ref 0
+val remaining_stride = Unsynchronized.ref stride_default
-fun sledgehammer_action args id (st as {pre, name, ...}: Mirabelle.run_args) =
- let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
- if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal
- then () else
- let
- val max_calls =
- AList.lookup (op =) args max_callsK |> the_default max_calls_default
- |> Int.fromString |> the
- val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1;
- in
- if !num_sledgehammer_calls > max_calls then ()
- else
- let
- val meth = Unsynchronized.ref ""
- val named_thms =
- Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
- val trivial =
- if AList.lookup (op =) args check_trivialK |> the_default trivial_default
- |> Value.parse_bool then
- Try0.try0 (SOME try_timeout) ([], [], [], []) pre
- handle Timeout.TIMEOUT _ => false
- else false
- fun apply_method () =
- (Mirabelle.catch_result (proof_method_tag meth) false
- (run_proof_method trivial false name meth (these (!named_thms))) id st; ())
- in
- Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st;
- if is_some (!named_thms) then apply_method () else ()
- end
- end
+fun sledgehammer_action args =
+ let
+ val stride = Mirabelle.get_int_setting args (strideK, stride_default)
+ val max_calls = Mirabelle.get_int_setting args (max_callsK, max_calls_default)
+ in
+ fn id => fn (st as {pre, name, log, ...}: Mirabelle.run_args) =>
+ let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
+ if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then
+ ()
+ else if !remaining_stride > 1 then
+ (* We still have some steps to do *)
+ (remaining_stride := !remaining_stride - 1;
+ log "Skipping because of stride")
+ else
+ (* This was the last step, now run the action *)
+ let
+ val _ = remaining_stride := stride
+ val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1
+ in
+ if !num_sledgehammer_calls > max_calls then
+ log "Skipping because max number of calls reached"
+ else
+ let
+ val meth = Unsynchronized.ref ""
+ val named_thms =
+ Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
+ val trivial =
+ if AList.lookup (op =) args check_trivialK |> the_default trivial_default
+ |> Value.parse_bool then
+ Try0.try0 (SOME try_timeout) ([], [], [], []) pre
+ handle Timeout.TIMEOUT _ => false
+ else false
+ fun apply_method () =
+ (Mirabelle.catch_result (proof_method_tag meth) false
+ (run_proof_method trivial false name meth (these (!named_thms))) id st; ())
+ in
+ Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st;
+ if is_some (!named_thms) then apply_method () else ()
+ end
+ end
+ end
end
fun invoke args =