disable slicing within SMT (in preparation for factoring it out)
authorblanchet
Mon, 31 Jan 2022 16:09:23 +0100
changeset 75017 30ccc472d486
parent 75016 873b581fd690
child 75018 fcfd96a59625
disable slicing within SMT (in preparation for factoring it out)
src/HOL/Tools/Sledgehammer/sledgehammer.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_minimize.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt.ML
src/HOL/Tools/Sledgehammer/sledgehammer_tactics.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Mon Jan 31 16:09:23 2022 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Mon Jan 31 16:09:23 2022 +0100
@@ -137,24 +137,21 @@
         (used_facts |> filter_out (fn (_, (sc, _)) => sc = Chained), (meth, play)))
 
 fun launch_prover (params as {verbose, spy, max_facts, induction_rules, ...}) mode only learn
-    ({comment, state, goal, subgoal, subgoal_count, factss, found_proof} : prover_problem) name =
+    ({comment, state, goal, subgoal, subgoal_count, facts, found_proof} : prover_problem) name =
   let
     val ctxt = Proof.context_of state
 
     val _ = spying spy (fn () => (state, subgoal, name, "Launched"))
     val max_facts = max_facts |> the_default (default_max_facts_of_prover ctxt name)
-    val num_facts =
-      (case factss of
-        (_, facts) :: _ => length facts |> not only ? Integer.min max_facts
-      | _ => 0)
+    val num_facts = length (snd facts) |> not only ? Integer.min max_facts
     val induction_rules = induction_rules_for_prover ctxt name induction_rules
 
     val problem =
       {comment = comment, state = state, goal = goal, subgoal = subgoal,
        subgoal_count = subgoal_count,
-       factss = factss
-         (* We take num_facts because factss contains the maximum of all called provers. *)
-         |> map (apsnd (take num_facts o maybe_filter_out_induction_rules induction_rules)),
+       facts = facts
+         (* We take "num_facts" because "facts" contains the maximum of all called provers. *)
+         |> apsnd (take num_facts o maybe_filter_out_induction_rules induction_rules),
        found_proof = found_proof}
 
     fun print_used_facts used_facts used_from =
@@ -188,7 +185,7 @@
             end
 
           val filter_infos =
-            map filter_info (("actual", used_from) :: factss)
+            map filter_info [("actual", used_from), facts]
             |> AList.group (op =)
             |> map (fn (indices, fact_filters) => commas fact_filters ^ ": " ^ indices)
         in
@@ -348,10 +345,10 @@
 
         fun launch_provers () =
           let
-            val factss = get_factss provers
+            val facts = hd (get_factss provers) (* temporary *)
             val problem =
               {comment = "", state = state, goal = goal, subgoal = i, subgoal_count = n,
-               factss = factss, found_proof = found_proof}
+               facts = facts, found_proof = found_proof}
             val learn = mash_learn_proof ctxt params (Thm.prop_of goal)
             val launch = launch_prover_and_preplay params mode writeln_result only learn
           in
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover.ML	Mon Jan 31 16:09:23 2022 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover.ML	Mon Jan 31 16:09:23 2022 +0100
@@ -57,7 +57,7 @@
      goal : thm,
      subgoal : int,
      subgoal_count : int,
-     factss : (string * fact list) list,
+     facts : string * fact list,
      found_proof : unit -> unit}
 
   type prover_result =
@@ -185,7 +185,7 @@
    goal : thm,
    subgoal : int,
    subgoal_count : int,
-   factss : (string * fact list) list,
+   facts : string * fact list,
    found_proof : unit -> unit}
 
 type prover_result =
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML	Mon Jan 31 16:09:23 2022 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML	Mon Jan 31 16:09:23 2022 +0100
@@ -106,12 +106,6 @@
 fun get_slices slice slices =
   map_index I slices |> slice = Time.zeroTime ? (List.last #> single)
 
-fun get_facts_of_filter _ [(_, facts)] = facts
-  | get_facts_of_filter fact_filter factss =
-    (case AList.lookup (op =) factss fact_filter of
-      SOME facts => facts
-    | NONE => snd (hd factss))
-
 (* For low values of "max_facts", this fudge value ensures that most slices are invoked with a
    nontrivial amount of facts. *)
 val max_fact_factor_fudge = 5
@@ -124,18 +118,17 @@
   | suffix_of_mode MaSh = ""
   | suffix_of_mode Minimize = "_min"
 
-(* Give the ATPs some slack before interrupting them the hard way. "z3_tptp" on Linux appears to be
-   the only ATP that does not honor its time limit. *)
+(* Give the ATPs some slack before interrupting them the hard way. *)
 val atp_timeout_slack = seconds 1.0
 
 (* Important messages are important but not so important that users want to see them each time. *)
 val atp_important_message_keep_quotient = 25
 
 fun run_atp mode name
-    ({debug, verbose, overlord, type_enc, strict, lam_trans, uncurried_aliases, fact_filter,
-     max_facts, max_mono_iters, max_new_mono_instances, isar_proofs, compress, try0, smt_proofs,
-     slice, minimize, timeout, preplay_timeout, spy, ...} : params)
-    ({comment, state, goal, subgoal, subgoal_count, factss, found_proof, ...} : prover_problem) =
+    ({debug, verbose, overlord, type_enc, strict, lam_trans, uncurried_aliases, max_facts,
+      max_mono_iters, max_new_mono_instances, isar_proofs, compress, try0, smt_proofs, slice,
+      minimize, timeout, preplay_timeout, spy, ...} : params)
+    ({comment, state, goal, subgoal, subgoal_count, facts, found_proof, ...} : prover_problem) =
   let
     val thy = Proof.theory_of state
     val ctxt = Proof.context_of state
@@ -226,12 +219,11 @@
         val slices_timeout_ms = real (Time.toMilliseconds timeout - slices_overhead_ms)
 
         fun run_slice time_left (cache_key, cache_value) (slice, (time_frac,
-            (key as ((best_max_facts, best_fact_filter), format, best_type_enc, best_lam_trans,
-               best_uncurried_aliases),
+            (key as ((best_max_facts, _ (* best_fact_filter *)), format, best_type_enc,
+               best_lam_trans, best_uncurried_aliases),
              extra))) =
           let
-            val effective_fact_filter = fact_filter |> the_default best_fact_filter
-            val facts = get_facts_of_filter effective_fact_filter factss
+            val facts = snd facts
             val num_facts =
               Real.ceil (max_fact_factor * Real.fromInt best_max_facts) + max_fact_factor_fudge
               |> Integer.min (length facts)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_minimize.ML	Mon Jan 31 16:09:23 2022 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_minimize.ML	Mon Jan 31 16:09:23 2022 +0100
@@ -98,7 +98,7 @@
        slice = Time.zeroTime, timeout = timeout, preplay_timeout = preplay_timeout, expect = ""}
     val problem =
       {comment = "", state = state, goal = goal, subgoal = i, subgoal_count = n,
-       factss = [("", facts)], found_proof = I}
+       facts = ("", facts), found_proof = I}
     val result0 as {outcome = outcome0, used_facts, used_from, preferred_methss, run_time,
         message} =
       prover params problem
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt.ML	Mon Jan 31 16:09:23 2022 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt.ML	Mon Jan 31 16:09:23 2022 +0100
@@ -14,10 +14,6 @@
 
   val smt_builtins : bool Config.T
   val smt_triggers : bool Config.T
-  val smt_max_slices : int Config.T
-  val smt_slice_fact_frac : real Config.T
-  val smt_slice_time_frac : real Config.T
-  val smt_slice_min_secs : int Config.T
 
   val is_smt_prover : Proof.context -> string -> bool
   val run_smt_solver : mode -> string -> prover
@@ -64,137 +60,65 @@
   | failure_of_smt_failure SMT_Failure.Out_Of_Memory = OutOfResources
   | failure_of_smt_failure (SMT_Failure.Other_Failure s) = UnknownError s
 
-(* FUDGE *)
-val smt_max_slices = Attrib.setup_config_int \<^binding>\<open>sledgehammer_smt_max_slices\<close> (K 8)
-val smt_slice_fact_frac =
-  Attrib.setup_config_real \<^binding>\<open>sledgehammer_smt_slice_fact_frac\<close> (K 0.667)
-val smt_slice_time_frac =
-  Attrib.setup_config_real \<^binding>\<open>sledgehammer_smt_slice_time_frac\<close> (K 0.333)
-val smt_slice_min_secs = Attrib.setup_config_int \<^binding>\<open>sledgehammer_smt_slice_min_secs\<close> (K 3)
-
 val is_boring_builtin_typ =
   not o exists_subtype (member (op =) [\<^typ>\<open>nat\<close>, \<^typ>\<open>int\<close>, HOLogic.realT])
 
-fun smt_filter_loop name ({debug, overlord, max_mono_iters, max_new_mono_instances, timeout, slice,
-      type_enc, ...} : params) state goal i =
+fun smt_filter name ({debug, overlord, max_mono_iters, max_new_mono_instances,
+    type_enc, slice, timeout, ...} : params) state goal i facts =
   let
+    val run_timeout = if slice = Time.zeroTime then timeout else slice
     val (higher_order, nat_as_int) =
       (case type_enc of
         SOME s =>  (String.isSubstring "native_higher" s, String.isSubstring "arith" s)
       | NONE => (false, false))
-    fun repair_context ctxt =
-      ctxt |> Context.proof_map (SMT_Config.select_solver name)
-           |> Config.put SMT_Config.verbose debug
-           |> Config.put SMT_Config.higher_order higher_order
-           |> Config.put SMT_Config.nat_as_int nat_as_int
-           |> (if overlord then
-                 Config.put SMT_Config.debug_files
-                   (overlord_file_location_of_prover name |> (fn (path, name) => path ^ "/" ^ name))
-               else
-                 I)
-           |> Config.put SMT_Config.infer_triggers (Config.get ctxt smt_triggers)
-           |> not (Config.get ctxt smt_builtins)
-              ? (SMT_Builtin.filter_builtins is_boring_builtin_typ
-                 #> Config.put SMT_Systems.z3_extensions false)
-           |> repair_monomorph_context max_mono_iters default_max_mono_iters max_new_mono_instances
-                default_max_new_mono_instances
+    fun repair_context ctxt = ctxt
+      |> Context.proof_map (SMT_Config.select_solver name)
+      |> Config.put SMT_Config.verbose debug
+      |> Config.put SMT_Config.higher_order higher_order
+      |> Config.put SMT_Config.nat_as_int nat_as_int
+      |> (if overlord then
+            Config.put SMT_Config.debug_files
+              (overlord_file_location_of_prover name |> (fn (path, name) => path ^ "/" ^ name))
+          else
+            I)
+       |> Config.put SMT_Config.infer_triggers (Config.get ctxt smt_triggers)
+       |> not (Config.get ctxt smt_builtins)
+         ? (SMT_Builtin.filter_builtins is_boring_builtin_typ
+            #> Config.put SMT_Systems.z3_extensions false)
+       |> repair_monomorph_context max_mono_iters default_max_mono_iters max_new_mono_instances
+         default_max_new_mono_instances
 
     val state = Proof.map_context (repair_context) state
     val ctxt = Proof.context_of state
-    val max_slices = if slice = Time.zeroTime then 1 else Config.get ctxt smt_max_slices
 
-    fun do_slice timeout slice outcome0 time_so_far (factss as (fact_filter, facts) :: _) =
-      let
-        val timer = Timer.startRealTimer ()
-        val slice_timeout =
-          if slice < max_slices then
-            let val ms = Time.toMilliseconds timeout in
-              Int.min (ms, Int.max (1000 * Config.get ctxt smt_slice_min_secs,
-                Real.ceil (Config.get ctxt smt_slice_time_frac * Real.fromInt ms)))
-              |> Time.fromMilliseconds
-            end
-          else
-            timeout
-        val num_facts = length facts
-        val _ =
-          if debug then
-            quote name ^ " slice " ^ string_of_int slice ^ " with " ^ string_of_int num_facts ^
-            " fact" ^ plural_s num_facts ^ " for " ^ string_of_time slice_timeout
-            |> writeln
-          else
-            ()
-        val birth = Timer.checkRealTimer timer
-
-        val filter_result as {outcome, ...} =
-          SMT_Solver.smt_filter ctxt goal facts i slice_timeout
-          handle exn =>
-            if Exn.is_interrupt exn orelse debug then
-              Exn.reraise exn
-            else
-              {outcome = SOME (SMT_Failure.Other_Failure (Runtime.exn_message exn)),
-               fact_ids = NONE, atp_proof = K []}
-
-        val death = Timer.checkRealTimer timer
-        val outcome0 = if is_none outcome0 then SOME outcome else outcome0
-        val time_so_far = time_so_far + (death - birth)
-        val timeout = timeout - Timer.checkRealTimer timer
+    val timer = Timer.startRealTimer ()
+    val birth = Timer.checkRealTimer timer
 
-        val too_many_facts_perhaps =
-          (case outcome of
-            NONE => false
-          | SOME (SMT_Failure.Counterexample _) => false
-          | SOME SMT_Failure.Time_Out => slice_timeout <> timeout
-          | SOME (SMT_Failure.Abnormal_Termination _) => true (* kind of *)
-          | SOME SMT_Failure.Out_Of_Memory => true
-          | SOME (SMT_Failure.Other_Failure _) => true)
-      in
-        if too_many_facts_perhaps andalso slice < max_slices andalso num_facts > 0 andalso
-           timeout > Time.zeroTime then
-          let
-            val new_num_facts =
-              Real.ceil (Config.get ctxt smt_slice_fact_frac * Real.fromInt num_facts)
-            val factss as (new_fact_filter, _) :: _ =
-              factss
-              |> (fn (x :: xs) => xs @ [x])
-              |> app_hd (apsnd (take new_num_facts))
-            val show_filter = fact_filter <> new_fact_filter
+    val filter_result as {outcome, ...} =
+      SMT_Solver.smt_filter ctxt goal facts i run_timeout
+      handle exn =>
+      if Exn.is_interrupt exn orelse debug then
+        Exn.reraise exn
+      else
+        {outcome = SOME (SMT_Failure.Other_Failure (Runtime.exn_message exn)), fact_ids = NONE,
+         atp_proof = K []}
 
-            fun num_of_facts fact_filter num_facts =
-              string_of_int num_facts ^ (if show_filter then " " ^ quote fact_filter else "") ^
-              " fact" ^ plural_s num_facts
-
-            val _ =
-              if debug then
-                quote name ^ " invoked with " ^
-                num_of_facts fact_filter num_facts ^ ": " ^
-                string_of_atp_failure (failure_of_smt_failure (the outcome)) ^
-                " Retrying with " ^ num_of_facts new_fact_filter new_num_facts ^
-                "..."
-                |> writeln
-              else
-                ()
-          in
-            do_slice timeout (slice + 1) outcome0 time_so_far factss
-          end
-        else
-          {outcome = if is_none outcome then NONE else the outcome0, filter_result = filter_result,
-           used_from = facts, run_time = time_so_far}
-      end
+    val death = Timer.checkRealTimer timer
+    val run_time = death - birth
   in
-    do_slice timeout 1 NONE Time.zeroTime
+    {outcome = outcome, filter_result = filter_result, used_from = facts, run_time = run_time}
   end
 
-fun run_smt_solver mode name (params as {debug, verbose, isar_proofs, compress, try0, smt_proofs,
-      minimize, preplay_timeout, ...})
-    ({state, goal, subgoal, subgoal_count, factss, found_proof, ...} : prover_problem) =
+fun run_smt_solver mode name (params as {debug, verbose, isar_proofs, compress, try0,
+      smt_proofs, minimize, preplay_timeout, ...})
+    ({state, goal, subgoal, subgoal_count, facts, found_proof, ...} : prover_problem) =
   let
-    val thy = Proof.theory_of state
     val ctxt = Proof.context_of state
 
-    val factss = map (apsnd (map (apsnd (Thm.transfer thy)))) factss
+    val facts = snd facts
 
     val {outcome, filter_result = {fact_ids, atp_proof, ...}, used_from, run_time} =
-      smt_filter_loop name params state goal subgoal factss
+      smt_filter name params state goal subgoal facts
     val used_facts =
       (case fact_ids of
         NONE => map fst used_from
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_tactics.ML	Mon Jan 31 16:09:23 2022 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_tactics.ML	Mon Jan 31 16:09:23 2022 +0100
@@ -45,7 +45,7 @@
       |> hd |> snd
     val problem =
       {comment = "", state = Proof.init ctxt, goal = goal, subgoal = i, subgoal_count = n,
-       factss = [("", facts)], found_proof = I}
+       facts = ("", facts), found_proof = I}
   in
     (case prover params problem of
       {outcome = NONE, used_facts, ...} => used_facts |> map fst |> SOME