more flexible environment variable
authorblanchet
Tue May 20 16:31:39 2014 +0200 (2014-05-20)
changeset 57018142950e9c7e2
parent 57017 afdf75c0de58
child 57019 f013e3a830c3
more flexible environment variable
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:11:37 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:31:39 2014 +0200
     1.3 @@ -15,7 +15,6 @@
     1.4    type prover_result = Sledgehammer_Prover.prover_result
     1.5  
     1.6    val trace : bool Config.T
     1.7 -  val sml : bool Config.T
     1.8    val MePoN : string
     1.9    val MaShN : string
    1.10    val MeShN : string
    1.11 @@ -37,7 +36,6 @@
    1.12    val extract_suggestions : string -> string * string list
    1.13  
    1.14    val mash_unlearn : Proof.context -> params -> unit
    1.15 -  val is_mash_enabled : unit -> bool
    1.16    val nickname_of_thm : thm -> string
    1.17    val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
    1.18    val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
    1.19 @@ -88,7 +86,6 @@
    1.20  open Sledgehammer_MePo
    1.21  
    1.22  val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
    1.23 -val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false)
    1.24  
    1.25  fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
    1.26  
    1.27 @@ -118,6 +115,25 @@
    1.28      ()
    1.29    end
    1.30  
    1.31 +datatype mash_flavor = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB
    1.32 +
    1.33 +fun mash_flavor () =
    1.34 +  (case getenv "MASH" of
    1.35 +    "yes" => SOME MaSh_Py
    1.36 +  | "py" => SOME MaSh_Py
    1.37 +  | "sml" => SOME MaSh_SML_KNN
    1.38 +  | "sml_knn" => SOME MaSh_SML_KNN
    1.39 +  | "sml_nb" => SOME MaSh_SML_NB
    1.40 +  | _ => NONE)
    1.41 +
    1.42 +val is_mash_enabled = is_some o mash_flavor
    1.43 +
    1.44 +fun is_mash_sml_enabled () =
    1.45 +  (case mash_flavor () of
    1.46 +    SOME MaSh_SML_KNN => true
    1.47 +  | SOME MaSh_SML_NB => true
    1.48 +  | _ => false)
    1.49 +
    1.50  
    1.51  (*** Low-level communication with Python version of MaSh ***)
    1.52  
    1.53 @@ -578,7 +594,7 @@
    1.54                    fold extract_line_and_add_node node_lines Graph.empty),
    1.55                  length node_lines)
    1.56               | LESS =>
    1.57 -               (if Config.get ctxt sml then wipe_out_mash_state_dir ()
    1.58 +               (if is_mash_sml_enabled () then wipe_out_mash_state_dir ()
    1.59                  else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
    1.60               | GREATER => raise FILE_VERSION_TOO_NEW ())
    1.61           in
    1.62 @@ -627,8 +643,8 @@
    1.63  fun clear_state ctxt overlord =
    1.64    (* "MaSh_Py.unlearn" also removes the state file *)
    1.65    Synchronized.change global_state (fn _ =>
    1.66 -    (if Config.get ctxt sml then wipe_out_mash_state_dir ()
    1.67 -     else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
    1.68 +    (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord;
    1.69 +     (false, empty_state)))
    1.70  
    1.71  end
    1.72  
    1.73 @@ -638,8 +654,6 @@
    1.74  
    1.75  (*** Isabelle helpers ***)
    1.76  
    1.77 -fun is_mash_enabled () = (getenv "MASH" = "yes")
    1.78 -
    1.79  val local_prefix = "local" ^ Long_Name.separator
    1.80  
    1.81  fun elided_backquote_thm threshold th =
    1.82 @@ -1208,7 +1222,7 @@
    1.83          (parents, hints, feats)
    1.84        end
    1.85  
    1.86 -    val sml = Config.get ctxt sml
    1.87 +    val sml = is_mash_sml_enabled ()
    1.88  
    1.89      val (access_G, py_suggs) =
    1.90        peek_state ctxt overlord (fn {access_G, ...} =>
    1.91 @@ -1293,7 +1307,7 @@
    1.92                |> filter (is_fact_in_graph access_G)
    1.93                |> map nickname_of_thm
    1.94            in
    1.95 -            if Config.get ctxt sml then
    1.96 +            if is_mash_sml_enabled () then
    1.97                let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
    1.98                  {access_G = access_G, num_known_facts = num_known_facts + 1,
    1.99                   dirty = Option.map (cons name) dirty}
   1.100 @@ -1318,6 +1332,7 @@
   1.101      val timer = Timer.startRealTimer ()
   1.102      fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
   1.103  
   1.104 +    val sml = is_mash_sml_enabled ()
   1.105      val {access_G, ...} = peek_state ctxt overlord I
   1.106      val is_in_access_G = is_fact_in_graph access_G o snd
   1.107      val no_new_facts = forall is_in_access_G facts
   1.108 @@ -1359,7 +1374,7 @@
   1.109                    (false, SOME names, []) => SOME (map #1 learns @ names)
   1.110                  | _ => NONE)
   1.111              in
   1.112 -              if Config.get ctxt sml then
   1.113 +              if sml then
   1.114                  ()
   1.115                else
   1.116                  (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
   1.117 @@ -1536,6 +1551,7 @@
   1.118            end
   1.119          else
   1.120            ()
   1.121 +
   1.122        fun maybe_learn () =
   1.123          if is_mash_enabled () andalso learn then
   1.124            let
   1.125 @@ -1557,6 +1573,7 @@
   1.126            end
   1.127          else
   1.128            false
   1.129 +
   1.130        val (save, effective_fact_filter) =
   1.131          (case fact_filter of
   1.132            SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
   1.133 @@ -1571,18 +1588,22 @@
   1.134        val add_ths = Attrib.eval_thms ctxt add
   1.135  
   1.136        fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
   1.137 +
   1.138        fun add_and_take accepts =
   1.139          (case add_ths of
   1.140             [] => accepts
   1.141           | _ => (unique_facts |> filter in_add |> map fact_of_raw_fact) @
   1.142                  (accepts |> filter_out in_add))
   1.143          |> take max_facts
   1.144 +
   1.145        fun mepo () =
   1.146          (mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts
   1.147           |> weight_facts_steeply, [])
   1.148 +
   1.149        fun mash () =
   1.150          mash_suggested_facts ctxt params (generous_max_facts max_facts) hyp_ts concl_t facts
   1.151          |>> weight_facts_steeply
   1.152 +
   1.153        val mess =
   1.154          (* the order is important for the "case" expression below *)
   1.155          [] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash)
   1.156 @@ -1590,7 +1611,7 @@
   1.157             |> Par_List.map (apsnd (fn f => f ()))
   1.158        val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
   1.159      in
   1.160 -      if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord;
   1.161 +      if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord;
   1.162        (case (fact_filter, mess) of
   1.163          (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
   1.164          [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
   1.165 @@ -1600,7 +1621,7 @@
   1.166  
   1.167  fun kill_learners ctxt ({overlord, ...} : params) =
   1.168    (Async_Manager.kill_threads MaShN "learner";
   1.169 -   if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord)
   1.170 +   if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord)
   1.171  
   1.172  fun running_learners () = Async_Manager.running_threads MaShN "learner"
   1.173