reused Sledgehammer code to parse parameters of sledgehammer action in Mirabelle
authordesharna
Sun, 28 Nov 2021 14:15:01 +0100
changeset 74897 8b1ab558e3ee
parent 74896 f9908452b282
child 74898 e83224066f19
reused Sledgehammer code to parse parameters of sledgehammer action in Mirabelle
src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML
src/HOL/Tools/Sledgehammer/sledgehammer_prover.ML
--- a/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Sun Nov 21 11:21:16 2021 +0100
+++ b/src/HOL/Tools/Mirabelle/mirabelle_sledgehammer.ML	Sun Nov 28 14:15:01 2021 +0100
@@ -3,7 +3,7 @@
     Author:     Sascha Boehme, TU Munich
     Author:     Tobias Nipkow, TU Munich
     Author:     Makarius
-    Author:     Martin Desharnais, UniBw Munich
+    Author:     Martin Desharnais, UniBw Munich, MPI-INF Saarbruecken
 
 Mirabelle action: "sledgehammer".
 *)
@@ -21,48 +21,16 @@
 
 val check_trivialK = "check_trivial" (*=BOOL: check if goals are "trivial"*)
 val e_selection_heuristicK = "e_selection_heuristic" (*=STRING: E clause selection heuristic*)
-val fact_filterK = "fact_filter" (*=STRING: fact filter*)
 val force_sosK = "force_sos" (*=BOOL: use set-of-support (in Vampire)*)
-val isar_proofsK = "isar_proofs" (*=SMART_BOOL: enable Isar proof generation*)
 val keepK = "keep" (*=BOOL: keep temporary files created by sledgehammer*)
-val lam_transK = "lam_trans" (*=STRING: lambda translation scheme*)
-val max_factsK = "max_facts" (*=NUM: max. relevant clauses to use*)
-val max_mono_itersK = "max_mono_iters" (*=NUM: max. iterations of monomorphiser*)
-val max_new_mono_instancesK = "max_new_mono_instances" (*=NUM: max. new monomorphic instances*)
-val max_relevantK = "max_relevant" (*=NUM: max. relevant clauses to use*)
-val minimizeK = "minimize" (*=BOOL: instruct sledgehammer to run its minimizer*)
-val preplay_timeoutK = "preplay_timeout" (*=TIME: timeout for finding reconstructed proof*)
 val proof_methodK = "proof_method" (*=STRING: how to reconstruct proofs (e.g. using metis)*)
 val proverK = "prover" (*=STRING: name of the external prover to call*)
-val prover_timeoutK = "prover_timeout" (*=TIME: timeout for invoked ATP (seconds of process time)*)
-val sliceK = "slice" (*=BOOL: allow sledgehammer-level strategy-scheduling*)
-val smt_proofsK = "smt_proofs" (*=BOOL: enable SMT proof generation*)
-val strictK = "strict" (*=BOOL: run in strict mode*)
 val term_orderK = "term_order" (*=STRING: term order (in E)*)
-val type_encK = "type_enc" (*=STRING: type encoding scheme*)
-val uncurried_aliasesK = "uncurried_aliases" (*=SMART_BOOL: use fresh function names to alias curried applications*)
 
-(*FIXME sensible to have Mirabelle-level Sledgehammer defaults?*)
 (*defaults used in this Mirabelle action*)
-val preplay_timeout_default = "1"
-val lam_trans_default = "smart"
-val uncurried_aliases_default = "smart"
-val fact_filter_default = "smart"
-val type_enc_default = "smart"
-val strict_default = "false"
-val max_facts_default = "smart"
-val slice_default = "true"
 val check_trivial_default = false
 val keep_default = false
 
-(*If a key is present in args then augment a list with its pair*)
-(*This is used to avoid fixing default values at the Mirabelle level, and
-  instead use the default values of the tool (Sledgehammer in this case).*)
-fun available_parameter args key label list =
-  let
-    val value = AList.lookup (op =) args key
-  in if is_some value then (label, the value) :: list else list end
-
 datatype sh_data = ShData of {
   calls: int,
   success: int,
@@ -336,10 +304,8 @@
   SH_FAIL of int * int |
   SH_ERROR
 
-fun run_sh prover_name fact_filter type_enc strict max_facts slice
-      lam_trans uncurried_aliases e_selection_heuristic term_order force_sos
-      hard_timeout timeout preplay_timeout isar_proofsLST smt_proofsLST
-      minimizeLST max_new_mono_instancesLST max_mono_itersLST dir pos st =
+fun run_sh (params as {max_facts, minimize, preplay_timeout, ...}) prover_name e_selection_heuristic
+    term_order force_sos hard_timeout dir pos st =
   let
     val thy = Proof.theory_of st
     val {context = ctxt, facts = chained_ths, goal} = Proof.goal st
@@ -365,30 +331,13 @@
                   term_order |> the_default I)
             #> (Option.map (Config.put Sledgehammer_ATP_Systems.force_sos)
                   force_sos |> the_default I))
-    val params as {max_facts, minimize, preplay_timeout, ...} =
-      Sledgehammer_Commands.default_params thy
-         ([(* ("verbose", "true"), *)
-           ("fact_filter", fact_filter),
-           ("type_enc", type_enc),
-           ("strict", strict),
-           ("lam_trans", lam_trans |> the_default lam_trans_default),
-           ("uncurried_aliases", uncurried_aliases |> the_default uncurried_aliases_default),
-           ("max_facts", max_facts),
-           ("slice", slice),
-           ("timeout", string_of_int timeout),
-           ("preplay_timeout", preplay_timeout)]
-          |> isar_proofsLST
-          |> smt_proofsLST
-          |> minimizeLST (*don't confuse the two minimization flags*)
-          |> max_new_mono_instancesLST
-          |> max_mono_itersLST)
     val default_max_facts =
       Sledgehammer_Prover_Minimize.default_max_facts_of_prover ctxt prover_name
     val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal i ctxt
     val time_limit =
       (case hard_timeout of
         NONE => I
-      | SOME secs => Timeout.apply (Time.fromSeconds secs))
+      | SOME t => Timeout.apply t)
     fun failed failure =
       ({outcome = SOME failure, used_facts = [], used_from = [],
         preferred_methss = (Sledgehammer_Proof_Methods.Auto_Method, []), run_time = Time.zeroTime,
@@ -432,33 +381,17 @@
 
 in
 
-fun run_sledgehammer change_data thy_index trivial output_dir args proof_method named_thms pos st =
+fun run_sledgehammer change_data (params as {provers, timeout, ...}) output_dir
+  e_selection_heuristic term_order force_sos keep proof_method_from_msg thy_index trivial
+  proof_method named_thms pos st =
   let
     val thy = Proof.theory_of st
     val thy_name = Context.theory_name thy
     val triv_str = if trivial then "[T] " else ""
     val _ = change_data inc_sh_calls
     val _ = if trivial then () else change_data inc_sh_nontriv_calls
-    val prover_name = get_prover_name thy args
-    val fact_filter = AList.lookup (op =) args fact_filterK |> the_default fact_filter_default
-    val type_enc = AList.lookup (op =) args type_encK |> the_default type_enc_default
-    val strict = AList.lookup (op =) args strictK |> the_default strict_default
-    val max_facts =
-      (case AList.lookup (op =) args max_factsK of
-        SOME max => max
-      | NONE =>
-        (case AList.lookup (op =) args max_relevantK of
-          SOME max => max
-        | NONE => max_facts_default))
-    val slice = AList.lookup (op =) args sliceK |> the_default slice_default
-    val lam_trans = AList.lookup (op =) args lam_transK
-    val uncurried_aliases = AList.lookup (op =) args uncurried_aliasesK
-    val e_selection_heuristic = AList.lookup (op =) args e_selection_heuristicK
-    val term_order = AList.lookup (op =) args term_orderK
-    val force_sos = AList.lookup (op =) args force_sosK
-      |> Option.map (curry (op <>) "false")
     val keep_dir =
-      if Mirabelle.get_bool_argument args (keepK, keep_default) then
+      if keep then
         let val subdir = StringCvt.padLeft #"0" 4 (string_of_int thy_index) ^ "_" ^ thy_name in
           Path.append output_dir (Path.basic subdir)
           |> Isabelle_System.make_directory
@@ -467,23 +400,13 @@
         end
       else
         NONE
-    val timeout = Mirabelle.get_int_argument args (prover_timeoutK, 30)
     (* always use a hard timeout, but give some slack so that the automatic
        minimizer has a chance to do its magic *)
-    val preplay_timeout = AList.lookup (op =) args preplay_timeoutK
-      |> the_default preplay_timeout_default
-    val isar_proofsLST = available_parameter args isar_proofsK "isar_proofs"
-    val smt_proofsLST = available_parameter args smt_proofsK "smt_proofs"
-    val minimizeLST = available_parameter args minimizeK "minimize"
-    val max_new_mono_instancesLST =
-      available_parameter args max_new_mono_instancesK max_new_mono_instancesK
-    val max_mono_itersLST = available_parameter args max_mono_itersK max_mono_itersK
-    val hard_timeout = SOME (4 * timeout)
+    val hard_timeout = SOME (Time.scale 4.0 timeout)
+    val prover_name = hd provers
     val (msg, result) =
-      run_sh prover_name fact_filter type_enc strict max_facts slice lam_trans
-        uncurried_aliases e_selection_heuristic term_order force_sos
-        hard_timeout timeout preplay_timeout isar_proofsLST smt_proofsLST
-        minimizeLST max_new_mono_instancesLST max_mono_itersLST keep_dir pos st
+      run_sh params prover_name e_selection_heuristic term_order force_sos hard_timeout keep_dir pos
+        st
   in
     (case result of
       SH_OK (time_isa, time_prover, names) =>
@@ -499,7 +422,7 @@
           change_data (inc_sh_max_lems (length names));
           change_data (inc_sh_time_isa time_isa);
           change_data (inc_sh_time_prover time_prover);
-          proof_method := proof_method_from_msg args msg;
+          proof_method := proof_method_from_msg msg;
           named_thms := SOME (map_filter get_thms names);
           triv_str ^ "succeeded (" ^ string_of_int time_isa ^ "+" ^
             string_of_int time_prover ^ ") [" ^ prover_name ^ "]:\n" ^ msg
@@ -593,8 +516,22 @@
 
 fun make_action ({arguments, timeout, output_dir, ...} : Mirabelle.action_context) =
   let
+    (* Parse Mirabelle-specific parameters *)
     val check_trivial =
       Mirabelle.get_bool_argument arguments (check_trivialK, check_trivial_default)
+    val keep = Mirabelle.get_bool_argument arguments (keepK, keep_default)
+    val e_selection_heuristic = AList.lookup (op =) arguments e_selection_heuristicK
+    val term_order = AList.lookup (op =) arguments term_orderK
+    val force_sos = AList.lookup (op =) arguments force_sosK
+      |> Option.map (curry (op <>) "false")
+    val proof_method_from_msg = proof_method_from_msg arguments
+
+    (* Parse Sledgehammer parameters *)
+    val params = Sledgehammer_Commands.default_params \<^theory> arguments
+      |> (fn (params as {provers, ...}) =>
+            (case provers of
+              prover :: _ => Sledgehammer_Prover.set_params_provers params [prover]
+            | _ => error "sledgehammer action requires one prover"))
 
     val data = Synchronized.var "Mirabelle_Sledgehammer.data" empty_data
     val change_data = Synchronized.change data
@@ -612,8 +549,8 @@
               check_trivial andalso Try0.try0 (SOME try_timeout) ([], [], [], []) pre
               handle Timeout.TIMEOUT _ => false
             val log1 =
-              run_sledgehammer change_data theory_index trivial output_dir arguments meth named_thms
-                pos pre
+              run_sledgehammer change_data params output_dir e_selection_heuristic term_order
+                force_sos keep proof_method_from_msg theory_index trivial meth named_thms pos pre
             val log2 =
               (case !named_thms of
                 SOME thms =>
@@ -628,4 +565,4 @@
 
 val () = Mirabelle.register_action "sledgehammer" make_action
 
-end
+end
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover.ML	Sun Nov 21 11:21:16 2021 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover.ML	Sun Nov 28 14:15:01 2021 +0100
@@ -47,6 +47,8 @@
      preplay_timeout : Time.time,
      expect : string}
 
+  val set_params_provers : params -> string list -> params
+
   type prover_problem =
     {comment : string,
      state : Proof.state,
@@ -141,6 +143,33 @@
    preplay_timeout : Time.time,
    expect : string}
 
+fun set_params_provers params provers =
+  {debug = #debug params,
+   verbose = #verbose params,
+   overlord = #overlord params,
+   spy = #spy params,
+   provers = provers,
+   type_enc = #type_enc params,
+   strict = #strict params,
+   lam_trans = #lam_trans params,
+   uncurried_aliases = #uncurried_aliases params,
+   learn = #learn params,
+   fact_filter = #fact_filter params,
+   induction_rules = #induction_rules params,
+   max_facts = #max_facts params,
+   fact_thresholds = #fact_thresholds params,
+   max_mono_iters = #max_mono_iters params,
+   max_new_mono_instances = #max_new_mono_instances params,
+   isar_proofs = #isar_proofs params,
+   compress = #compress params,
+   try0 = #try0 params,
+   smt_proofs = #smt_proofs params,
+   slice = #slice params,
+   minimize = #minimize params,
+   timeout = #timeout params,
+   preplay_timeout = #preplay_timeout params,
+   expect = #expect params}
+
 type prover_problem =
   {comment : string,
    state : Proof.state,