added option to control which fact filter is used
authorblanchet
Wed, 18 Jul 2012 08:44:04 +0200
changeset 48314 ee33ba3c0e05
parent 48313 0faafdffa662
child 48315 82d6e46c673f
added option to control which fact filter is used
src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML
src/HOL/Tools/Sledgehammer/sledgehammer_isar.ML
src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML
src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -15,6 +15,10 @@
   type prover_result = Sledgehammer_Provers.prover_result
 
   val trace : bool Config.T
+  val meshN : string
+  val iterN : string
+  val mashN : string
+  val fact_filters : string list
   val escape_meta : string -> string
   val escape_metas : string list -> string
   val unescape_meta : string -> string
@@ -67,10 +71,17 @@
   Attrib.setup_config_bool @{binding sledgehammer_filter_mash_trace} (K false)
 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
 
-fun mash_dir () =
+val meshN = "mesh"
+val iterN = "iter"
+val mashN = "mash"
+
+val fact_filters = [meshN, iterN, mashN]
+
+fun mash_home () = getenv "MASH_HOME"
+fun mash_state_dir () =
   getenv "ISABELLE_HOME_USER" ^ "/mash"
   |> tap (fn dir => Isabelle_System.mkdir (Path.explode dir))
-fun mash_state_path () = mash_dir () ^ "/state" |> Path.explode
+fun mash_state_path () = mash_state_dir () ^ "/state" |> Path.explode
 
 (*** Isabelle helpers ***)
 
@@ -109,20 +120,21 @@
 fun sum_avg n xs =
   fold (Integer.add o Integer.mult n) xs 0 div (length xs)
 
-fun mesh_facts max_facts mess =
-  let
-    val n = length mess
-    val fact_eq = Thm.eq_thm o pairself snd
-    fun score_in fact (facts, def) =
-      case find_index (curry fact_eq fact) facts of
-        ~1 => def
-      | j => SOME j
-    fun score_of fact = mess |> map_filter (score_in fact) |> sum_avg n
-    val facts = fold (union fact_eq o take max_facts o fst) mess []
-  in
-    facts |> map (`score_of) |> sort (int_ord o pairself fst) |> map snd
-          |> take max_facts
-  end
+fun mesh_facts max_facts [(facts, _)] = facts |> take max_facts
+  | mesh_facts max_facts mess =
+    let
+      val n = length mess
+      val fact_eq = Thm.eq_thm o pairself snd
+      fun score_in fact (facts, def) =
+        case find_index (curry fact_eq fact) facts of
+          ~1 => def
+        | j => SOME j
+      fun score_of fact = mess |> map_filter (score_in fact) |> sum_avg n
+      val facts = fold (union fact_eq o take max_facts o fst) mess []
+    in
+      facts |> map (`score_of) |> sort (int_ord o pairself fst) |> map snd
+            |> take max_facts
+    end
 
 val thy_feature_prefix = "y_"
 
@@ -319,10 +331,10 @@
     val pred_file = temp_dir ^ "/mash_preds." ^ serial
     val log_file = temp_dir ^ "/mash_log." ^ serial
     val command =
-      getenv "MASH_HOME" ^ "/mash.py --inputFile " ^ cmd_file ^
-      " --outputDir " ^ mash_dir () ^ " --predictions " ^ pred_file ^
+      mash_home () ^ "/mash.py --quiet --inputFile " ^ cmd_file ^
+      " --outputDir " ^ mash_state_dir () ^ " --predictions " ^ pred_file ^
       " --log " ^ log_file ^ " --numberOfPredictions 1000" ^
-      (if save then " --saveModel" else "") ^ " > /dev/null"
+      (if save then " --saveModel" else "")
     val _ = File.write cmd_path ""
     val _ = write_cmds (File.append cmd_path)
     val _ = trace_msg ctxt (fn () => "  running " ^ command)
@@ -337,7 +349,7 @@
   "? " ^ escape_metas access ^ "; " ^ escape_metas feats
 
 fun mash_RESET ctxt =
-  let val path = mash_dir () |> Path.explode in
+  let val path = mash_state_dir () |> Path.explode in
     trace_msg ctxt (K "MaSh RESET");
     File.fold_dir (fn file => fn () =>
                       File.rm (Path.append path (Path.basic file)))
@@ -355,7 +367,7 @@
       "MaSh DEL " ^ escape_metas facts ^ "; " ^ escape_metas feats)
 
 fun mash_QUERY ctxt (query as (_, feats)) =
-  (trace_msg ctxt (fn () => "MaSh SUGGEST " ^ space_implode " " feats);
+  (trace_msg ctxt (fn () => "MaSh QUERY " ^ space_implode " " feats);
    run_mash ctxt false (fn append => append (str_of_query query))
                  (fn preds => snd (extract_query (List.last (preds ()))))
    handle List.Empty => [])
@@ -509,25 +521,34 @@
         state
     end)
 
-fun relevant_facts ctxt params prover max_facts
+fun relevant_facts ctxt (params as {fact_filter, ...}) prover max_facts
         ({add, only, ...} : fact_override) hyp_ts concl_t facts =
-  if only then
+  if not (subset (op =) (the_list fact_filter, fact_filters)) then
+    error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
+  else if only then
     facts
   else if max_facts <= 0 then
     []
   else
     let
+      val fact_filter =
+        case fact_filter of
+          SOME ff => ff
+        | NONE => if mash_home () = "" then iterN else meshN
       val add_ths = Attrib.eval_thms ctxt add
       fun prepend_facts ths accepts =
         ((facts |> filter (member Thm.eq_thm_prop ths o snd)) @
          (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
         |> take max_facts
-      val iter_facts =
+      fun iter () =
         iterative_relevant_facts ctxt params prover max_facts NONE hyp_ts
                                  concl_t facts
-      val mash_facts =
-        facts |> mash_suggest_facts ctxt params prover hyp_ts concl_t
-      val mess = [(iter_facts, SOME (length iter_facts)), (mash_facts, NONE)]
+        |> (fn facts => (facts, SOME (length facts)))
+      fun mash () =
+        (facts |> mash_suggest_facts ctxt params prover hyp_ts concl_t, NONE)
+      val mess =
+        [] |> (if fact_filter <> mashN then cons (iter ()) else I)
+           |> (if fact_filter <> iterN then cons (mash ()) else I)
     in
       mesh_facts max_facts mess
       |> not (null add_ths) ? prepend_facts add_ths
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_isar.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_isar.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -84,8 +84,9 @@
    ("strict", "false"),
    ("lam_trans", "smart"),
    ("uncurried_aliases", "smart"),
+   ("fact_filter", "smart"),
+   ("max_facts", "smart"),
    ("fact_thresholds", "0.45 0.85"),
-   ("max_facts", "smart"),
    ("max_mono_iters", "smart"),
    ("max_new_mono_instances", "smart"),
    ("isar_proof", "false"),
@@ -147,8 +148,8 @@
 
 val any_type_enc = type_enc_from_string Strict "erased"
 
-(* "provers =", "type_enc =", and "lam_trans" can be omitted. For the last two,
-   this is a secret feature. *)
+(* "provers =", "type_enc =", "lam_trans =", and "fact_filter =" can be omitted.
+   For the last three, this is a secret feature. *)
 fun normalize_raw_param ctxt =
   unalias_raw_param
   #> (fn (name, value) =>
@@ -161,6 +162,8 @@
          else if can (trans_lams_from_string ctxt any_type_enc) name andalso
                  null value then
            ("lam_trans", [name])
+         else if member (op =) fact_filters name then
+           ("fact_filter", [name])
          else
            error ("Unknown parameter: " ^ quote name ^ "."))
 
@@ -291,8 +294,9 @@
     val strict = mode = Auto_Try orelse lookup_bool "strict"
     val lam_trans = lookup_option lookup_string "lam_trans"
     val uncurried_aliases = lookup_option lookup_bool "uncurried_aliases"
+    val fact_filter = lookup_option lookup_string "fact_filter"
+    val max_facts = lookup_option lookup_int "max_facts"
     val fact_thresholds = lookup_real_pair "fact_thresholds"
-    val max_facts = lookup_option lookup_int "max_facts"
     val max_mono_iters = lookup_option lookup_int "max_mono_iters"
     val max_new_mono_instances =
       lookup_option lookup_int "max_new_mono_instances"
@@ -311,8 +315,8 @@
     {debug = debug, verbose = verbose, overlord = overlord, blocking = blocking,
      provers = provers, type_enc = type_enc, strict = strict,
      lam_trans = lam_trans, uncurried_aliases = uncurried_aliases,
-     fact_thresholds = fact_thresholds, max_facts = max_facts,
-     max_mono_iters = max_mono_iters,
+     fact_filter = fact_filter, max_facts = max_facts,
+     fact_thresholds = fact_thresholds, max_mono_iters = max_mono_iters,
      max_new_mono_instances = max_new_mono_instances,  isar_proof = isar_proof,
      isar_shrink_factor = isar_shrink_factor, slice = slice,
      minimize = minimize, timeout = timeout, preplay_timeout = preplay_timeout,
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -68,8 +68,8 @@
       {debug = debug, verbose = verbose, overlord = overlord, blocking = true,
        provers = provers, type_enc = type_enc, strict = strict,
        lam_trans = lam_trans, uncurried_aliases = uncurried_aliases,
-       fact_thresholds = (1.01, 1.01), max_facts = SOME (length facts),
-       max_mono_iters = max_mono_iters,
+       fact_filter = NONE, max_facts = SOME (length facts),
+       fact_thresholds = (1.01, 1.01), max_mono_iters = max_mono_iters,
        max_new_mono_instances = max_new_mono_instances, isar_proof = isar_proof,
        isar_shrink_factor = isar_shrink_factor, slice = false,
        minimize = SOME false, timeout = timeout,
@@ -225,7 +225,7 @@
 
 fun adjust_reconstructor_params override_params
         ({debug, verbose, overlord, blocking, provers, type_enc, strict,
-         lam_trans, uncurried_aliases, fact_thresholds, max_facts,
+         lam_trans, uncurried_aliases, fact_filter, max_facts, fact_thresholds,
          max_mono_iters, max_new_mono_instances, isar_proof, isar_shrink_factor,
          slice, minimize, timeout, preplay_timeout, expect} : params) =
   let
@@ -241,8 +241,8 @@
     {debug = debug, verbose = verbose, overlord = overlord, blocking = blocking,
      provers = provers, type_enc = type_enc, strict = strict,
      lam_trans = lam_trans, uncurried_aliases = uncurried_aliases,
-     max_facts = max_facts, fact_thresholds = fact_thresholds,
-     max_mono_iters = max_mono_iters,
+     fact_filter = fact_filter, max_facts = max_facts,
+     fact_thresholds = fact_thresholds, max_mono_iters = max_mono_iters,
      max_new_mono_instances = max_new_mono_instances, isar_proof = isar_proof,
      isar_shrink_factor = isar_shrink_factor, slice = slice,
      minimize = minimize, timeout = timeout, preplay_timeout = preplay_timeout,
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -27,8 +27,9 @@
      strict: bool,
      lam_trans: string option,
      uncurried_aliases: bool option,
+     fact_filter: string option,
+     max_facts: int option,
      fact_thresholds: real * real,
-     max_facts: int option,
      max_mono_iters: int option,
      max_new_mono_instances: int option,
      isar_proof: bool,
@@ -314,8 +315,9 @@
    strict: bool,
    lam_trans: string option,
    uncurried_aliases: bool option,
+   fact_filter: string option,
+   max_facts: int option,
    fact_thresholds: real * real,
-   max_facts: int option,
    max_mono_iters: int option,
    max_new_mono_instances: int option,
    isar_proof: bool,