src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57366 d01d1befe4a3
parent 57365 d2090a01e920
child 57367 e64c1b174f4b
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:21 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:30 2014 +0200
@@ -512,12 +512,9 @@
 
     fun for i =
       if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
-
-    val ln_afreq = Math.ln (Real.fromInt num_facts)
   in
     for 0;
-    (Array.vector tfreq, Array.vector sfreq,
-     Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq))
+    (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
   end
 
 fun learn num_facts get_deps get_feats num_feats =
@@ -529,12 +526,15 @@
     learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
   end
 
-fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, idf) =
+fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) =
   let
     val tau = 0.05 (* FUDGE *)
     val pos_weight = 10.0 (* FUDGE *)
     val def_val = ~15.0 (* FUDGE *)
 
+    val ln_afreq = Math.ln (Real.fromInt num_facts)
+    val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) dffreq
+
     fun tfidf feat = Vector.sub (idf, feat)
 
     fun log_posterior i =
@@ -629,7 +629,7 @@
   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
 
-fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
+fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
 
@@ -643,11 +643,11 @@
       naive_bayes_cpp max_suggs learns feats
     else
       let
-        val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) =
+        val (rev_depss, rev_featss, ((num_facts, fact_tab), rev_facts), (num_feats, feat_tab)) =
           fold (fn (fact, feats, deps) =>
-                fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
+                fn (rev_depss, rev_featss, (fact_xtab as (_, fact_tab), rev_facts), feat_xtab) =>
               let
-                fun add_feat feat (xtab as (n, tab, _)) =
+                fun add_feat feat (xtab as (n, tab)) =
                   (case Symtab.lookup tab feat of
                     SOME i => (i, xtab)
                   | NONE => (n, add_to_xtab feat xtab))
@@ -655,9 +655,9 @@
                 val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
               in
                 (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
-                 add_to_xtab fact fact_xtab, feat_xtab')
+                 (add_to_xtab fact fact_xtab, fact :: rev_facts), feat_xtab')
               end)
-            learns ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
+            learns ([], [], ((0, Symtab.empty), []), (0, Symtab.empty))
 
         val facts = rev rev_facts
         val fact_vec = Vector.fromList facts