merged
authorwenzelm
Tue, 20 May 2014 16:52:59 +0200
changeset 57024 c9e98c2498fd
parent 57020 f7cf92543e6c (diff)
parent 57023 0662ccd94158 (current diff)
child 57025 e7fd64f82876
merged
--- a/NEWS	Tue May 20 16:28:05 2014 +0200
+++ b/NEWS	Tue May 20 16:52:59 2014 +0200
@@ -380,6 +380,10 @@
 
 * Sledgehammer:
   - New prover "z3_new" with support for Isar proofs
+  - MaSh overhaul:
+      - A new SML-based learning engine eliminates the dependency on Python
+        and increases performance and reliability. See the Sledgehammer
+        documentation for details.
   - New option:
       smt_proofs
   - Renamed options:
--- a/src/Doc/Sledgehammer/document/root.tex	Tue May 20 16:28:05 2014 +0200
+++ b/src/Doc/Sledgehammer/document/root.tex	Tue May 20 16:52:59 2014 +0200
@@ -403,8 +403,7 @@
 \item[\labelitemi]
 An experimental alternative to MePo is \emph{MaSh}
 (\underline{Ma}chine Learner for \underline{S}ledge\underline{h}ammer). It
-relies on an external Python tool that applies machine learning to
-the problem of finding relevant facts.
+applies machine learning to the problem of finding relevant facts.
 
 \item[\labelitemi] The \emph{MeSh} filter combines MePo and MaSh.
 \end{enum}
@@ -1068,10 +1067,19 @@
 The traditional memoryless MePo relevance filter.
 
 \item[\labelitemi] \textbf{\textit{mash}:}
-The experimental MaSh machine learner. MaSh relies on the external Python
-program \texttt{mash.py}, which is part of Isabelle. To enable MaSh, set the
-environment variable \texttt{MASH} to \texttt{yes}. Persistent data is stored in
-the directory \texttt{\$ISABELLE\_HOME\_USER/mash}.
+The experimental MaSh machine learner.
+Two learning engines are provided:
+
+\begin{enum}
+\item[\labelitemi] \emph{sml} (also called \emph{sml\_knn}) refers to a Standard ML implementation
+of $k$-nearest neighbors.
+
+\item[\labelitemi] \emph{py} (also called \emph{yes}) refers to a Python implementation of naive
+Bayes. The program is included with Isabelle as \texttt{mash.py}.
+\end{enum}
+
+To enable MaSh, set the environment variable \texttt{MASH} to the name of the desired engine.
+Persistent data for both engines is stored in the directory \texttt{\$ISABELLE\_HOME\_USER/mash}.
 
 \item[\labelitemi] \textbf{\textit{mesh}:} The MeSh filter, which combines the
 rankings from MePo and MaSh.
--- a/src/HOL/Tools/ATP/atp_proof.ML	Tue May 20 16:28:05 2014 +0200
+++ b/src/HOL/Tools/ATP/atp_proof.ML	Tue May 20 16:52:59 2014 +0200
@@ -183,8 +183,7 @@
   |> find_first (fn (_, pattern) => String.isSubstring pattern output)
   |> Option.map fst
 
-fun extract_tstplike_proof_and_outcome verbose proof_delims known_failures
-                                       output =
+fun extract_tstplike_proof_and_outcome verbose proof_delims known_failures output =
   (case (extract_tstplike_proof proof_delims output,
       extract_known_atp_failure known_failures output) of
     (_, SOME ProofIncomplete) => ("", NONE)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:28:05 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:52:59 2014 +0200
@@ -15,7 +15,6 @@
   type prover_result = Sledgehammer_Prover.prover_result
 
   val trace : bool Config.T
-  val sml : bool Config.T
   val MePoN : string
   val MaShN : string
   val MeShN : string
@@ -37,7 +36,6 @@
   val extract_suggestions : string -> string * string list
 
   val mash_unlearn : Proof.context -> params -> unit
-  val is_mash_enabled : unit -> bool
   val nickname_of_thm : thm -> string
   val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
   val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
@@ -88,7 +86,6 @@
 open Sledgehammer_MePo
 
 val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
-val sml = Attrib.setup_config_bool @{binding sledgehammer_mash_sml} (K false)
 
 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
 
@@ -118,6 +115,25 @@
     ()
   end
 
+datatype mash_flavor = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB
+
+fun mash_flavor () =
+  (case getenv "MASH" of
+    "yes" => SOME MaSh_Py
+  | "py" => SOME MaSh_Py
+  | "sml" => SOME MaSh_SML_KNN
+  | "sml_knn" => SOME MaSh_SML_KNN
+  | "sml_nb" => SOME MaSh_SML_NB
+  | _ => NONE)
+
+val is_mash_enabled = is_some o mash_flavor
+
+fun is_mash_sml_enabled () =
+  (case mash_flavor () of
+    SOME MaSh_SML_KNN => true
+  | SOME MaSh_SML_NB => true
+  | _ => false)
+
 
 (*** Low-level communication with Python version of MaSh ***)
 
@@ -284,22 +300,16 @@
 structure MaSh_SML =
 struct
 
-fun max a b = if a > b then a else b
-
 exception BOTTOM of int
 
 fun heap cmp bnd a =
   let
     fun maxson l i =
-      let
-        val i31 = i + i + i + 1
-      in
+      let val i31 = i + i + i + 1 in
         if i31 + 2 < l then
-          let
-            val x = Unsynchronized.ref i31;
-            val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
-            val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
-          in
+          let val x = Unsynchronized.ref i31 in
+            if cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS then x := i31 + 1 else ();
+            if cmp (Array.sub (a, !x), Array.sub (a, i31 + 2)) = LESS then x := i31 + 2 else ();
             !x
           end
         else
@@ -354,7 +364,7 @@
     val _ = for (((l + 1) div 3) - 1)
 
     fun for2 i =
-      if i < max 2 (l - bnd) then () else
+      if i < Integer.max 2 (l - bnd) then () else
       let
         val e = Array.sub (a, i)
         val _ = Array.update (a, i, Array.sub (a, 0))
@@ -387,51 +397,57 @@
 fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
   let
     (* Can be later used for TFIDF *)
-    fun sym_wght _ = 1.0;
-    val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)));
+    fun sym_wght _ = 1.0
+
+    val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)))
+
     fun inc_overlap j v =
       let
-        val ov = snd (Array.sub (overlaps_sqr,j))
+        val ov = snd (Array.sub (overlaps_sqr, j))
       in
         Array.update (overlaps_sqr, j, (j, v + ov))
-      end;
+      end
+
     fun do_sym (s, con_wght) =
       let
-        val sw = sym_wght s;
-        val w2 = sw * sw * con_wght;
+        val sw = sym_wght s
+        val w2 = sw * sw * con_wght
+
         fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
       in
-        ignore (map do_th (get_sym_ths s))
-      end;
-    val () = ignore (map do_sym syms);
-    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
-    val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0)));
+        List.app do_th (get_sym_ths s)
+      end
+
+    val _ = List.app do_sym syms
+    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
+    val recommends = Array.tabulate (adv_max, rpair 0.0)
+
     fun inc_recommend j v =
-      if j >= adv_max then () else
-      let
-        val ov = snd (Array.sub (recommends,j))
-      in
-        Array.update (recommends, j, (j, v + ov))
-      end;
+      if j >= adv_max then ()
+      else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
+
     fun for k =
-      if k = knns then () else
-      if k >= adv_max then () else
-      let
-        val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1);
-        val o1 = Math.sqrt o2;
-        val () = inc_recommend j o1;
-        val ds = get_deps j;
-        val l = Real.fromInt (length ds);
-        val _ = map (fn d => inc_recommend d (o1 / l)) ds
-      in
-        for (k + 1)
-      end;
-    val () = for 0;
-    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
+      if k = knns orelse k >= adv_max then
+        ()
+      else
+        let
+          val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1)
+          val o1 = Math.sqrt o2
+          val _ = inc_recommend j o1
+          val ds = get_deps j
+          val l = Real.fromInt (length ds)
+          val _ = map (fn d => inc_recommend d (o1 / l)) ds
+        in
+          for (k + 1)
+        end
+
+    val _ = for 0
+    val _ = heap (Real.compare o pairself snd) advno recommends
+
     fun ret acc at =
-      if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
+      if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   in
-    ret [] (max 0 (adv_max - advno))
+    ret [] (Integer.max 0 (adv_max - advno))
   end
 
 val knns = 40 (* FUDGE *)
@@ -440,7 +456,7 @@
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
 
-fun learn_and_query ctxt parents access_G max_suggs hints feats =
+fun query ctxt parents access_G max_suggs hints feats =
   let
     val str_of_feat = space_implode "|"
 
@@ -469,9 +485,9 @@
         all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
 
     val facts = rev rev_facts
-    val fact_ary = Array.fromList facts
+    val fact_vec = Vector.fromList facts
 
-    val deps_ary = Array.fromList (rev rev_depss)
+    val deps_vec = Vector.fromList (rev rev_depss)
     val facts_ary = Array.array (num_feats, [])
     val _ =
       fold (fn feats => fn fact =>
@@ -484,11 +500,11 @@
   in
     trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
       elide_string 1000 (space_implode " " facts) ^ "}");
-    knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary)
+    knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
       (curry Array.sub facts_ary) knns max_suggs
       (map_filter (fn (feat, weight) =>
          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
-    |> map ((fn i => Array.sub (fact_ary, i)) o fst)
+    |> map (curry Vector.sub fact_vec o fst)
   end
 
 end;
@@ -578,7 +594,7 @@
                   fold extract_line_and_add_node node_lines Graph.empty),
                 length node_lines)
              | LESS =>
-               (if Config.get ctxt sml then wipe_out_mash_state_dir ()
+               (if is_mash_sml_enabled () then wipe_out_mash_state_dir ()
                 else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
              | GREATER => raise FILE_VERSION_TOO_NEW ())
          in
@@ -625,10 +641,10 @@
   Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
 
 fun clear_state ctxt overlord =
-  (* "unlearn" also removes the state file *)
+  (* "MaSh_Py.unlearn" also removes the state file *)
   Synchronized.change global_state (fn _ =>
-    (if Config.get ctxt sml then wipe_out_mash_state_dir ()
-     else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
+    (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord;
+     (false, empty_state)))
 
 end
 
@@ -638,8 +654,6 @@
 
 (*** Isabelle helpers ***)
 
-fun is_mash_enabled () = (getenv "MASH" = "yes")
-
 val local_prefix = "local" ^ Long_Name.separator
 
 fun elided_backquote_thm threshold th =
@@ -971,9 +985,6 @@
     | NONE => false)
   | is_size_def _ _ = false
 
-fun no_dependencies_for_status status =
-  status = Non_Rec_Def orelse status = Rec_Def
-
 fun trim_dependencies deps =
   if length deps > max_dependencies then NONE else SOME deps
 
@@ -1022,18 +1033,17 @@
       val num_isar_deps = length isar_deps
     in
       if verbose andalso auto_level = 0 then
-        "MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^
-        " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts."
-        |> Output.urgent_message
+        Output.urgent_message ("MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^
+          string_of_int num_isar_deps ^ " + " ^ string_of_int (length facts - num_isar_deps) ^
+          " facts.")
       else
         ();
       (case run_prover_for_mash ctxt params prover name facts goal of
         {outcome = NONE, used_facts, ...} =>
         (if verbose andalso auto_level = 0 then
            let val num_facts = length used_facts in
-             "Found proof with " ^ string_of_int num_facts ^ " fact" ^
-             plural_s num_facts ^ "."
-             |> Output.urgent_message
+             Output.urgent_message ("Found proof with " ^ string_of_int num_facts ^ " fact" ^
+               plural_s num_facts ^ ".")
            end
          else
            ();
@@ -1187,40 +1197,57 @@
       |> features_of ctxt (theory_of_thm th) num_facts const_tab stature false
       |> map (apsnd (fn r => weight * factor * r))
 
-    val (access_G, suggs) =
+    fun query_args access_G =
+      let
+        val parents = maximal_wrt_access_graph access_G facts
+        val hints = chained
+          |> filter (is_fact_in_graph access_G o snd)
+          |> map (nickname_of_thm o snd)
+
+        val goal_feats =
+          features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
+        val chained_feats = chained
+          |> map (rpair 1.0)
+          |> map (chained_or_extra_features_of chained_feature_factor)
+          |> rpair [] |-> fold (union (eq_fst (op =)))
+        val extra_feats = facts
+          |> take (Int.max (0, num_extra_feature_facts - length chained))
+          |> filter fact_has_right_theory
+          |> weight_facts_steeply
+          |> map (chained_or_extra_features_of extra_feature_factor)
+          |> rpair [] |-> fold (union (eq_fst (op =)))
+        val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
+          |> debug ? sort (Real.compare o swap o pairself snd)
+      in
+        (parents, hints, feats)
+      end
+
+    val sml = is_mash_sml_enabled ()
+
+    val (access_G, py_suggs) =
       peek_state ctxt overlord (fn {access_G, ...} =>
         if Graph.is_empty access_G then
           (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
         else
-          let
-            val parents = maximal_wrt_access_graph access_G facts
-            val goal_feats =
-              features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
-            val chained_feats = chained
-              |> map (rpair 1.0)
-              |> map (chained_or_extra_features_of chained_feature_factor)
-              |> rpair [] |-> fold (union (eq_fst (op =)))
-            val extra_feats = facts
-              |> take (Int.max (0, num_extra_feature_facts - length chained))
-              |> filter fact_has_right_theory
-              |> weight_facts_steeply
-              |> map (chained_or_extra_features_of extra_feature_factor)
-              |> rpair [] |-> fold (union (eq_fst (op =)))
-            val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
-              |> debug ? sort (Real.compare o swap o pairself snd)
-            val hints = chained
-              |> filter (is_fact_in_graph access_G o snd)
-              |> map (nickname_of_thm o snd)
-          in
-            (access_G,
-             if Config.get ctxt sml then
-               MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats
-             else
-               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats))
-          end)
+          (access_G,
+           if sml then
+             []
+           else
+             let val (parents, hints, feats) = query_args access_G in
+               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+             end))
+
+    val sml_suggs =
+      if sml then
+        let val (parents, hints, feats) = query_args access_G in
+          MaSh_SML.query ctxt parents access_G max_facts hints feats
+        end
+      else
+        []
+
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   in
-    find_mash_suggestions ctxt max_facts suggs facts chained unknown
+    find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown
     |> pairself (map fact_of_raw_fact)
   end
 
@@ -1280,7 +1307,7 @@
               |> filter (is_fact_in_graph access_G)
               |> map nickname_of_thm
           in
-            if Config.get ctxt sml then
+            if is_mash_sml_enabled () then
               let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
                 {access_G = access_G, num_known_facts = num_known_facts + 1,
                  dirty = Option.map (cons name) dirty}
@@ -1305,6 +1332,7 @@
     val timer = Timer.startRealTimer ()
     fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
 
+    val sml = is_mash_sml_enabled ()
     val {access_G, ...} = peek_state ctxt overlord I
     val is_in_access_G = is_fact_in_graph access_G o snd
     val no_new_facts = forall is_in_access_G facts
@@ -1323,7 +1351,7 @@
         val name_tabs = build_name_tables nickname_of_thm facts
 
         fun deps_of status th =
-          if no_dependencies_for_status status then
+          if status = Non_Rec_Def orelse status = Rec_Def then
             SOME []
           else if run_prover then
             prover_dependencies_of ctxt params prover auto_level facts name_tabs th
@@ -1346,7 +1374,7 @@
                   (false, SOME names, []) => SOME (map #1 learns @ names)
                 | _ => NONE)
             in
-              if Config.get ctxt sml then
+              if sml then
                 ()
               else
                 (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
@@ -1355,18 +1383,13 @@
             end
 
         fun commit last learns relearns flops =
-          (if debug andalso auto_level = 0 then
-             Output.urgent_message "Committing..."
-           else
-             ();
+          (if debug andalso auto_level = 0 then Output.urgent_message "Committing..." else ();
            map_state ctxt overlord (do_commit (rev learns) relearns flops);
            if not last andalso auto_level = 0 then
              let val num_proofs = length learns + length relearns in
-               "Learned " ^ string_of_int num_proofs ^ " " ^
-               (if run_prover then "automatic" else "Isar") ^ " proof" ^
-               plural_s num_proofs ^ " in the last " ^
-               string_of_time commit_timeout ^ "."
-               |> Output.urgent_message
+               Output.urgent_message ("Learned " ^ string_of_int num_proofs ^ " " ^
+                 (if run_prover then "automatic" else "Isar") ^ " proof" ^
+                 plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout ^ ".")
              end
            else
              ())
@@ -1478,14 +1501,12 @@
       |> Output.urgent_message
   in
     if run_prover then
-      ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
-       " for automatic proofs (" ^ quote prover ^ " timeout: " ^ string_of_time timeout ^
-       ").\n\nCollecting Isar proofs first..."
-       |> Output.urgent_message;
+      (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
+         plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^ " timeout: " ^
+         string_of_time timeout ^ ").\n\nCollecting Isar proofs first...");
        learn 1 false;
-       "Now collecting automatic proofs. This may take several hours. You can safely stop the \
-       \learning process at any point."
-       |> Output.urgent_message;
+       Output.urgent_message "Now collecting automatic proofs. This may take several hours. You \
+         \can safely stop the learning process at any point.";
        learn 0 true)
     else
       (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
@@ -1530,6 +1551,7 @@
           end
         else
           ()
+
       fun maybe_learn () =
         if is_mash_enabled () andalso learn then
           let
@@ -1551,6 +1573,7 @@
           end
         else
           false
+
       val (save, effective_fact_filter) =
         (case fact_filter of
           SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
@@ -1565,18 +1588,22 @@
       val add_ths = Attrib.eval_thms ctxt add
 
       fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
+
       fun add_and_take accepts =
         (case add_ths of
            [] => accepts
          | _ => (unique_facts |> filter in_add |> map fact_of_raw_fact) @
                 (accepts |> filter_out in_add))
         |> take max_facts
+
       fun mepo () =
         (mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts
          |> weight_facts_steeply, [])
+
       fun mash () =
         mash_suggested_facts ctxt params (generous_max_facts max_facts) hyp_ts concl_t facts
         |>> weight_facts_steeply
+
       val mess =
         (* the order is important for the "case" expression below *)
         [] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash)
@@ -1584,7 +1611,7 @@
            |> Par_List.map (apsnd (fn f => f ()))
       val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
     in
-      if Config.get ctxt sml orelse not save then () else MaSh_Py.save ctxt overlord;
+      if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord;
       (case (fact_filter, mess) of
         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
@@ -1594,7 +1621,7 @@
 
 fun kill_learners ctxt ({overlord, ...} : params) =
   (Async_Manager.kill_threads MaShN "learner";
-   if Config.get ctxt sml then () else MaSh_Py.shutdown ctxt overlord)
+   if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord)
 
 fun running_learners () = Async_Manager.running_threads MaShN "learner"