# HG changeset patch # User desharna # Date 1613125124 -3600 # Node ID a34b49841585d8b5c904437a935a951ebf65768f # Parent f6f1242ed3672872147725b36bdcea582a20624f added stride option to Mirabelle diff -r f6f1242ed367 -r a34b49841585 src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML --- a/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML Fri Feb 12 11:18:12 2021 +0100 +++ b/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML Fri Feb 12 11:18:44 2021 +0100 @@ -33,6 +33,7 @@ val sliceK = "slice" (*=BOOL: allow sledgehammer-level strategy-scheduling*) val smt_proofsK = "smt_proofs" (*=BOOL: enable SMT proof generation*) val strictK = "strict" (*=BOOL: run in strict mode*) +val strideK = "stride" (*=NUM: run every nth goal*) val term_orderK = "term_order" (*=STRING: term order (in E)*) val type_encK = "type_enc" (*=STRING: type encoding scheme*) val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*) @@ -50,9 +51,10 @@ val fact_filter_default = "smart" val type_enc_default = "smart" val strict_default = "false" +val stride_default = 1 val max_facts_default = "smart" val slice_default = "true" -val max_calls_default = "10000000" +val max_calls_default = 10000000 val trivial_default = "false" (*If a key is present in args then augment a list with its pair*) @@ -605,37 +607,49 @@ (* crude hack *) val num_sledgehammer_calls = Unsynchronized.ref 0 +val remaining_stride = Unsynchronized.ref stride_default -fun sledgehammer_action args id (st as {pre, name, ...}: Mirabelle.run_args) = - let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in - if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal - then () else - let - val max_calls = - AList.lookup (op =) args max_callsK |> the_default max_calls_default - |> Int.fromString |> the - val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1; - in - if !num_sledgehammer_calls > max_calls then () - else - let - val meth = Unsynchronized.ref "" - val named_thms = - Unsynchronized.ref (NONE : ((string * stature) * thm list) list option) - val trivial = - if AList.lookup (op =) args check_trivialK |> the_default trivial_default - |> Value.parse_bool then - Try0.try0 (SOME try_timeout) ([], [], [], []) pre - handle Timeout.TIMEOUT _ => false - else false - fun apply_method () = - (Mirabelle.catch_result (proof_method_tag meth) false - (run_proof_method trivial false name meth (these (!named_thms))) id st; ()) - in - Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st; - if is_some (!named_thms) then apply_method () else () - end - end +fun sledgehammer_action args = + let + val stride = Mirabelle.get_int_setting args (strideK, stride_default) + val max_calls = Mirabelle.get_int_setting args (max_callsK, max_calls_default) + in + fn id => fn (st as {pre, name, log, ...}: Mirabelle.run_args) => + let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in + if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then + () + else if !remaining_stride > 1 then + (* We still have some steps to do *) + (remaining_stride := !remaining_stride - 1; + log "Skipping because of stride") + else + (* This was the last step, now run the action *) + let + val _ = remaining_stride := stride + val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1 + in + if !num_sledgehammer_calls > max_calls then + log "Skipping because max number of calls reached" + else + let + val meth = Unsynchronized.ref "" + val named_thms = + Unsynchronized.ref (NONE : ((string * stature) * thm list) list option) + val trivial = + if AList.lookup (op =) args check_trivialK |> the_default trivial_default + |> Value.parse_bool then + Try0.try0 (SOME try_timeout) ([], [], [], []) pre + handle Timeout.TIMEOUT _ => false + else false + fun apply_method () = + (Mirabelle.catch_result (proof_method_tag meth) false + (run_proof_method trivial false name meth (these (!named_thms))) id st; ()) + in + Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st; + if is_some (!named_thms) then apply_method () else () + end + end + end end fun invoke args =