--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed May 28 12:34:26 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed May 28 13:02:47 2014 +0200
@@ -362,19 +362,20 @@
val number_of_nearest_neighbors = 40 (* FUDGE *)
(*
- avail_num = maximum number of theorems to check dependencies and symbols
- adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num
+ num_facts = maximum number of theorems to check dependencies and symbols
+ num_visible_facts = do not return theorems over or equal to this number.
+ Must satisfy: num_visible_facts <= num_facts.
get_deps = returns dependencies of a theorem
get_sym_ths = get theorems that have this feature
- advno = number of predictions to return
- syms = symbols of the conjecture
+ max_suggs = number of suggestions to return
+ feats = features of the goal
*)
-fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths advno syms =
+fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats =
let
(* Can be later used for TFIDF *)
fun sym_wght _ = 1.0
- val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)))
+ val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
fun inc_overlap j v =
let
@@ -383,30 +384,30 @@
Array.update (overlaps_sqr, j, (j, v + ov))
end
- fun do_sym (s, con_wght) =
+ fun do_feat (s, con_wght) =
let
val sw = sym_wght s
val w2 = sw * sw * con_wght
- fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
+ fun do_th (j, prem_wght) = if j < num_facts then inc_overlap j (w2 * prem_wght) else ()
in
List.app do_th (get_sym_ths s)
end
- val _ = List.app do_sym syms
+ val _ = List.app do_feat feats
val _ = heap (Real.compare o pairself snd) number_of_nearest_neighbors overlaps_sqr
- val recommends = Array.tabulate (adv_max, rpair 0.0)
+ val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
fun inc_recommend j v =
- if j >= adv_max then ()
+ if j >= num_visible_facts then ()
else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
fun for k =
- if k = number_of_nearest_neighbors orelse k >= adv_max then
+ if k = number_of_nearest_neighbors orelse k >= num_visible_facts then
()
else
let
- val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1)
+ val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
val o1 = Math.sqrt o2
val _ = inc_recommend j o1
val ds = get_deps j
@@ -419,8 +420,8 @@
if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
in
for 0;
- heap (Real.compare o pairself snd) advno recommends;
- ret [] (Integer.max 0 (adv_max - advno))
+ heap (Real.compare o pairself snd) max_suggs recommends;
+ ret [] (Integer.max 0 (num_visible_facts - max_suggs))
end
val nb_tau = 0.02 (* FUDGE *)
@@ -428,14 +429,14 @@
val nb_def_val = ~15.0 (* FUDGE *)
val nb_def_prior_weight = 20 (* FUDGE *)
-fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms =
+fun naive_bayes_learn num_facts get_deps get_th_feats num_feats =
let
val afreq = Unsynchronized.ref 0
- val tfreq = Array.array (avail_num, 0)
- val sfreq = Array.array (avail_num, Inttab.empty)
- val dffreq = Array.array (sym_num, 0)
+ val tfreq = Array.array (num_facts, 0)
+ val sfreq = Array.array (num_facts, Inttab.empty)
+ val dffreq = Array.array (num_feats, 0)
- fun learn th syms deps =
+ fun learn th feats deps =
let
fun add_th weight t =
let
@@ -443,57 +444,61 @@
fun fold_fn s sf = Inttab.update (s, weight + the_default 0 (Inttab.lookup im s)) sf
in
Array.update (tfreq, t, weight + Array.sub (tfreq, t));
- Array.update (sfreq, t, fold fold_fn syms im)
+ Array.update (sfreq, t, fold fold_fn feats im)
end
fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
in
add_th nb_def_prior_weight th;
List.app (add_th 1) deps;
- List.app add_sym syms;
+ List.app add_sym feats;
afreq := !afreq + 1
end
- fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)))
+ fun for i =
+ if i = num_facts then () else (learn i (get_th_feats i) (get_deps i); for (i + 1))
+ in
+ for 0; (Real.fromInt (!afreq), Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
+ end
- fun eval syms =
+fun naive_bayes_query num_visible_facts max_suggs feats (afreq, tfreq, sfreq, dffreq) =
+ let
+ fun tfidf feat = Math.ln afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
+
+ fun log_posterior i =
let
- fun log_posterior i =
- let
- val tfreq = Real.fromInt (Array.sub (tfreq, i))
-
- fun fold_syms (f, fw) (res, sfh) =
- (case Inttab.lookup sfh f of
- SOME sf =>
- (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq),
- Inttab.delete f sfh)
- | NONE => (res + fw * nb_def_val, sfh))
+ val tfreq = Real.fromInt (Vector.sub (tfreq, i))
- val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i))
-
- fun fold_sfh (f, sf) sow =
- sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq))
+ fun fold_feats (f, fw) (res, sfh) =
+ (case Inttab.lookup sfh f of
+ SOME sf =>
+ (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq),
+ Inttab.delete f sfh)
+ | NONE => (res + fw * nb_def_val, sfh))
- val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
- in
- res + nb_tau * sum_of_weights
- end
+ val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i))
- val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j)))
+ fun fold_sfh (f, sf) sow =
+ sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq))
- fun ret acc at =
- if at = adv_max then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
+ val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
in
- heap (Real.compare o pairself snd) advno posterior;
- ret [] (Integer.max 0 (adv_max - advno))
+ res + nb_tau * sum_of_weights
end
- fun for i =
- if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1))
+ val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j)))
+
+ fun ret acc at =
+ if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
in
- for 0; eval syms
+ heap (Real.compare o pairself snd) max_suggs posterior;
+ ret [] (Integer.max 0 (num_visible_facts - max_suggs))
end
+fun naive_bayes num_facts num_visible_facts get_deps get_th_feats num_feats max_suggs feats =
+ naive_bayes_learn num_facts get_deps get_th_feats num_feats
+ |> naive_bayes_query num_visible_facts max_suggs feats
+
fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
@@ -547,19 +552,19 @@
end)
rev_featss num_facts
val get_facts = curry Array.sub facts_ary
- val syms = map_filter (fn (feat, weight) =>
+ val feats' = map_filter (fn (feat, weight) =>
Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
in
- k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs syms
+ k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats'
end
else
let
val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
- val syms = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
+ val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
in
naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs
- syms
+ feats'
end)
|> map (curry Vector.sub fact_vec o fst)
end