updated naive Bayes
authorblanchet
Tue, 27 May 2014 17:48:11 +0200
changeset 57095 001ec97c3e59
parent 57094 589ec121ce1a
child 57096 e4074b91b2a6
updated naive Bayes
src/Doc/Sledgehammer/document/root.tex
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/Doc/Sledgehammer/document/root.tex	Tue May 27 17:32:42 2014 +0200
+++ b/src/Doc/Sledgehammer/document/root.tex	Tue May 27 17:48:11 2014 +0200
@@ -1070,12 +1070,11 @@
 The experimental MaSh machine learner. Three learning engines are provided:
 
 \begin{enum}
-\item[\labelitemi] \textbf{\textit{sml\_knn}} (also called
-\textbf{\textit{sml}}) is a Standard ML implementation of $k$-nearest
-neighbors.
+\item[\labelitemi] \textbf{\textit{sml\_nb}} (also called \textbf{\textit{sml}})
+is a Standard ML implementation of naive Bayes.
 
-\item[\labelitemi] \textbf{\textit{sml\_nb}} is a Standard ML implementation of
-naive Bayes.
+\item[\labelitemi] \textbf{\textit{sml\_knn}} is a Standard ML implementation of
+$k$-nearest neighbors.
 
 \item[\labelitemi] \textbf{\textit{py}} (also called \textbf{\textit{yes}}) is a
 Python implementation of naive Bayes. The program is included with Isabelle as
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 27 17:32:42 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 27 17:48:11 2014 +0200
@@ -121,7 +121,7 @@
     (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
       "yes" => SOME MaSh_Py
     | "py" => SOME MaSh_Py
-    | "sml" => SOME MaSh_SML_kNN
+    | "sml" => SOME MaSh_SML_NB
     | "sml_knn" => SOME MaSh_SML_kNN
     | "sml_nb" => SOME MaSh_SML_NB
     | _ => NONE)
@@ -422,78 +422,72 @@
     ret [] (Integer.max 0 (adv_max - advno))
   end
 
-(* Two arguments control the behaviour of naive Bayes: prior and ess. Prior expresses our belief in
-   usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the
-   prior. *)
-fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms =
+val tau = 0.02
+val posWeight = 2.0
+val defVal = ~15.0
+val defPriWei = 20
+
+fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms =
   let
-    val afreq = Unsynchronized.ref 0
-    val tfreq = Array.array (avail_num, 0)
-    val sfreq = Array.array (avail_num, Inttab.empty)
+    val afreq = Unsynchronized.ref 0;
+    val tfreq = Array.array (avail_num, 0);
+    val sfreq = Array.array (avail_num, Inttab.empty);
+    val dffreq = Array.array (sym_num, 0);
 
-    fun nb_learn syms ts =
+    fun learn th syms deps =
       let
-        fun add_sym hpis sym =
+        fun add_th t =
           let
-            val im = Array.sub (sfreq, hpis)
-            val v = the_default 0 (Inttab.lookup im sym)
+            val im = Array.sub (sfreq, t);
+            fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf;
           in
-            Array.update (sfreq, hpis, Inttab.update (sym, v + 1) im)
-          end
+            Array.update (tfreq, t, 1 + Array.sub (tfreq, t));
+            Array.update (sfreq, t, fold fold_fn syms im)
+          end;
+        fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s));
+      in
+        List.app add_th (replicate defPriWei th);
+        List.app add_th deps;
+        List.app add_sym syms;
+        afreq := !afreq + 1
+      end;
 
-        fun add_th t =
-          (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms)
-      in
-        afreq := !afreq + 1; List.app add_th ts
-      end
+    fun tfidf _ = 1.0;
+    (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)));*)
 
-    fun nb_eval syms =
+    fun eval syms =
       let
         fun log_posterior i =
           let
-            val symh = fold (Inttab.update o rpair ()) syms Inttab.empty
-            val n = Real.fromInt (Array.sub (tfreq, i))
-            val sfreqh = Array.sub (sfreq, i)
-            val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq)
-            val mp = ess * p
-            val logmp = Math.ln mp
-            val lognmp = Math.ln (n + mp)
-
-            fun in_sfreqh (s, sfreqv) (sofar, sfsymh) =
-              let val sfreqv = Real.fromInt sfreqv in
-                if Inttab.defined sfsymh s then
-                  (sofar + Math.ln (sfreqv + mp), Inttab.delete s sfsymh)
-                else
-                  (sofar + Math.ln (n - sfreqv + mp), sfsymh)
-              end
-
-            val (postsfreqh, symh) = Inttab.fold in_sfreqh sfreqh (Math.ln n, symh)
-            val len_mem = length (Inttab.keys symh)
-            val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh)
+            val tfreq = Real.fromInt (Array.sub (tfreq, i));
+            fun fold_syms (f, fw) (res, sfh) =
+              (case Inttab.lookup sfh f of
+                SOME sf =>
+                (res + tfidf f * fw * Math.ln (posWeight * Real.fromInt sf / tfreq),
+                 Inttab.delete f sfh)
+              | NONE => (res + fw * defVal, sfh));
+            val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i));
+            fun fold_sfh (f, sf) sow =
+              sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq));
+            val sumOfWei = Inttab.fold fold_sfh sfh 0.0;
           in
-            postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp -
-              Real.fromInt sym_num * Math.ln (n + ess)
+            res + tau * sumOfWei
           end
-
-        val posterior = Array.tabulate (adv_max, swap o `log_posterior)
-
+        val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j)));
         fun ret acc at =
-          if at = Array.length posterior then acc
-          else ret (Array.sub (posterior, at) :: acc) (at + 1)
+          if at = adv_max then acc else ret (Array.sub (posterior,at) :: acc) (at + 1)
       in
         heap (Real.compare o pairself snd) advno posterior;
         ret [] (Integer.max 0 (adv_max - advno))
-      end
+      end;
 
     fun for i =
-      if i = avail_num then () else (nb_learn (get_th_syms i) (i :: get_deps i); for (i + 1))
+      if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1))
   in
-    for 0; nb_eval syms
+    for 0; eval syms
   end
 
 val knns = 40 (* FUDGE *)
-val ess = 0.00001 (* FUDGE *)
-val prior = 0.001 (* FUDGE *)
 
 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
 
@@ -532,6 +526,8 @@
 
     val num_visible_facts = length visible_facts
     val get_deps = curry Vector.sub deps_vec
+    val syms = map_filter (fn (feat, weight) =>
+      Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
   in
     trace_msg ctxt (fn () => "MaSh_SML " ^ (if engine = MaSh_SML_kNN then "kNN" else "NB") ^
       " query " ^ encode_features feats ^ " from {" ^
@@ -548,8 +544,6 @@
                end)
              rev_featss num_facts
          val get_facts = curry Array.sub facts_ary
-         val syms = map_filter (fn (feat, weight) =>
-           Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
        in
          k_nearest_neighbors num_facts num_visible_facts get_deps get_facts knns max_suggs syms
        end
@@ -557,10 +551,9 @@
        let
          val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
          val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
-         val unweighted_syms = map_filter (Symtab.lookup feat_tab o fst) feats
        in
-         naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats ess prior
-           max_suggs unweighted_syms
+         naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs
+           syms
        end)
     |> map (curry Vector.sub fact_vec o fst)
   end
@@ -1237,17 +1230,12 @@
           |> map (chained_or_extra_features_of chained_feature_factor)
           |> rpair [] |-> fold (union (eq_fst (op =)))
         val extra_feats =
-          (* As long as SML NB does not support weights, it makes little sense to include these
-             extra features. *)
-          if engine = MaSh_SML_NB then
-            []
-          else
-            facts
-            |> take (Int.max (0, num_extra_feature_facts - length chained))
-            |> filter fact_has_right_theory
-            |> weight_facts_steeply
-            |> map (chained_or_extra_features_of extra_feature_factor)
-            |> rpair [] |-> fold (union (eq_fst (op =)))
+          facts
+          |> take (Int.max (0, num_extra_feature_facts - length chained))
+          |> filter fact_has_right_theory
+          |> weight_facts_steeply
+          |> map (chained_or_extra_features_of extra_feature_factor)
+          |> rpair [] |-> fold (union (eq_fst (op =)))
         val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
           |> debug ? sort (Real.compare o swap o pairself snd)
       in