merged
authordesharna
Tue, 23 Feb 2021 10:13:09 +0100
changeset 73293 8b6fa865bac4
parent 73292 f84a93f1de2f (diff)
parent 73287 04c9a2cd7686 (current diff)
child 73294 f0210642e43f
merged
src/HOL/Tools/Sledgehammer/sledgehammer_atp_systems.ML
--- a/src/HOL/Mirabelle/Tools/mirabelle.ML	Mon Feb 22 23:31:59 2021 +0100
+++ b/src/HOL/Mirabelle/Tools/mirabelle.ML	Tue Feb 23 10:13:09 2021 +0100
@@ -32,6 +32,7 @@
   val theorems_of_sucessful_proof : Toplevel.state option -> thm list
   val get_setting : (string * string) list -> string * string -> string
   val get_int_setting : (string * string) list -> string * int -> int
+  val get_bool_setting : (string * string) list -> string * bool -> bool
   val cpu_time : ('a -> 'b) -> 'a -> 'b * int
 end
 
@@ -209,6 +210,12 @@
   | SOME NONE => error ("bad option: " ^ key)
   | NONE => default)
 
+fun get_bool_setting settings (key, default) =
+  (case Option.map Bool.fromString (AList.lookup (op =) settings key) of
+    SOME (SOME i) => i
+  | SOME NONE => error ("bad option: " ^ key)
+  | NONE => default)
+
 fun cpu_time f x =
   let val ({cpu, ...}, y) = Timing.timing f x
   in (y, Time.toMilliseconds cpu) end
--- a/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Mon Feb 22 23:31:59 2021 +0100
+++ b/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Tue Feb 23 10:13:09 2021 +0100
@@ -33,6 +33,7 @@
 val sliceK = "slice" (*=BOOL: allow sledgehammer-level strategy-scheduling*)
 val smt_proofsK = "smt_proofs" (*=BOOL: enable SMT proof generation*)
 val strictK = "strict" (*=BOOL: run in strict mode*)
+val strideK = "stride" (*=NUM: run every nth goal*)
 val term_orderK = "term_order" (*=STRING: term order (in E)*)
 val type_encK = "type_enc" (*=STRING: type encoding scheme*)
 val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*)
@@ -50,10 +51,11 @@
 val fact_filter_default = "smart"
 val type_enc_default = "smart"
 val strict_default = "false"
+val stride_default = 1
 val max_facts_default = "smart"
 val slice_default = "true"
-val max_calls_default = "10000000"
-val trivial_default = "false"
+val max_calls_default = 10000000
+val check_trivial_default = false
 
 (*If a key is present in args then augment a list with its pair*)
 (*This is used to avoid fixing default values at the Mirabelle level, and
@@ -605,37 +607,49 @@
 
 (* crude hack *)
 val num_sledgehammer_calls = Unsynchronized.ref 0
+val remaining_stride = Unsynchronized.ref stride_default
 
-fun sledgehammer_action args id (st as {pre, name, ...}: Mirabelle.run_args) =
-  let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
-    if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal
-    then () else
-    let
-      val max_calls =
-        AList.lookup (op =) args max_callsK |> the_default max_calls_default
-        |> Int.fromString |> the
-      val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1;
-    in
-      if !num_sledgehammer_calls > max_calls then ()
-      else
-        let
-          val meth = Unsynchronized.ref ""
-          val named_thms =
-            Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
-          val trivial =
-            if AList.lookup (op =) args check_trivialK |> the_default trivial_default
-                            |> Value.parse_bool then
-              Try0.try0 (SOME try_timeout) ([], [], [], []) pre
-              handle Timeout.TIMEOUT _ => false
-            else false
-          fun apply_method () =
-            (Mirabelle.catch_result (proof_method_tag meth) false
-              (run_proof_method trivial false name meth (these (!named_thms))) id st; ())
-        in
-          Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st;
-          if is_some (!named_thms) then apply_method () else ()
-        end
-    end
+fun sledgehammer_action args =
+  let
+    val stride = Mirabelle.get_int_setting args (strideK, stride_default)
+    val max_calls = Mirabelle.get_int_setting args (max_callsK, max_calls_default)
+    val check_trivial = Mirabelle.get_bool_setting args (check_trivialK, check_trivial_default)
+  in
+    fn id => fn (st as {pre, name, log, ...}: Mirabelle.run_args) =>
+      let val goal = Thm.major_prem_of (#goal (Proof.goal pre)) in
+        if can Logic.dest_conjunction goal orelse can Logic.dest_equals goal then
+          ()
+        else if !remaining_stride > 1 then
+          (* We still have some steps to do *)
+          (remaining_stride := !remaining_stride - 1;
+          log "Skipping because of stride")
+        else
+          (* This was the last step, now run the action *)
+          let
+            val _ = remaining_stride := stride
+            val _ = num_sledgehammer_calls := !num_sledgehammer_calls + 1
+          in
+            if !num_sledgehammer_calls > max_calls then
+              log "Skipping because max number of calls reached"
+            else
+              let
+                val meth = Unsynchronized.ref ""
+                val named_thms =
+                  Unsynchronized.ref (NONE : ((string * stature) * thm list) list option)
+                val trivial =
+                  if check_trivial then
+                    Try0.try0 (SOME try_timeout) ([], [], [], []) pre
+                    handle Timeout.TIMEOUT _ => false
+                  else false
+                fun apply_method () =
+                  (Mirabelle.catch_result (proof_method_tag meth) false
+                    (run_proof_method trivial false name meth (these (!named_thms))) id st; ())
+              in
+                Mirabelle.catch sh_tag (run_sledgehammer trivial args meth named_thms) id st;
+                if is_some (!named_thms) then apply_method () else ()
+              end
+          end
+      end
   end
 
 fun invoke args =
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_atp_systems.ML	Mon Feb 22 23:31:59 2021 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_atp_systems.ML	Tue Feb 23 10:13:09 2021 +0100
@@ -553,9 +553,9 @@
    prem_role = Conjecture,
    best_slices = fn _ =>
      (* FUDGE *)
-     [(0.333, (((128, "meshN"), THF (Without_FOOL, Monomorphic, THF_Without_Choice), "mono_native_higher", keep_lamsN, false), zipperposition_blsimp)),
+     [(0.333, (((128, "meshN"), THF (Without_FOOL, Polymorphic, THF_Without_Choice), "mono_native_higher", keep_lamsN, false), zipperposition_blsimp)),
       (0.333, (((32, "meshN"), THF (Without_FOOL, Polymorphic, THF_Without_Choice), "poly_native_higher", keep_lamsN, false), zipperposition_s6)),
-      (0.334, (((512, "meshN"), THF (Without_FOOL, Monomorphic, THF_Without_Choice), "mono_native_higher", keep_lamsN, false), zipperposition_cdots))],
+      (0.334, (((512, "meshN"), THF (Without_FOOL, Polymorphic, THF_Without_Choice), "mono_native_higher", keep_lamsN, false), zipperposition_cdots))],
    best_max_mono_iters = default_max_mono_iters,
    best_max_new_mono_instances = default_max_new_mono_instances}