--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 16:11:37 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 16:31:39 2014 +0200
@@ -15,7 +15,6 @@
type prover_result = Sledgehammer_Prover.prover_result
val trace : bool Config.T
- val sml : bool Config.T
val MePoN : string
val MaShN : string
val MeShN : string
@@ -37,7 +36,6 @@
val extract_suggestions : string -> string * string list
val mash_unlearn : Proof.context -> params -> unit
- val is_mash_enabled : unit -> bool
val nickname_of_thm : thm -> string
val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
@@ -88,7 +86,6 @@
open Sledgehammer_MePo
val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
-val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false)
fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
@@ -118,6 +115,25 @@
()
end
+datatype mash_flavor = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB
+
+fun mash_flavor () =
+ (case getenv "MASH" of
+ "yes" => SOME MaSh_Py
+ | "py" => SOME MaSh_Py
+ | "sml" => SOME MaSh_SML_KNN
+ | "sml_knn" => SOME MaSh_SML_KNN
+ | "sml_nb" => SOME MaSh_SML_NB
+ | _ => NONE)
+
+val is_mash_enabled = is_some o mash_flavor
+
+fun is_mash_sml_enabled () =
+ (case mash_flavor () of
+ SOME MaSh_SML_KNN => true
+ | SOME MaSh_SML_NB => true
+ | _ => false)
+
(*** Low-level communication with Python version of MaSh ***)
@@ -578,7 +594,7 @@
fold extract_line_and_add_node node_lines Graph.empty),
length node_lines)
| LESS =>
- (if Config.get ctxt sml then wipe_out_mash_state_dir ()
+ (if is_mash_sml_enabled () then wipe_out_mash_state_dir ()
else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
| GREATER => raise FILE_VERSION_TOO_NEW ())
in
@@ -627,8 +643,8 @@
fun clear_state ctxt overlord =
(* "MaSh_Py.unlearn" also removes the state file *)
Synchronized.change global_state (fn _ =>
- (if Config.get ctxt sml then wipe_out_mash_state_dir ()
- else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
+ (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord;
+ (false, empty_state)))
end
@@ -638,8 +654,6 @@
(*** Isabelle helpers ***)
-fun is_mash_enabled () = (getenv "MASH" = "yes")
-
val local_prefix = "local" ^ Long_Name.separator
fun elided_backquote_thm threshold th =
@@ -1208,7 +1222,7 @@
(parents, hints, feats)
end
- val sml = Config.get ctxt sml
+ val sml = is_mash_sml_enabled ()
val (access_G, py_suggs) =
peek_state ctxt overlord (fn {access_G, ...} =>
@@ -1293,7 +1307,7 @@
|> filter (is_fact_in_graph access_G)
|> map nickname_of_thm
in
- if Config.get ctxt sml then
+ if is_mash_sml_enabled () then
let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
{access_G = access_G, num_known_facts = num_known_facts + 1,
dirty = Option.map (cons name) dirty}
@@ -1318,6 +1332,7 @@
val timer = Timer.startRealTimer ()
fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
+ val sml = is_mash_sml_enabled ()
val {access_G, ...} = peek_state ctxt overlord I
val is_in_access_G = is_fact_in_graph access_G o snd
val no_new_facts = forall is_in_access_G facts
@@ -1359,7 +1374,7 @@
(false, SOME names, []) => SOME (map #1 learns @ names)
| _ => NONE)
in
- if Config.get ctxt sml then
+ if sml then
()
else
(MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
@@ -1536,6 +1551,7 @@
end
else
()
+
fun maybe_learn () =
if is_mash_enabled () andalso learn then
let
@@ -1557,6 +1573,7 @@
end
else
false
+
val (save, effective_fact_filter) =
(case fact_filter of
SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
@@ -1571,18 +1588,22 @@
val add_ths = Attrib.eval_thms ctxt add
fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
+
fun add_and_take accepts =
(case add_ths of
[] => accepts
| _ => (unique_facts |> filter in_add |> map fact_of_raw_fact) @
(accepts |> filter_out in_add))
|> take max_facts
+
fun mepo () =
(mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts
|> weight_facts_steeply, [])
+
fun mash () =
mash_suggested_facts ctxt params (generous_max_facts max_facts) hyp_ts concl_t facts
|>> weight_facts_steeply
+
val mess =
(* the order is important for the "case" expression below *)
[] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash)
@@ -1590,7 +1611,7 @@
|> Par_List.map (apsnd (fn f => f ()))
val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
in
- if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord;
+ if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord;
(case (fact_filter, mess) of
(NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
[(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
@@ -1600,7 +1621,7 @@
fun kill_learners ctxt ({overlord, ...} : params) =
(Async_Manager.kill_threads MaShN "learner";
- if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord)
+ if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord)
fun running_learners () = Async_Manager.running_threads MaShN "learner"