added stride option to Mirabelle
authordesharna
Fri, 12 Feb 2021 11:18:44 +0100
changeset 73289 a34b49841585
parent 73288 f6f1242ed367
child 73290 dcf295994c90
added stride option to Mirabelle
src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML
--- a/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Fri Feb 12 11:18:12 2021 +0100
+++ b/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Fri Feb 12 11:18:44 2021 +0100
@@ -33,6 +33,7 @@
 val sliceK = "slice" (*=BOOL: allow sledgehammer-level strategy-scheduling*)
 val smt_proofsK = "smt_proofs" (*=BOOL: enable SMT proof generation*)
 val strictK = "strict" (*=BOOL: run in strict mode*)
+val strideK = "stride" (*=NUM: run every nth goal*)
 val term_orderK = "term_order" (*=STRING: term order (in E)*)
 val type_encK = "type_enc" (*=STRING: type encoding scheme*)
 val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*)
@@ -50,9 +51,10 @@
 val fact_filter_default = "smart"
 val type_enc_default = "smart"
 val strict_default = "false"
+val stride_default = 1
 val max_facts_default = "smart"
 val slice_default = "true"
-val max_calls_default = "10000000"
+val max_calls_default = 10000000
 val trivial_default = "false"
 
 (*If a key is present in args then augment a list with its pair*)
@@ -605,37 +607,49 @@
 
 (* crude hack *)
 val num_sledgehammer_calls = Unsynchronized.ref 0
+val remaining_stride = Unsynchronized.ref stride_default
 
-fun sledgehammer_action args id (st as {pre, name, ...}: Mirabelle.run_args) =
-  let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
-    if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal
-    then () else
-    let
-      val max_calls =
-        AList.lookup (op =) args max_callsK |> the_default max_calls_default
-        |> Int.fromString |> the
-      val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1;
-    in
-      if !num_sledgehammer_calls > max_calls then ()
-      else
-        let
-          val meth = Unsynchronized.ref ""
-          val named_thms =
-            Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
-          val trivial =
-            if AList.lookup (op =) args check_trivialK |> the_default trivial_default
-                            |> Value.parse_bool then
-              Try0.try0 (SOME try_timeout) ([], [], [], []) pre
-              handle Timeout.TIMEOUT _ => false
-            else false
-          fun apply_method () =
-            (Mirabelle.catch_result (proof_method_tag meth) false
-              (run_proof_method trivial false name meth (these (!named_thms))) id st; ())
-        in
-          Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st;
-          if is_some (!named_thms) then apply_method () else ()
-        end
-    end
+fun sledgehammer_action args =
+  let
+    val stride = Mirabelle.get_int_setting args (strideK, stride_default)
+    val max_calls = Mirabelle.get_int_setting args (max_callsK, max_calls_default)
+  in
+    fn id => fn (st as {pre, name, log, ...}: Mirabelle.run_args) =>
+      let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
+        if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then
+          ()
+        else if !remaining_stride > 1 then
+          (* We still have some steps to do *)
+          (remaining_stride := !remaining_stride - 1;
+          log "Skipping because of stride")
+        else
+          (* This was the last step, now run the action *)
+          let
+            val _ = remaining_stride := stride
+            val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1
+          in
+            if !num_sledgehammer_calls > max_calls then
+              log "Skipping because max number of calls reached"
+            else
+              let
+                val meth = Unsynchronized.ref ""
+                val named_thms =
+                  Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
+                val trivial =
+                  if AList.lookup (op =) args check_trivialK |> the_default trivial_default
+                                  |> Value.parse_bool then
+                    Try0.try0 (SOME try_timeout) ([], [], [], []) pre
+                    handle Timeout.TIMEOUT _ => false
+                  else false
+                fun apply_method () =
+                  (Mirabelle.catch_result (proof_method_tag meth) false
+                    (run_proof_method trivial false name meth (these (!named_thms))) id st; ())
+              in
+                Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st;
+                if is_some (!named_thms) then apply_method () else ()
+              end
+          end
+      end
   end
 
 fun invoke args =