# HG changeset patch # User blanchet # Date 1400596299 -7200 # Node ID 142950e9c7e279153dd1f51475397cde3c713145 # Parent afdf75c0de588d320c9f70b5b9431312761ac9d3 more flexible environment variable diff -r afdf75c0de58 -r 142950e9c7e2 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 16:11:37 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 16:31:39 2014 +0200 @@ -15,7 +15,6 @@ type prover_result = Sledgehammer_Prover.prover_result val trace : bool Config.T - val sml : bool Config.T val MePoN : string val MaShN : string val MeShN : string @@ -37,7 +36,6 @@ val extract_suggestions : string -> string * string list val mash_unlearn : Proof.context -> params -> unit - val is_mash_enabled : unit -> bool val nickname_of_thm : thm -> string val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list @@ -88,7 +86,6 @@ open Sledgehammer_MePo val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false) -val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false) fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else () @@ -118,6 +115,25 @@ () end +datatype mash_flavor = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB + +fun mash_flavor () = + (case getenv "MASH" of + "yes" => SOME MaSh_Py + | "py" => SOME MaSh_Py + | "sml" => SOME MaSh_SML_KNN + | "sml_knn" => SOME MaSh_SML_KNN + | "sml_nb" => SOME MaSh_SML_NB + | _ => NONE) + +val is_mash_enabled = is_some o mash_flavor + +fun is_mash_sml_enabled () = + (case mash_flavor () of + SOME MaSh_SML_KNN => true + | SOME MaSh_SML_NB => true + | _ => false) + (*** Low-level communication with Python version of MaSh ***) @@ -578,7 +594,7 @@ fold extract_line_and_add_node node_lines Graph.empty), length node_lines) | LESS => - (if Config.get ctxt sml then wipe_out_mash_state_dir () + (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *) | GREATER => raise FILE_VERSION_TOO_NEW ()) in @@ -627,8 +643,8 @@ fun clear_state ctxt overlord = (* "MaSh_Py.unlearn" also removes the state file *) Synchronized.change global_state (fn _ => - (if Config.get ctxt sml then wipe_out_mash_state_dir () - else MaSh_Py.unlearn ctxt overlord; (false, empty_state))) + (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord; + (false, empty_state))) end @@ -638,8 +654,6 @@ (*** Isabelle helpers ***) -fun is_mash_enabled () = (getenv "MASH" = "yes") - val local_prefix = "local" ^ Long_Name.separator fun elided_backquote_thm threshold th = @@ -1208,7 +1222,7 @@ (parents, hints, feats) end - val sml = Config.get ctxt sml + val sml = is_mash_sml_enabled () val (access_G, py_suggs) = peek_state ctxt overlord (fn {access_G, ...} => @@ -1293,7 +1307,7 @@ |> filter (is_fact_in_graph access_G) |> map nickname_of_thm in - if Config.get ctxt sml then + if is_mash_sml_enabled () then let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in {access_G = access_G, num_known_facts = num_known_facts + 1, dirty = Option.map (cons name) dirty} @@ -1318,6 +1332,7 @@ val timer = Timer.startRealTimer () fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout) + val sml = is_mash_sml_enabled () val {access_G, ...} = peek_state ctxt overlord I val is_in_access_G = is_fact_in_graph access_G o snd val no_new_facts = forall is_in_access_G facts @@ -1359,7 +1374,7 @@ (false, SOME names, []) => SOME (map #1 learns @ names) | _ => NONE) in - if Config.get ctxt sml then + if sml then () else (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns); @@ -1536,6 +1551,7 @@ end else () + fun maybe_learn () = if is_mash_enabled () andalso learn then let @@ -1557,6 +1573,7 @@ end else false + val (save, effective_fact_filter) = (case fact_filter of SOME ff => (ff <> mepoN andalso maybe_learn (), ff) @@ -1571,18 +1588,22 @@ val add_ths = Attrib.eval_thms ctxt add fun in_add (_, th) = member Thm.eq_thm_prop add_ths th + fun add_and_take accepts = (case add_ths of [] => accepts | _ => (unique_facts |> filter in_add |> map fact_of_raw_fact) @ (accepts |> filter_out in_add)) |> take max_facts + fun mepo () = (mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts |> weight_facts_steeply, []) + fun mash () = mash_suggested_facts ctxt params (generous_max_facts max_facts) hyp_ts concl_t facts |>> weight_facts_steeply + val mess = (* the order is important for the "case" expression below *) [] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash) @@ -1590,7 +1611,7 @@ |> Par_List.map (apsnd (fn f => f ())) val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take in - if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord; + if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord; (case (fact_filter, mess) of (NONE, [(_, (mepo, _)), (_, (mash, _))]) => [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take), @@ -1600,7 +1621,7 @@ fun kill_learners ctxt ({overlord, ...} : params) = (Async_Manager.kill_threads MaShN "learner"; - if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord) + if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord) fun running_learners () = Async_Manager.running_threads MaShN "learner"