--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:21 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:30 2014 +0200
@@ -512,12 +512,9 @@
fun for i =
if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
-
- val ln_afreq = Math.ln (Real.fromInt num_facts)
in
for 0;
- (Array.vector tfreq, Array.vector sfreq,
- Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq))
+ (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
end
fun learn num_facts get_deps get_feats num_feats =
@@ -529,12 +526,15 @@
learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
end
-fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, idf) =
+fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) =
let
val tau = 0.05 (* FUDGE *)
val pos_weight = 10.0 (* FUDGE *)
val def_val = ~15.0 (* FUDGE *)
+ val ln_afreq = Math.ln (Real.fromInt num_facts)
+ val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) dffreq
+
fun tfidf feat = Vector.sub (idf, feat)
fun log_posterior i =
@@ -629,7 +629,7 @@
c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
-fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
+fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
@@ -643,11 +643,11 @@
naive_bayes_cpp max_suggs learns feats
else
let
- val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) =
+ val (rev_depss, rev_featss, ((num_facts, fact_tab), rev_facts), (num_feats, feat_tab)) =
fold (fn (fact, feats, deps) =>
- fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
+ fn (rev_depss, rev_featss, (fact_xtab as (_, fact_tab), rev_facts), feat_xtab) =>
let
- fun add_feat feat (xtab as (n, tab, _)) =
+ fun add_feat feat (xtab as (n, tab)) =
(case Symtab.lookup tab feat of
SOME i => (i, xtab)
| NONE => (n, add_to_xtab feat xtab))
@@ -655,9 +655,9 @@
val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
in
(map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
- add_to_xtab fact fact_xtab, feat_xtab')
+ (add_to_xtab fact fact_xtab, fact :: rev_facts), feat_xtab')
end)
- learns ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
+ learns ([], [], ((0, Symtab.empty), []), (0, Symtab.empty))
val facts = rev rev_facts
val fact_vec = Vector.fromList facts