moved stride option from sledgehammer action to main mirabelle
authordesharna
Fri, 04 Jun 2021 23:03:12 +0200
changeset 73797 f7ea394490f5
parent 73796 56f31baaa837
child 73806 b982362eeca4
moved stride option from sledgehammer action to main mirabelle
src/HOL/Tools/Mirabelle/mirabelle.ML
src/HOL/Tools/Mirabelle/mirabelle.scala
src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML
src/HOL/Tools/etc/options
--- a/src/HOL/Tools/Mirabelle/mirabelle.ML	Thu Jun 03 10:58:15 2021 +0100
+++ b/src/HOL/Tools/Mirabelle/mirabelle.ML	Fri Jun 04 23:03:12 2021 +0200
@@ -152,11 +152,21 @@
 
 val whitelist = ["apply", "by", "proof"];
 
+fun filter_index f =
+  let
+    fun filter_aux _ [] = []
+      | filter_aux n (x :: xs) =
+        if f (n, x) then x :: filter_aux (n + 1) xs else filter_aux (n + 1) xs
+  in
+    filter_aux 0
+  end
+
 val _ =
   Theory.setup (Thy_Info.add_presentation (fn context => fn thy =>
     let
       val {options, adjust_pos, segments, ...} = context;
       val mirabelle_timeout = Options.seconds options \<^system_option>\<open>mirabelle_timeout\<close>;
+      val mirabelle_stride = Options.int options \<^system_option>\<open>mirabelle_stride\<close>;
       val mirabelle_actions = Options.string options \<^system_option>\<open>mirabelle_actions\<close>;
       val mirabelle_theories = Options.string options \<^system_option>\<open>mirabelle_theories\<close>;
 
@@ -166,7 +176,8 @@
         | NONE => error ("Failed to parse mirabelle_actions: " ^ quote mirabelle_actions));
       val check = check_theories (space_explode "," mirabelle_theories);
       val commands =
-        segments |> map_filter (fn {command = tr, prev_state = st, state = st', ...} =>
+        segments
+        |> map_filter (fn {command = tr, prev_state = st, state = st', ...} =>
           let
             val name = Toplevel.name_of tr;
             val pos = adjust_pos (Toplevel.pos_of tr);
@@ -176,7 +187,8 @@
               check (Context.theory_long_name thy) pos
             then SOME {name = name, pos = pos, pre = Toplevel.proof_of st, post = st'}
             else NONE
-          end);
+          end)
+        |> filter_index (fn (n, _) => n mod mirabelle_stride = 0);
 
       fun apply (i, (name, arguments)) () =
         apply_action (i + 1) name arguments mirabelle_timeout commands thy;
--- a/src/HOL/Tools/Mirabelle/mirabelle.scala	Thu Jun 03 10:58:15 2021 +0100
+++ b/src/HOL/Tools/Mirabelle/mirabelle.scala	Fri Jun 04 23:03:12 2021 +0200
@@ -172,6 +172,7 @@
     var verbose = false
     var exclude_sessions: List[String] = Nil
 
+    val default_stride = options.int("mirabelle_stride")
     val default_timeout = options.seconds("mirabelle_timeout")
 
     val getopts = Getopts("""
@@ -182,7 +183,7 @@
     -B NAME      include session NAME and all descendants
     -D DIR       include session directory and select its sessions
     -N           cyclic shuffling of NUMA CPU nodes (performance tuning)
-    -O DIR       output directory for log files (default: """ + default_output_dir + """,
+    -O DIR       output directory for log files (default: """ + default_output_dir + """)
     -R           refer to requirements of selected sessions
     -T THEORY    theory restriction: NAME or NAME[LINE:END_LINE]
     -X NAME      exclude sessions from group NAME and all descendants
@@ -191,6 +192,7 @@
     -g NAME      select session group NAME
     -j INT       maximum number of parallel jobs (default 1)
     -o OPTION    override Isabelle system OPTION (via NAME=VAL or NAME)
+    -s INT       run actions on every nth goal (default """ + default_stride + """)
     -t SECONDS   timeout for each action (default """ + default_timeout + """)
     -v           verbose
     -x NAME      exclude session NAME and all descendants
@@ -221,6 +223,7 @@
       "g:" -> (arg => session_groups = session_groups ::: List(arg)),
       "j:" -> (arg => max_jobs = Value.Int.parse(arg)),
       "o:" -> (arg => options = options + arg),
+      "s:" -> (arg => options = options + ("mirabelle_stride=" + arg)),
       "t:" -> (arg => options = options + ("mirabelle_timeout=" + arg)),
       "v" -> (_ => verbose = true),
       "x:" -> (arg => exclude_sessions = exclude_sessions ::: List(arg)))
--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Thu Jun 03 10:58:15 2021 +0100
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Fri Jun 04 23:03:12 2021 +0200
@@ -36,7 +36,6 @@
 val sliceK = "slice" (*=BOOL: allow sledgehammer-level strategy-scheduling*)
 val smt_proofsK = "smt_proofs" (*=BOOL: enable SMT proof generation*)
 val strictK = "strict" (*=BOOL: run in strict mode*)
-val strideK = "stride" (*=NUM: run every nth goal*)
 val term_orderK = "term_order" (*=STRING: term order (in E)*)
 val type_encK = "type_enc" (*=STRING: type encoding scheme*)
 val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*)
@@ -51,7 +50,6 @@
 val fact_filter_default = "smart"
 val type_enc_default = "smart"
 val strict_default = "false"
-val stride_default = 1
 val max_facts_default = "smart"
 val slice_default = "true"
 val max_calls_default = 10000000
@@ -615,7 +613,6 @@
 
 (* crude hack *)
 val num_sledgehammer_calls = Unsynchronized.ref 0
-val remaining_stride = Unsynchronized.ref stride_default
 
 val _ =
   Theory.setup (Mirabelle.theory_action \<^binding>\<open>sledgehammer\<close>
@@ -627,7 +624,6 @@
         val data = Unsynchronized.ref empty_data
         val change_data = Unsynchronized.change data
 
-        val stride = Mirabelle.get_int_argument args (strideK, stride_default)
         val max_calls = Mirabelle.get_int_argument args (max_callsK, max_calls_default)
         val check_trivial = Mirabelle.get_bool_argument args (check_trivialK, check_trivial_default)
 
@@ -638,13 +634,8 @@
               val goal = Thm.major_prem_of (#goal (Proof.goal st))
               val log =
                 if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then []
-                else if !remaining_stride > 1 then
-                  (* We still have some steps to do *)
-                  (Unsynchronized.dec remaining_stride; ["Skipping because of stride"])
                 else
-                  (* This was the last step, now run the action *)
                   let
-                    val _ = remaining_stride := stride
                     val _ = Unsynchronized.inc num_sledgehammer_calls
                   in
                     if !num_sledgehammer_calls > max_calls then
--- a/src/HOL/Tools/etc/options	Thu Jun 03 10:58:15 2021 +0100
+++ b/src/HOL/Tools/etc/options	Fri Jun 04 23:03:12 2021 +0200
@@ -53,6 +53,9 @@
 option mirabelle_timeout : real = 30
   -- "default timeout for Mirabelle actions"
 
+option mirabelle_stride : int = 1
+  -- "default stride for running Mirabelle actions on every nth goal"
+
 option mirabelle_actions : string = ""
   -- "Mirabelle actions (outer syntax, separated by semicolons)"