added naive Bayes ML implementation, due to Cezary Kaliszyk (like k-NN)
authorblanchet
Tue, 20 May 2014 22:28:44 +0200
changeset 57029 75cc30d2b83f
parent 57028 e5466055e94f
child 57030 b592202a45cc
added naive Bayes ML implementation, due to Cezary Kaliszyk (like k-NN)
NEWS
src/Doc/Sledgehammer/document/root.tex
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/NEWS	Tue May 20 22:28:08 2014 +0200
+++ b/NEWS	Tue May 20 22:28:44 2014 +0200
@@ -386,8 +386,8 @@
       - Activation of MaSh now works via the "mash" system option (without
         requiring restart), instead of former settings variable "MASH".
         The option can be edited in Isabelle/jEdit menu Plugin
-        Options / Isabelle / General. Allowed values include "sml" (for the new
-        SML engine), "py" (for the Python engine), and "no".
+        Options / Isabelle / General. Allowed values include "sml" (for the
+        default SML engine), "py" (for the old Python engine), and "none".
   - New option:
       smt_proofs
   - Renamed options:
--- a/src/Doc/Sledgehammer/document/root.tex	Tue May 20 22:28:08 2014 +0200
+++ b/src/Doc/Sledgehammer/document/root.tex	Tue May 20 22:28:44 2014 +0200
@@ -1070,8 +1070,8 @@
 The experimental MaSh machine learner. Three learning engines are provided:
 
 \begin{enum}
-\item[\labelitemi] \textbf{\textit{sml}} (also called
-\textbf{\textit{sml\_knn}}) is a Standard ML implementation of $k$-nearest
+\item[\labelitemi] \textbf{\textit{sml\_knn}} (also called
+\textbf{\textit{sml}}) is a Standard ML implementation of $k$-nearest
 neighbors.
 
 \item[\labelitemi] \textbf{\textit{sml\_nb}} is a Standard ML implementation of
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 22:28:08 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 22:28:44 2014 +0200
@@ -115,26 +115,21 @@
     ()
   end
 
-datatype mash_engine = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB
+datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB
 
 fun mash_engine () =
   let val flag1 = Options.default_string @{system_option maSh} in
     (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
       "yes" => SOME MaSh_Py
     | "py" => SOME MaSh_Py
-    | "sml" => SOME MaSh_SML_KNN
-    | "sml_knn" => SOME MaSh_SML_KNN
+    | "sml" => SOME MaSh_SML_kNN
+    | "sml_knn" => SOME MaSh_SML_kNN
     | "sml_nb" => SOME MaSh_SML_NB
     | _ => NONE)
   end
 
 val is_mash_enabled = is_some o mash_engine
-
-fun is_mash_sml_enabled () =
-  (case mash_engine () of
-    SOME MaSh_SML_KNN => true
-  | SOME MaSh_SML_NB => true
-  | _ => false)
+val the_mash_engine = the_default MaSh_SML_kNN o mash_engine
 
 
 (*** Low-level communication with Python version of MaSh ***)
@@ -320,71 +315,55 @@
       end
 
     fun trickledown l i e =
-      let
-        val j = maxson l i
-      in
+      let val j = maxson l i in
         if cmp (Array.sub (a, j), e) = GREATER then
-          let val _ = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
-        else Array.update (a, i, e)
+          (Array.update (a, i, Array.sub (a, j)); trickledown l j e)
+        else
+          Array.update (a, i, e)
       end
 
     fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e)
 
     fun bubbledown l i =
-      let
-        val j = maxson l i
-        val _ = Array.update (a, i, Array.sub (a, j))
-      in
+      let val j = maxson l i in
+        Array.update (a, i, Array.sub (a, j));
         bubbledown l j
       end
 
     fun bubble l i = bubbledown l i handle BOTTOM i => i
 
     fun trickleup i e =
-      let
-        val father = (i - 1) div 3
-      in
+      let val father = (i - 1) div 3 in
         if cmp (Array.sub (a, father), e) = LESS then
-          let
-            val _ = Array.update (a, i, Array.sub (a, father))
-          in
-            if father > 0 then trickleup father e else Array.update (a, 0, e)
-          end
-        else Array.update (a, i, e)
+          (Array.update (a, i, Array.sub (a, father));
+           if father > 0 then trickleup father e else Array.update (a, 0, e))
+        else
+          Array.update (a, i, e)
       end
 
     val l = Array.length a
 
-    fun for i =
-      if i < 0 then () else
-      let
-        val _ = trickle l i (Array.sub (a, i))
-      in
-        for (i - 1)
-      end
-
-    val _ = for (((l + 1) div 3) - 1)
+    fun for i = if i < 0 then () else (trickle l i (Array.sub (a, i)); for (i - 1))
 
     fun for2 i =
-      if i < Integer.max 2 (l - bnd) then () else
-      let
-        val e = Array.sub (a, i)
-        val _ = Array.update (a, i, Array.sub (a, 0))
-        val _ = trickleup (bubble i 0) e
-      in
-        for2 (i - 1)
-      end
-
-    val _ = for2 (l - 1)
+      if i < Integer.max 2 (l - bnd) then
+        ()
+      else
+        let val e = Array.sub (a, i) in
+          Array.update (a, i, Array.sub (a, 0));
+          trickleup (bubble i 0) e;
+          for2 (i - 1)
+        end
   in
+    for (((l + 1) div 3) - 1);
+    for2 (l - 1);
     if l > 1 then
-      let
-        val e = Array.sub (a, 1)
-        val _ = Array.update (a, 1, Array.sub (a, 0))
-      in
+      let val e = Array.sub (a, 1) in
+        Array.update (a, 1, Array.sub (a, 0));
         Array.update (a, 0, e)
       end
-    else ()
+    else
+      ()
   end
 
 (*
@@ -421,7 +400,7 @@
       end
 
     val _ = List.app do_sym syms
-    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
+    val _ = heap (Real.compare o pairself snd) knns overlaps_sqr
     val recommends = Array.tabulate (adv_max, rpair 0.0)
 
     fun inc_recommend j v =
@@ -438,27 +417,97 @@
           val _ = inc_recommend j o1
           val ds = get_deps j
           val l = Real.fromInt (length ds)
-          val _ = map (fn d => inc_recommend d (o1 / l)) ds
         in
-          for (k + 1)
+          List.app (fn d => inc_recommend d (o1 / l)) ds; for (k + 1)
         end
 
-    val _ = for 0
-    val _ = heap (Real.compare o pairself snd) advno recommends
-
     fun ret acc at =
       if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   in
+    for 0;
+    heap (Real.compare o pairself snd) advno recommends;
     ret [] (Integer.max 0 (adv_max - advno))
   end
 
+(* Two arguments control the behaviour of nbayes: prior and ess. Prior expresses our belief in
+   usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the
+   prior. *)
+fun nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms =
+  let
+    val afreq = Unsynchronized.ref 0
+    val tfreq = Array.array (avail_num, 0)
+    val sfreq = Array.array (avail_num, Inttab.empty)
+
+    fun nb_learn syms ts =
+      let
+        fun add_sym hpis sym =
+          let
+            val im = Array.sub (sfreq, hpis)
+            val v = the_default 0 (Inttab.lookup im sym)
+          in
+            Array.update(sfreq, hpis, Inttab.update (sym, v + 1) im)
+          end
+
+        fun add_th t =
+          (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms)
+      in
+        afreq := !afreq + 1;
+        List.app add_th ts
+      end
+
+    fun nb_eval syms =
+      let
+        fun log_posterior i =
+          let
+            val symh = fold (fn s => fn sf => Inttab.update (s, ()) sf) syms Inttab.empty
+            val n = Real.fromInt (Array.sub (tfreq, i))
+            val sfreqh = Array.sub (sfreq, i)
+            val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq)
+            val mp = ess * p
+            val logmp = Math.ln mp
+            val lognmp = Math.ln (n + mp)
+
+            fun in_sfreqh (s, sfreqv) (sofar, sfsymh) =
+              let val sfreqv = Real.fromInt sfreqv in
+                if Inttab.defined sfsymh s then
+                  (sofar + Math.ln (sfreqv + mp), Inttab.delete s sfsymh)
+                else
+                  (sofar + Math.ln (n - sfreqv + mp), sfsymh)
+              end
+
+            val (postsfreqh, symh) = Inttab.fold in_sfreqh sfreqh (Math.ln n, symh)
+            val len_mem = length (Inttab.keys symh)
+            val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh)
+          in
+            postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp -
+              Real.fromInt sym_num * Math.ln(n + ess)
+          end
+
+        val posterior = Array.tabulate (adv_max, swap o `log_posterior)
+
+        fun ret acc at =
+          if at = Array.length posterior then acc
+          else ret (Array.sub (posterior,at) :: acc) (at + 1)
+      in
+        heap (Real.compare o pairself snd) advno posterior;
+        ret [] (Integer.max 0 (adv_max - advno))
+      end
+
+    fun for i =
+      if i = avail_num then () else (nb_learn (get_th_syms i) (get_deps i); for (i + 1))
+  in
+    for 0; nb_eval syms
+  end
+
 val knns = 40 (* FUDGE *)
+val ess = 0.00001 (* FUDGE *)
+val prior = 0.001 (* FUDGE *)
 
 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
 
-fun query ctxt parents access_G max_suggs hints feats =
+fun query ctxt engine parents access_G max_suggs hints feats =
   let
     val str_of_feat = space_implode "|"
 
@@ -470,9 +519,9 @@
        |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
       (if null hints then [] else [(".goal", feats, hints)])
 
-    val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) =
+    val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
       fold (fn (fact, feats, deps) =>
-            fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
+            fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
           let
             fun add_feat (feat, weight) (xtab as (n, tab, _)) =
               (case Symtab.lookup tab feat of
@@ -481,7 +530,7 @@
 
             val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab
           in
-            (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
+            (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
              add_to_xtab fact fact_xtab, feat_xtab')
           end)
         all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
@@ -490,22 +539,40 @@
     val fact_vec = Vector.fromList facts
 
     val deps_vec = Vector.fromList (rev rev_depss)
-    val facts_ary = Array.array (num_feats, [])
-    val _ =
-      fold (fn feats => fn fact =>
-          let val fact' = fact - 1 in
-            List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
-              feats;
-            fact'
-          end)
-        featss (length featss)
+
+    val avail_num = Vector.length deps_vec
+    val adv_max = length visible_facts
+    val get_deps = curry Vector.sub deps_vec
+    val advno = max_suggs
   in
     trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
       elide_string 1000 (space_implode " " facts) ^ "}");
-    knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
-      (curry Array.sub facts_ary) knns max_suggs
-      (map_filter (fn (feat, weight) =>
-         Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
+    (if engine = MaSh_SML_kNN then
+       let
+        val facts_ary = Array.array (num_feats, [])
+        val _ =
+          fold (fn feats => fn fact =>
+              let val fact' = fact - 1 in
+                List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
+                  feats;
+                fact'
+              end)
+            rev_featss num_facts
+         val get_sym_ths = curry Array.sub facts_ary
+         val syms = map_filter (fn (feat, weight) =>
+           Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats
+       in
+         knn avail_num adv_max get_deps get_sym_ths knns advno syms
+       end
+     else
+       let
+         val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
+         val get_th_syms = curry Vector.sub unweighted_feats_ary
+         val sym_num = num_feats
+         val unweighted_syms = map_filter (Symtab.lookup feat_tab o str_of_feat o fst) feats
+       in
+         nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno unweighted_syms
+       end)
     |> map (curry Vector.sub fact_vec o fst)
   end
 
@@ -596,8 +663,10 @@
                   fold extract_line_and_add_node node_lines Graph.empty),
                 length node_lines)
              | LESS =>
-               (if is_mash_sml_enabled () then wipe_out_mash_state_dir ()
-                else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
+               (* cannot parse old file *)
+               (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
+                else wipe_out_mash_state_dir ();
+                (Graph.empty, 0))
              | GREATER => raise FILE_VERSION_TOO_NEW ())
          in
            trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
@@ -645,7 +714,8 @@
 fun clear_state ctxt overlord =
   (* "MaSh_Py.unlearn" also removes the state file *)
   Synchronized.change global_state (fn _ =>
-    (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord;
+    (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
+     else wipe_out_mash_state_dir ();
      (false, empty_state)))
 
 end
@@ -1224,7 +1294,7 @@
         (parents, hints, feats)
       end
 
-    val sml = is_mash_sml_enabled ()
+    val engine = the_mash_engine ()
 
     val (access_G, py_suggs) =
       peek_state ctxt overlord (fn {access_G, ...} =>
@@ -1232,20 +1302,20 @@
           (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
         else
           (access_G,
-           if sml then
-             []
-           else
+           if engine = MaSh_Py then
              let val (parents, hints, feats) = query_args access_G in
                MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
-             end))
+             end
+           else
+             []))
 
     val sml_suggs =
-      if sml then
+      if engine = MaSh_Py then
+        []
+      else
         let val (parents, hints, feats) = query_args access_G in
-          MaSh_SML.query ctxt parents access_G max_facts hints feats
+          MaSh_SML.query ctxt engine parents access_G max_facts hints feats
         end
-      else
-        []
 
     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   in
@@ -1309,13 +1379,13 @@
               |> filter (is_fact_in_graph access_G)
               |> map nickname_of_thm
           in
-            if is_mash_sml_enabled () then
+            if the_mash_engine () = MaSh_Py then
+              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
+            else
               let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
                 {access_G = access_G, num_known_facts = num_known_facts + 1,
                  dirty = Option.map (cons name) dirty}
               end
-            else
-              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
           end);
         (true, "")
       end)
@@ -1334,7 +1404,7 @@
     val timer = Timer.startRealTimer ()
     fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
 
-    val sml = is_mash_sml_enabled ()
+    val engine = the_mash_engine ()
     val {access_G, ...} = peek_state ctxt overlord I
     val is_in_access_G = is_fact_in_graph access_G o snd
     val no_new_facts = forall is_in_access_G facts
@@ -1376,11 +1446,11 @@
                   (false, SOME names, []) => SOME (map #1 learns @ names)
                 | _ => NONE)
             in
-              if sml then
-                ()
+              if engine = MaSh_Py then
+                (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
+                 MaSh_Py.relearn ctxt overlord save relearns)
               else
-                (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
-                 MaSh_Py.relearn ctxt overlord save relearns);
+                ();
               {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
             end
 
@@ -1613,7 +1683,7 @@
            |> Par_List.map (apsnd (fn f => f ()))
       val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
     in
-      if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord;
+      if the_mash_engine () = MaSh_Py andalso save then MaSh_Py.save ctxt overlord else ();
       (case (fact_filter, mess) of
         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
@@ -1623,7 +1693,7 @@
 
 fun kill_learners ctxt ({overlord, ...} : params) =
   (Async_Manager.kill_threads MaShN "learner";
-   if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord)
+   if the_mash_engine () = MaSh_Py then MaSh_Py.shutdown ctxt overlord else ())
 
 fun running_learners () = Async_Manager.running_threads MaShN "learner"