src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 53148 c898409d8630
parent 53142 966a251efd16
child 53150 5565d1b56f84
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Aug 22 21:15:43 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Aug 22 23:03:21 2013 +0200
@@ -36,7 +36,7 @@
 
   structure MaSh:
   sig
-    val unlearn : Proof.context -> unit
+    val unlearn : Proof.context -> bool -> unit
     val learn :
       Proof.context -> bool
       -> (string * string list * string list * string list) list -> unit
@@ -49,7 +49,7 @@
       -> string list
   end
 
-  val mash_unlearn : Proof.context -> unit
+  val mash_unlearn : Proof.context -> params -> unit
   val nickname_of_thm : thm -> string
   val find_suggested_facts :
     Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
@@ -90,14 +90,14 @@
   val mash_learn :
     Proof.context -> params -> fact_override -> thm list -> bool -> unit
   val is_mash_enabled : unit -> bool
-  val mash_can_suggest_facts : Proof.context -> bool
+  val mash_can_suggest_facts : Proof.context -> bool -> bool
   val generous_max_facts : int -> int
   val mepo_weight : real
   val mash_weight : real
   val relevant_facts :
     Proof.context -> params -> string -> int -> fact_override -> term list
     -> term -> raw_fact list -> (string * fact list) list
-  val kill_learners : Proof.context -> unit
+  val kill_learners : Proof.context -> params -> unit
   val running_learners : unit -> unit
 end;
 
@@ -253,13 +253,13 @@
 structure MaSh =
 struct
 
-fun shutdown ctxt =
-  run_mash_tool ctxt false [shutdown_server_arg] ([], K "") (K ())
+fun shutdown ctxt overlord =
+  run_mash_tool ctxt overlord [shutdown_server_arg] ([], K "") (K ())
 
-fun unlearn ctxt =
+fun unlearn ctxt overlord =
   let val path = mash_model_dir () in
     trace_msg ctxt (K "MaSh unlearn");
-    shutdown ctxt;
+    shutdown ctxt overlord;
     try (File.fold_dir (fn file => fn _ =>
                            try File.rm (Path.append path (Path.basic file)))
                        path) NONE;
@@ -359,8 +359,8 @@
      | _ => NONE)
   | _ => NONE
 
-fun load _ (state as (true, _)) = state
-  | load ctxt _ =
+fun load _ _ (state as (true, _)) = state
+  | load ctxt overlord _ =
     let val path = mash_state_file () in
       (true,
        case try File.read_lines path of
@@ -383,7 +383,7 @@
                 length node_lines)
              | LESS =>
                (* can't parse old file *)
-               (MaSh.unlearn ctxt; (Graph.empty, 0))
+               (MaSh.unlearn ctxt overlord; (Graph.empty, 0))
              | GREATER => raise FILE_VERSION_TOO_NEW ()
          in
            trace_msg ctxt (fn () =>
@@ -423,22 +423,22 @@
 
 in
 
-fun map_state ctxt f =
-  Synchronized.change global_state (load ctxt ##> (f #> save ctxt))
+fun map_state ctxt overlord f =
+  Synchronized.change global_state (load ctxt overlord ##> (f #> save ctxt))
   handle FILE_VERSION_TOO_NEW () => ()
 
-fun peek_state ctxt f =
+fun peek_state ctxt overlord f =
   Synchronized.change_result global_state
-      (perhaps (try (load ctxt)) #> `snd #>> f)
+      (perhaps (try (load ctxt overlord)) #> `snd #>> f)
 
-fun clear_state ctxt =
+fun clear_state ctxt overlord =
   Synchronized.change global_state (fn _ =>
-      (MaSh.unlearn ctxt; (* also removes the state file *)
+      (MaSh.unlearn ctxt overlord; (* also removes the state file *)
        (false, empty_state)))
 
 end
 
-val mash_unlearn = clear_state
+fun mash_unlearn ctxt ({overlord, ...} : params) = clear_state ctxt overlord
 
 
 (*** Isabelle helpers ***)
@@ -589,7 +589,7 @@
 fun maybe_singleton_str _ "" = []
   | maybe_singleton_str pref s = [pref ^ s]
 
-val max_pat_breadth = 10
+val max_pat_breadth = 10 (* FUDGE *)
 
 fun term_features_of ctxt prover thy_name num_facts const_tab term_max_depth
                      type_max_depth ts =
@@ -641,7 +641,7 @@
          0.0
        else
          let val count = Symtab.lookup const_tab s |> the_default 1 in
-           (Real.fromInt num_facts / Real.fromInt count) (* FUDGE *)
+           Real.fromInt num_facts / Real.fromInt count (* FUDGE *)
          end)
     fun pattify_term _ _ 0 _ = []
       | pattify_term _ args _ (Const (x as (s, _))) =
@@ -906,7 +906,7 @@
 
 val chained_feature_factor = 0.5
 val extra_feature_factor = 0.1
-val num_extra_feature_facts = 10 (* FUDGE *)
+val num_extra_feature_facts = 0 (* FUDGE *)
 
 (* FUDGE *)
 fun weight_of_proximity_fact rank =
@@ -958,7 +958,7 @@
       |> features_of ctxt prover (theory_of_thm th) num_facts const_tab stature
       |> map (apsnd (fn r => weight * factor * r))
     val (access_G, suggs) =
-      peek_state ctxt (fn {access_G, ...} =>
+      peek_state ctxt overlord (fn {access_G, ...} =>
           if Graph.is_empty access_G then
             (access_G, [])
           else
@@ -974,7 +974,7 @@
                 |> rpair [] |-> fold (union (op = o pairself fst))
               val extra_feats =
                 facts
-                |> take (num_extra_feature_facts - length chained)
+                |> take (Int.max (0, num_extra_feature_facts - length chained))
                 |> weight_facts_steeply
                 |> map (chained_or_extra_features_of extra_feature_factor)
                 |> rpair [] |-> fold (union (op = o pairself fst))
@@ -1036,7 +1036,7 @@
           features_of ctxt prover thy 0 Symtab.empty (Local, General) [t]
           |> map fst
       in
-        peek_state ctxt (fn {access_G, ...} =>
+        peek_state ctxt overlord (fn {access_G, ...} =>
             let
               val parents = maximal_wrt_access_graph access_G facts
               val deps =
@@ -1060,7 +1060,7 @@
     val timer = Timer.startRealTimer ()
     fun next_commit_time () =
       Time.+ (Timer.checkRealTimer timer, commit_timeout)
-    val {access_G, ...} = peek_state ctxt I
+    val {access_G, ...} = peek_state ctxt overlord I
     val is_in_access_G = is_fact_in_graph access_G o snd
     val no_new_facts = forall is_in_access_G facts
   in
@@ -1114,7 +1114,7 @@
              Output.urgent_message "Committing..."
            else
              ();
-           map_state ctxt (do_commit (rev learns) relearns flops);
+           map_state ctxt overlord (do_commit (rev learns) relearns flops);
            if not last andalso auto_level = 0 then
              let val num_proofs = length learns + length relearns in
                "Learned " ^ string_of_int num_proofs ^ " " ^
@@ -1250,8 +1250,8 @@
   end
 
 fun is_mash_enabled () = (getenv "MASH" = "yes")
-fun mash_can_suggest_facts ctxt =
-  not (Graph.is_empty (#access_G (peek_state ctxt I)))
+fun mash_can_suggest_facts ctxt overlord =
+  not (Graph.is_empty (#access_G (peek_state ctxt overlord I)))
 
 (* Generate more suggestions than requested, because some might be thrown out
    later for various reasons. *)
@@ -1264,8 +1264,9 @@
    Sledgehammer and Try. *)
 val min_secs_for_learning = 15
 
-fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
-        max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
+fun relevant_facts ctxt (params as {overlord, learn, fact_filter, timeout, ...})
+        prover max_facts ({add, only, ...} : fact_override) hyp_ts concl_t
+        facts =
   if not (subset (op =) (the_list fact_filter, fact_filters)) then
     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
   else if only then
@@ -1295,7 +1296,7 @@
         | NONE =>
           if is_mash_enabled () then
             (maybe_learn ();
-             if mash_can_suggest_facts ctxt then meshN else mepoN)
+             if mash_can_suggest_facts ctxt overlord then meshN else mepoN)
           else
             mepoN
       val add_ths = Attrib.eval_thms ctxt add
@@ -1335,8 +1336,8 @@
       | _ => [("", mesh)]
     end
 
-fun kill_learners ctxt =
-  (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt)
+fun kill_learners ctxt ({overlord, ...} : params) =
+  (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord)
 fun running_learners () = Async_Manager.running_threads MaShN "learner"
 
 end;