--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:21 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:27 2014 +0200
@@ -352,7 +352,7 @@
exception BOTTOM of int
-fun heap cmp bnd a =
+fun heap cmp bnd al a =
let
fun maxson l i =
let val i31 = i + i + i + 1 in
@@ -394,12 +394,10 @@
Array.update (a, i, e)
end
- val l = Array.length a
-
- fun for i = if i < 0 then () else (trickle l i (Array.sub (a, i)); for (i - 1))
+ fun for i = if i < 0 then () else (trickle al i (Array.sub (a, i)); for (i - 1))
fun for2 i =
- if i < Integer.max 2 (l - bnd) then
+ if i < Integer.max 2 (al - bnd) then
()
else
let val e = Array.sub (a, i) in
@@ -408,9 +406,9 @@
for2 (i - 1)
end
in
- for (((l + 1) div 3) - 1);
- for2 (l - 1);
- if l > 1 then
+ for (((al + 1) div 3) - 1);
+ for2 (al - 1);
+ if al > 1 then
let val e = Array.sub (a, 1) in
Array.update (a, 1, Array.sub (a, 0));
Array.update (a, 0, e)
@@ -457,7 +455,7 @@
end
val _ = List.app do_feat feats
- val _ = heap (Real.compare o pairself snd) num_facts overlaps_sqr
+ val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
val no_recommends = Unsynchronized.ref 0
val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
val age = Unsynchronized.ref 1000000000.0
@@ -498,39 +496,34 @@
if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
in
while1 (); while2 ();
- heap (Real.compare o pairself snd) max_suggs recommends;
+ heap (Real.compare o pairself snd) max_suggs num_visible_facts recommends;
ret [] (Integer.max 0 (num_visible_facts - max_suggs))
end
val nb_def_prior_weight = 21 (* FUDGE *)
-fun naive_bayes_learn_fact tfreq sfreq dffreq th feats deps =
+fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
let
- fun add_th weight t =
+ fun learn_fact th feats deps =
let
- val im = Array.sub (sfreq, t)
- fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
+ fun add_th weight t =
+ let
+ val im = Array.sub (sfreq, t)
+ fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
+ in
+ Array.update (tfreq, t, weight + Array.sub (tfreq, t));
+ Array.update (sfreq, t, fold fold_fn feats im)
+ end
+
+ fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
in
- Array.update (tfreq, t, weight + Array.sub (tfreq, t));
- Array.update (sfreq, t, fold fold_fn feats im)
+ add_th nb_def_prior_weight th;
+ List.app (add_th 1) deps;
+ List.app add_sym feats
end
- fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
- in
- add_th nb_def_prior_weight th;
- List.app (add_th 1) deps;
- List.app add_sym feats
- end
-
-fun naive_bayes_learn num_facts get_deps get_feats num_feats =
- let
- val tfreq = Array.array (num_facts, 0)
- val sfreq = Array.array (num_facts, Inttab.empty)
- val dffreq = Array.array (num_feats, 0)
-
fun for i =
- if i = num_facts then ()
- else (naive_bayes_learn_fact tfreq sfreq dffreq i (get_feats i) (get_deps i); for (i + 1))
+ if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
val ln_afreq = Math.ln (Real.fromInt num_facts)
in
@@ -539,6 +532,15 @@
Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq))
end
+fun learn num_facts get_deps get_feats num_feats =
+ let
+ val tfreq = Array.array (num_facts, 0)
+ val sfreq = Array.array (num_facts, Inttab.empty)
+ val dffreq = Array.array (num_feats, 0)
+ in
+ learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
+ end
+
fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts num_visible_facts max_suggs feats
(tfreq, sfreq, idf) =
let
@@ -579,12 +581,12 @@
fun ret acc at =
if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
in
- heap (Real.compare o pairself snd) max_suggs posterior;
+ heap (Real.compare o pairself snd) max_suggs num_visible_facts posterior;
ret [] (Integer.max 0 (num_visible_facts - max_suggs))
end
fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
- naive_bayes_learn num_facts get_deps get_feats num_feats
+ learn num_facts get_deps get_feats num_feats
|> naive_bayes_query opts num_facts num_visible_facts max_suggs feats
(* experimental *)