add self dependency to naive Bayes
authorblanchet
Thu, 22 May 2014 13:07:52 +0200
changeset 57059 fcd25f2e3da6
parent 57058 b1ae5079b795
child 57060 7a1167331c8c
add self dependency to naive Bayes
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu May 22 13:07:51 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu May 22 13:07:52 2014 +0200
@@ -438,21 +438,20 @@
             val im = Array.sub (sfreq, hpis)
             val v = the_default 0 (Inttab.lookup im sym)
           in
-            Array.update(sfreq, hpis, Inttab.update (sym, v + 1) im)
+            Array.update (sfreq, hpis, Inttab.update (sym, v + 1) im)
           end
 
         fun add_th t =
           (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms)
       in
-        afreq := !afreq + 1;
-        List.app add_th ts
+        afreq := !afreq + 1; List.app add_th ts
       end
 
     fun nb_eval syms =
       let
         fun log_posterior i =
           let
-            val symh = fold (fn s => fn sf => Inttab.update (s, ()) sf) syms Inttab.empty
+            val symh = fold (Inttab.update o rpair ()) syms Inttab.empty
             val n = Real.fromInt (Array.sub (tfreq, i))
             val sfreqh = Array.sub (sfreq, i)
             val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq)
@@ -473,21 +472,21 @@
             val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh)
           in
             postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp -
-              Real.fromInt sym_num * Math.ln(n + ess)
+              Real.fromInt sym_num * Math.ln (n + ess)
           end
 
         val posterior = Array.tabulate (adv_max, swap o `log_posterior)
 
         fun ret acc at =
           if at = Array.length posterior then acc
-          else ret (Array.sub (posterior,at) :: acc) (at + 1)
+          else ret (Array.sub (posterior, at) :: acc) (at + 1)
       in
         heap (Real.compare o pairself snd) advno posterior;
         ret [] (Integer.max 0 (adv_max - advno))
       end
 
     fun for i =
-      if i = avail_num then () else (nb_learn (get_th_syms i) (get_deps i); for (i + 1))
+      if i = avail_num then () else (nb_learn (get_th_syms i) (i :: get_deps i); for (i + 1))
   in
     for 0; nb_eval syms
   end