src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57018 142950e9c7e2
parent 57017 afdf75c0de58
child 57028 e5466055e94f
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:11:37 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:31:39 2014 +0200
@@ -15,7 +15,6 @@
   type prover_result = Sledgehammer_Prover.prover_result
 
   val trace : bool Config.T
-  val sml : bool Config.T
   val MePoN : string
   val MaShN : string
   val MeShN : string
@@ -37,7 +36,6 @@
   val extract_suggestions : string -> string * string list
 
   val mash_unlearn : Proof.context -> params -> unit
-  val is_mash_enabled : unit -> bool
   val nickname_of_thm : thm -> string
   val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
   val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
@@ -88,7 +86,6 @@
 open Sledgehammer_MePo
 
 val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
-val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false)
 
 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
 
@@ -118,6 +115,25 @@
     ()
   end
 
+datatype mash_flavor = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB
+
+fun mash_flavor () =
+  (case getenv "MASH" of
+    "yes" => SOME MaSh_Py
+  | "py" => SOME MaSh_Py
+  | "sml" => SOME MaSh_SML_KNN
+  | "sml_knn" => SOME MaSh_SML_KNN
+  | "sml_nb" => SOME MaSh_SML_NB
+  | _ => NONE)
+
+val is_mash_enabled = is_some o mash_flavor
+
+fun is_mash_sml_enabled () =
+  (case mash_flavor () of
+    SOME MaSh_SML_KNN => true
+  | SOME MaSh_SML_NB => true
+  | _ => false)
+
 
 (*** Low-level communication with Python version of MaSh ***)
 
@@ -578,7 +594,7 @@
                   fold extract_line_and_add_node node_lines Graph.empty),
                 length node_lines)
              | LESS =>
-               (if Config.get ctxt sml then wipe_out_mash_state_dir ()
+               (if is_mash_sml_enabled () then wipe_out_mash_state_dir ()
                 else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
              | GREATER => raise FILE_VERSION_TOO_NEW ())
          in
@@ -627,8 +643,8 @@
 fun clear_state ctxt overlord =
   (* "MaSh_Py.unlearn" also removes the state file *)
   Synchronized.change global_state (fn _ =>
-    (if Config.get ctxt sml then wipe_out_mash_state_dir ()
-     else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
+    (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord;
+     (false, empty_state)))
 
 end
 
@@ -638,8 +654,6 @@
 
 (*** Isabelle helpers ***)
 
-fun is_mash_enabled () = (getenv "MASH" = "yes")
-
 val local_prefix = "local" ^ Long_Name.separator
 
 fun elided_backquote_thm threshold th =
@@ -1208,7 +1222,7 @@
         (parents, hints, feats)
       end
 
-    val sml = Config.get ctxt sml
+    val sml = is_mash_sml_enabled ()
 
     val (access_G, py_suggs) =
       peek_state ctxt overlord (fn {access_G, ...} =>
@@ -1293,7 +1307,7 @@
               |> filter (is_fact_in_graph access_G)
               |> map nickname_of_thm
           in
-            if Config.get ctxt sml then
+            if is_mash_sml_enabled () then
               let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
                 {access_G = access_G, num_known_facts = num_known_facts + 1,
                  dirty = Option.map (cons name) dirty}
@@ -1318,6 +1332,7 @@
     val timer = Timer.startRealTimer ()
     fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
 
+    val sml = is_mash_sml_enabled ()
     val {access_G, ...} = peek_state ctxt overlord I
     val is_in_access_G = is_fact_in_graph access_G o snd
     val no_new_facts = forall is_in_access_G facts
@@ -1359,7 +1374,7 @@
                   (false, SOME names, []) => SOME (map #1 learns @ names)
                 | _ => NONE)
             in
-              if Config.get ctxt sml then
+              if sml then
                 ()
               else
                 (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
@@ -1536,6 +1551,7 @@
           end
         else
           ()
+
       fun maybe_learn () =
         if is_mash_enabled () andalso learn then
           let
@@ -1557,6 +1573,7 @@
           end
         else
           false
+
       val (save, effective_fact_filter) =
         (case fact_filter of
           SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
@@ -1571,18 +1588,22 @@
       val add_ths = Attrib.eval_thms ctxt add
 
       fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
+
       fun add_and_take accepts =
         (case add_ths of
            [] => accepts
          | _ => (unique_facts |> filter in_add |> map fact_of_raw_fact) @
                 (accepts |> filter_out in_add))
         |> take max_facts
+
       fun mepo () =
         (mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts
          |> weight_facts_steeply, [])
+
       fun mash () =
         mash_suggested_facts ctxt params (generous_max_facts max_facts) hyp_ts concl_t facts
         |>> weight_facts_steeply
+
       val mess =
         (* the order is important for the "case" expression below *)
         [] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash)
@@ -1590,7 +1611,7 @@
            |> Par_List.map (apsnd (fn f => f ()))
       val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
     in
-      if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord;
+      if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord;
       (case (fact_filter, mess) of
         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
@@ -1600,7 +1621,7 @@
 
 fun kill_learners ctxt ({overlord, ...} : params) =
   (Async_Manager.kill_threads MaShN "learner";
-   if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord)
+   if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord)
 
 fun running_learners () = Async_Manager.running_threads MaShN "learner"