added option "cache_dir" to Sledgehammer
authordesharna
Wed, 18 Dec 2024 10:21:58 +0100
changeset 81610 ed9ffd8e9e40
parent 81581 8a3608933607
child 81611 2a0276c40989
added option "cache_dir" to Sledgehammer
src/HOL/Tools/SMT/smt_solver.ML
src/HOL/Tools/Sledgehammer/sledgehammer.ML
src/HOL/Tools/Sledgehammer/sledgehammer_commands.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_minimize.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt.ML
src/HOL/Tools/Sledgehammer/sledgehammer_tactics.ML
--- a/src/HOL/Tools/SMT/smt_solver.ML	Wed Dec 11 12:04:27 2024 +0100
+++ b/src/HOL/Tools/SMT/smt_solver.ML	Wed Dec 18 10:21:58 2024 +0100
@@ -35,7 +35,7 @@
 
   (*filter*)
   val smt_filter: Proof.context -> thm -> ((string * ATP_Problem_Generate.stature) * thm) list ->
-    int -> Time.time -> string list -> parsed_proof
+    int -> Time.time -> ((string -> string) -> string -> string) -> string list -> parsed_proof
 
   (*tactic*)
   val smt_tac: Proof.context -> thm list -> int -> tactic
@@ -136,7 +136,7 @@
 
 in
 
-fun invoke name command options smt_options ithms ctxt =
+fun invoke memoize_fun_call name command options smt_options ithms ctxt =
   let
     val options = options @ SMT_Config.solver_options_of ctxt
     val comments = [implode_space options]
@@ -146,7 +146,14 @@
       |> tap (trace_assms ctxt)
       |> SMT_Translate.translate ctxt name smt_options comments
       ||> tap trace_replay_data
-  in (run_solver ctxt' name (make_command command options) str, replay_data) end
+
+    val run_solver = run_solver ctxt' name (make_command command options)
+
+    val output_lines =
+      (case memoize_fun_call of
+        NONE => run_solver str
+      | SOME memoize => split_lines (memoize (cat_lines o run_solver) str))
+  in (output_lines, replay_data) end
 
 end
 
@@ -264,13 +271,13 @@
     val thms = map (pair SMT_Util.Axiom o check_topsort ctxt) thms0
     val (name, {command, smt_options, replay, ...}) = name_and_info_of ctxt
     val (output, replay_data) =
-      invoke name command [] smt_options (SMT_Normalize.normalize ctxt thms) ctxt
+      invoke NONE name command [] smt_options (SMT_Normalize.normalize ctxt thms) ctxt
   in replay ctxt replay_data output end
 
 
 (* filter (for Sledgehammer) *)
 
-fun smt_filter ctxt0 goal xfacts i time_limit options =
+fun smt_filter ctxt0 goal xfacts i time_limit memoize_fun_call options =
   let
     val ctxt = ctxt0 |> Config.put SMT_Config.timeout (Time.toReal time_limit)
 
@@ -290,7 +297,8 @@
 
     val (name, {command, smt_options, parse_proof, ...}) = name_and_info_of ctxt
     val (output, replay_data) =
-      invoke name command options smt_options (SMT_Normalize.normalize ctxt thms) ctxt
+      invoke (SOME memoize_fun_call) name command options smt_options
+        (SMT_Normalize.normalize ctxt thms) ctxt
   in
     parse_proof ctxt replay_data xfacts (map Thm.prop_of prems) (Thm.term_of concl) output
   end
--- a/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Wed Dec 11 12:04:27 2024 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer.ML	Wed Dec 18 10:21:58 2024 +0100
@@ -320,7 +320,7 @@
     val ctxt = Proof.context_of state
     val hard_timeout = Time.scale 5.0 timeout
 
-    fun flip_problem {comment, state, goal, subgoal, factss = factss, ...} =
+    fun flip_problem {comment, state, goal, subgoal, factss, memoize_fun_call, ...} =
       let
         val thy = Proof_Context.theory_of ctxt
         val assms = Assumption.all_assms_of ctxt
@@ -340,7 +340,8 @@
         {comment = comment, state = state, goal = Thm.trivial @{cprop False}, subgoal = 1,
          subgoal_count = 1, factss = map (apsnd (append new_facts)) factss,
          has_already_found_something = has_already_found_something,
-         found_something = found_something "a falsification"}
+         found_something = found_something "a falsification",
+         memoize_fun_call = memoize_fun_call}
       end
 
     val problem as {goal, ...} = problem |> falsify ? flip_problem
@@ -486,8 +487,35 @@
     |> distinct (op =)
   end
 
+local
+
+fun memoize verbose cache_dir f arg =
+  let
+    val hash = SHA1.rep (SHA1.digest arg)
+    val file = cache_dir + Path.explode hash
+  in
+    if File.is_file file then
+      let
+        val () =
+          if verbose then
+            writeln ("Found problem with key " ^ hash ^ " in cache.")
+          else
+            ()
+      in
+        File.read file
+      end
+    else
+      let
+        val result = f arg
+      in
+        File.write file result;
+        result
+      end
+  end
+in
+
 fun run_sledgehammer (params as {verbose, spy, provers, falsify, induction_rules, max_facts,
-    max_proofs, slices, timeout, ...}) mode writeln_result i (fact_override as {only, ...}) state =
+    max_proofs, slices, timeout, cache_dir, ...}) mode writeln_result i (fact_override as {only, ...}) state =
   if null provers then
     error "No prover is set"
   else
@@ -589,13 +617,23 @@
             factss
           end
 
+        val memoize_fun_call =
+          (case cache_dir of
+            NONE => (fn f => fn arg => f arg)
+          | SOME path =>
+            (if File.is_dir path then
+              memoize verbose path
+            else
+              (warning ("No such directory: " ^ quote (Path.print path));
+              fn f => fn arg => f arg)))
+
         fun launch_provers () =
           let
             val factss = get_factss provers
             val problem =
               {comment = "", state = state, goal = goal, subgoal = i, subgoal_count = n,
                factss = factss, has_already_found_something = has_already_found_something,
-               found_something = found_something "a proof"}
+               found_something = found_something "a proof", memoize_fun_call = memoize_fun_call}
             val learn = mash_learn_proof ctxt params (Thm.prop_of goal)
             val launch = launch_prover_and_preplay params mode has_already_found_something
               found_something massage_message writeln_result learn
@@ -652,4 +690,6 @@
             else (the_default writeln writeln_result ("Warning: " ^ message); false)))
       end)
 
+end
+
 end;
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_commands.ML	Wed Dec 11 12:04:27 2024 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_commands.ML	Wed Dec 18 10:21:58 2024 +0100
@@ -74,7 +74,8 @@
    ("suggest_of", "smart"),
    ("minimize", "true"),
    ("slices", string_of_int (12 * Multithreading.max_threads ())),
-   ("preplay_timeout", "1")]
+   ("preplay_timeout", "1"),
+   ("cache_dir", "")]
 
 val alias_params =
   [("prover", ("provers", [])), (* undocumented *)
@@ -272,6 +273,8 @@
     val timeout = lookup_time "timeout"
     val preplay_timeout = lookup_time "preplay_timeout"
     val expect = lookup_string "expect"
+    val cache_dir = Option.mapPartial
+      (fn str => if str = "" then NONE else SOME (Path.explode str)) (lookup "cache_dir")
   in
     {debug = debug, verbose = verbose, overlord = overlord, spy = spy, provers = provers,
      abduce = abduce, falsify = falsify, type_enc = type_enc, strict = strict,
@@ -281,7 +284,7 @@
      max_new_mono_instances = max_new_mono_instances, max_proofs = max_proofs,
      isar_proofs = isar_proofs, compress = compress, try0 = try0, smt_proofs = smt_proofs,
      suggest_of = suggest_of, minimize = minimize, slices = slices, timeout = timeout,
-     preplay_timeout = preplay_timeout, expect = expect}
+     preplay_timeout = preplay_timeout, expect = expect, cache_dir = cache_dir}
   end
 
 fun get_params mode = extract_params mode o default_raw_params
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Dec 11 12:04:27 2024 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Dec 18 10:21:58 2024 +0100
@@ -866,7 +866,7 @@
     val problem =
       {comment = "Goal: " ^ goal_name, state = Proof.init ctxt, goal = goal, subgoal = 1,
        subgoal_count = 1, factss = [("", facts)], has_already_found_something = K false,
-       found_something = K ()}
+       found_something = K (), memoize_fun_call = (fn f => f)}
     val slice = hd (get_slices ctxt prover)
   in
     get_minimizing_prover ctxt MaSh (K ()) prover params problem slice
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover.ML	Wed Dec 11 12:04:27 2024 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover.ML	Wed Dec 18 10:21:58 2024 +0100
@@ -52,7 +52,8 @@
      slices : int,
      timeout : Time.time,
      preplay_timeout : Time.time,
-     expect : string}
+     expect : string,
+     cache_dir : Path.T option}
 
   val string_of_params : params -> string
   val slice_timeout : int -> int -> Time.time -> Time.time
@@ -65,7 +66,8 @@
      subgoal_count : int,
      factss : (string * fact list) list,
      has_already_found_something : unit -> bool,
-     found_something : string -> unit}
+     found_something : string -> unit,
+     memoize_fun_call : (string -> string) -> string -> string}
 
   datatype prover_slice_extra =
     ATP_Slice of atp_slice
@@ -165,7 +167,8 @@
    slices : int,
    timeout : Time.time,
    preplay_timeout : Time.time,
-   expect : string}
+   expect : string,
+   cache_dir : Path.T option}
 
 fun string_of_params (params : params) =
   Pretty.pure_string_of (Pretty.from_ML (ML_system_pretty (params, 100)))
@@ -190,7 +193,8 @@
    subgoal_count : int,
    factss : (string * fact list) list,
    has_already_found_something : unit -> bool,
-   found_something : string -> unit}
+   found_something : string -> unit,
+   memoize_fun_call : (string -> string) -> string -> string}
 
 datatype prover_slice_extra =
   ATP_Slice of atp_slice
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML	Wed Dec 11 12:04:27 2024 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML	Wed Dec 18 10:21:58 2024 +0100
@@ -103,7 +103,7 @@
       max_new_mono_instances, isar_proofs, compress, try0, smt_proofs, minimize, slices, timeout,
       preplay_timeout, spy, ...} : params)
     ({comment, state, goal, subgoal, subgoal_count, factss, has_already_found_something,
-      found_something} : prover_problem)
+      found_something, memoize_fun_call} : prover_problem)
     slice =
   let
     val (basic_slice as (slice_size, abduce, _, _, _),
@@ -130,10 +130,18 @@
         (Config.get ctxt atp_problem_dest_dir,
          Config.get ctxt atp_proof_dest_dir,
          Config.get ctxt atp_problem_prefix)
-    val problem_file_name =
-      Path.basic (problem_prefix ^ (if overlord then "" else serial_string ()) ^
-        suffix_of_mode mode ^ "_" ^ string_of_int subgoal)
-      |> Path.ext "p"
+
+    val (problem_file_name, proof_file_name) =
+    let
+      val base_name =
+        problem_prefix ^ (if overlord then "" else serial_string ()) ^
+        suffix_of_mode mode ^ "_" ^ string_of_int subgoal
+    in
+      (base_name, suffix "_proof" base_name)
+      |> apply2 Path.basic
+      |> apply2 (Path.ext "p")
+    end
+
     val prob_path =
       if problem_dest_dir = "" then
         File.tmp_path problem_file_name
@@ -141,6 +149,7 @@
         Path.explode problem_dest_dir + problem_file_name
       else
         error ("No such directory: " ^ quote problem_dest_dir)
+
     val executable =
       (case find_first (fn var => getenv var <> "") (fst exec) of
         SOME var =>
@@ -206,32 +215,50 @@
         val args = arguments abduce full_proofs extra run_timeout prob_path
         val command = implode_space (File.bash_path executable :: args)
 
-        fun run_command () =
-          if exec = isabelle_scala_function then
-            let val {output, timing} = SystemOnTPTP.run_system_encoded args
-            in (output, timing) end
-          else
-            let val res = Isabelle_System.bash_process (Bash.script command |> Bash.redirect)
-            in (Process_Result.out res, Process_Result.timing_elapsed res) end
+        val lines_of_atp_problem =
+          lines_of_atp_problem good_format (fn () => atp_problem_term_order_info atp_problem)
+            atp_problem
 
-        val _ = atp_problem
-          |> lines_of_atp_problem good_format (fn () => atp_problem_term_order_info atp_problem)
+        val () = lines_of_atp_problem
           |> (exec <> isabelle_scala_function) ?
             cons ("% " ^ command ^ "\n" ^ (if comment = "" then "" else "% " ^ comment ^ "\n"))
           |> File.write_list prob_path
 
+        fun run_command () =
+          let
+            val f = fn _ =>
+              if exec = isabelle_scala_function then
+                  let val {output, ...} = SystemOnTPTP.run_system_encoded args
+                  in output end
+              else
+                  let val res = Isabelle_System.bash_process (Bash.script command |> Bash.redirect)
+                  in Process_Result.out res end
+            (* Hackish: This removes the two first lines that contain call-specific information
+            such as timestamp. *)
+            val arg = cat_lines (drop 2 lines_of_atp_problem)
+          in
+            Timing.timing (memoize_fun_call f) arg
+          end
+
         val local_name = name |> perhaps (try (unprefix remote_prefix))
 
         val ((output, run_time), ((atp_proof, tstplike_proof), outcome)) =
-          Timeout.apply generous_run_timeout run_command ()
-          |>> overlord ?
-            (fn output => prefix ("% " ^ command ^ "\n% " ^ timestamp () ^ "\n") output)
-          |> (fn accum as (output, _) =>
-            (accum,
-             extract_tstplike_proof_and_outcome verbose proof_delims known_failures output
-             |>> `(atp_proof_of_tstplike_proof false local_name atp_problem)
-             handle UNRECOGNIZED_ATP_PROOF () => (([], ""), SOME ProofUnparsable)))
-          handle Timeout.TIMEOUT _ => (("", run_timeout), (([], ""), SOME TimedOut))
+          let
+            val ({elapsed, ...}, output) = Timeout.apply generous_run_timeout run_command ()
+            val output =
+              if overlord then
+                prefix ("% " ^ command ^ "\n% " ^ timestamp () ^ "\n") output
+              else
+                output
+            val output2 =
+              extract_tstplike_proof_and_outcome verbose proof_delims known_failures output
+              |>> `(atp_proof_of_tstplike_proof false local_name atp_problem)
+              handle UNRECOGNIZED_ATP_PROOF () => (([], ""), SOME ProofUnparsable)
+          in
+            ((output, elapsed), output2)
+          end
+          handle
+              Timeout.TIMEOUT _ => (("", run_timeout), (([], ""), SOME TimedOut))
             | ERROR msg => (("", Time.zeroTime), (([], ""), SOME (UnknownError msg)))
 
         val atp_abduce_candidates =
@@ -259,18 +286,11 @@
        too. *)
     fun clean_up () = if problem_dest_dir = "" then (try File.rm prob_path; ()) else ()
     fun export (_, (output, _, _, _, _, _, _, _), _) =
-      let
-        val proof_dest_dir_path = Path.explode proof_dest_dir
-        val make_export_file_name =
-          Path.split_ext
-          #> apfst (Path.explode o suffix "_proof" o Path.implode)
-          #> swap
-          #> uncurry Path.ext
-      in
+      let val proof_dest_dir_path = Path.explode proof_dest_dir in
         if proof_dest_dir = "" then
           Output.system_message "don't export proof"
         else if File.exists proof_dest_dir_path then
-          File.write (proof_dest_dir_path + make_export_file_name problem_file_name) output
+          File.write (proof_dest_dir_path + proof_file_name) output
         else
           error ("No such directory: " ^ quote proof_dest_dir)
       end
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_minimize.ML	Wed Dec 11 12:04:27 2024 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_minimize.ML	Wed Dec 18 10:21:58 2024 +0100
@@ -82,7 +82,7 @@
 
 fun test_facts ({debug, verbose, overlord, spy, provers, max_mono_iters, max_new_mono_instances,
       type_enc, strict, lam_trans, uncurried_aliases, isar_proofs, compress, try0, smt_proofs,
-      suggest_of, minimize, preplay_timeout, induction_rules, ...} : params)
+      suggest_of, minimize, preplay_timeout, induction_rules, cache_dir, ...} : params)
     (slice as ((_, _, falsify, _, fact_filter), slice_extra)) silent (prover : prover) timeout i n
     state goal facts =
   let
@@ -99,10 +99,11 @@
        max_new_mono_instances = max_new_mono_instances, max_proofs = 1,
        isar_proofs = isar_proofs, compress = compress, try0 = try0, smt_proofs = smt_proofs,
        suggest_of = suggest_of, minimize = minimize, slices = 1, timeout = timeout,
-       preplay_timeout = preplay_timeout, expect = ""}
+       preplay_timeout = preplay_timeout, expect = "", cache_dir = cache_dir}
     val problem =
       {comment = "", state = state, goal = goal, subgoal = i, subgoal_count = n,
-       factss = [("", facts)], has_already_found_something = K false, found_something = K ()}
+       factss = [("", facts)], has_already_found_something = K false, found_something = K (),
+       memoize_fun_call = (fn f => f)}
     val result0 as {outcome = outcome0, used_facts, used_from, preferred_methss, run_time,
         message} =
       prover params problem ((1, false, false, length facts, fact_filter), slice_extra)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt.ML	Wed Dec 11 12:04:27 2024 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_smt.ML	Wed Dec 18 10:21:58 2024 +0100
@@ -70,7 +70,7 @@
   not o exists_subtype (member (op =) [\<^typ>\<open>nat\<close>, \<^typ>\<open>int\<close>, HOLogic.realT])
 
 fun smt_filter name ({debug, overlord, max_mono_iters, max_new_mono_instances, type_enc, slices,
-    timeout, ...} : params) state goal i slice_size facts options =
+    timeout, ...} : params) memoize_fun_call state goal i slice_size facts options =
   let
     val run_timeout = slice_timeout slice_size slices timeout
     val (higher_order, nat_as_int) =
@@ -104,7 +104,7 @@
     val birth = Timer.checkRealTimer timer
 
     val filter_result as {outcome, ...} =
-      \<^try>\<open>SMT_Solver.smt_filter ctxt goal facts i run_timeout options
+      \<^try>\<open>SMT_Solver.smt_filter ctxt goal facts i run_timeout memoize_fun_call options
         catch exn =>
           {outcome = SOME (SMT_Failure.Other_Failure (Runtime.exn_message exn)), fact_ids = NONE,
            atp_proof = K []}\<close>
@@ -115,9 +115,11 @@
     {outcome = outcome, filter_result = filter_result, used_from = facts, run_time = run_time}
   end
 
-fun run_smt_solver mode name (params as {debug, verbose, isar_proofs, compress, try0,
-      smt_proofs, minimize, preplay_timeout, ...})
-    ({state, goal, subgoal, subgoal_count, factss, found_something, ...} : prover_problem)
+fun run_smt_solver mode name
+    (params as {debug, verbose, isar_proofs, compress, try0, smt_proofs, minimize, preplay_timeout,
+      ...} : params)
+    ({state, goal, subgoal, subgoal_count, factss, found_something, memoize_fun_call,
+      ...} : prover_problem)
     slice =
   let
     val (basic_slice as (slice_size, _, _, _, _), SMT_Slice options) = slice
@@ -125,7 +127,7 @@
     val ctxt = Proof.context_of state
 
     val {outcome, filter_result = {fact_ids, atp_proof, ...}, used_from, run_time} =
-      smt_filter name params state goal subgoal slice_size facts options
+      smt_filter name params memoize_fun_call state goal subgoal slice_size facts options
     val used_facts =
       (case fact_ids of
         NONE => map fst used_from
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_tactics.ML	Wed Dec 11 12:04:27 2024 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_tactics.ML	Wed Dec 18 10:21:58 2024 +0100
@@ -45,7 +45,8 @@
       |> hd |> snd
     val problem =
       {comment = "", state = Proof.init ctxt, goal = goal, subgoal = i, subgoal_count = n,
-       factss = [("", facts)], has_already_found_something = K false, found_something = K ()}
+       factss = [("", facts)], has_already_found_something = K false, found_something = K (),
+       memoize_fun_call = (fn f => f)}
     val slice = hd (get_slices ctxt name)
   in
     (case prover params problem slice of