--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 27 17:32:42 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 27 17:48:11 2014 +0200
@@ -121,7 +121,7 @@
(case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
"yes" => SOME MaSh_Py
| "py" => SOME MaSh_Py
- | "sml" => SOME MaSh_SML_kNN
+ | "sml" => SOME MaSh_SML_NB
| "sml_knn" => SOME MaSh_SML_kNN
| "sml_nb" => SOME MaSh_SML_NB
| _ => NONE)
@@ -422,78 +422,72 @@
ret [] (Integer.max 0 (adv_max - advno))
end
-(* Two arguments control the behaviour of naive Bayes: prior and ess. Prior expresses our belief in
- usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the
- prior. *)
-fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms =
+val tau = 0.02
+val posWeight = 2.0
+val defVal = ~15.0
+val defPriWei = 20
+
+fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms =
let
- val afreq = Unsynchronized.ref 0
- val tfreq = Array.array (avail_num, 0)
- val sfreq = Array.array (avail_num, Inttab.empty)
+ val afreq = Unsynchronized.ref 0;
+ val tfreq = Array.array (avail_num, 0);
+ val sfreq = Array.array (avail_num, Inttab.empty);
+ val dffreq = Array.array (sym_num, 0);
- fun nb_learn syms ts =
+ fun learn th syms deps =
let
- fun add_sym hpis sym =
+ fun add_th t =
let
- val im = Array.sub (sfreq, hpis)
- val v = the_default 0 (Inttab.lookup im sym)
+ val im = Array.sub (sfreq, t);
+ fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf;
in
- Array.update (sfreq, hpis, Inttab.update (sym, v + 1) im)
- end
+ Array.update (tfreq, t, 1 + Array.sub (tfreq, t));
+ Array.update (sfreq, t, fold fold_fn syms im)
+ end;
+ fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s));
+ in
+ List.app add_th (replicate defPriWei th);
+ List.app add_th deps;
+ List.app add_sym syms;
+ afreq := !afreq + 1
+ end;
- fun add_th t =
- (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms)
- in
- afreq := !afreq + 1; List.app add_th ts
- end
+ fun tfidf _ = 1.0;
+ (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)));*)
- fun nb_eval syms =
+ fun eval syms =
let
fun log_posterior i =
let
- val symh = fold (Inttab.update o rpair ()) syms Inttab.empty
- val n = Real.fromInt (Array.sub (tfreq, i))
- val sfreqh = Array.sub (sfreq, i)
- val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq)
- val mp = ess * p
- val logmp = Math.ln mp
- val lognmp = Math.ln (n + mp)
-
- fun in_sfreqh (s, sfreqv) (sofar, sfsymh) =
- let val sfreqv = Real.fromInt sfreqv in
- if Inttab.defined sfsymh s then
- (sofar + Math.ln (sfreqv + mp), Inttab.delete s sfsymh)
- else
- (sofar + Math.ln (n - sfreqv + mp), sfsymh)
- end
-
- val (postsfreqh, symh) = Inttab.fold in_sfreqh sfreqh (Math.ln n, symh)
- val len_mem = length (Inttab.keys symh)
- val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh)
+ val tfreq = Real.fromInt (Array.sub (tfreq, i));
+ fun fold_syms (f, fw) (res, sfh) =
+ (case Inttab.lookup sfh f of
+ SOME sf =>
+ (res + tfidf f * fw * Math.ln (posWeight * Real.fromInt sf / tfreq),
+ Inttab.delete f sfh)
+ | NONE => (res + fw * defVal, sfh));
+ val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i));
+ fun fold_sfh (f, sf) sow =
+ sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq));
+ val sumOfWei = Inttab.fold fold_sfh sfh 0.0;
in
- postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp -
- Real.fromInt sym_num * Math.ln (n + ess)
+ res + tau * sumOfWei
end
-
- val posterior = Array.tabulate (adv_max, swap o `log_posterior)
-
+ val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j)));
fun ret acc at =
- if at = Array.length posterior then acc
- else ret (Array.sub (posterior, at) :: acc) (at + 1)
+ if at = adv_max then acc else ret (Array.sub (posterior,at) :: acc) (at + 1)
in
heap (Real.compare o pairself snd) advno posterior;
ret [] (Integer.max 0 (adv_max - advno))
- end
+ end;
fun for i =
- if i = avail_num then () else (nb_learn (get_th_syms i) (i :: get_deps i); for (i + 1))
+ if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1))
in
- for 0; nb_eval syms
+ for 0; eval syms
end
val knns = 40 (* FUDGE *)
-val ess = 0.00001 (* FUDGE *)
-val prior = 0.001 (* FUDGE *)
fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
@@ -532,6 +526,8 @@
val num_visible_facts = length visible_facts
val get_deps = curry Vector.sub deps_vec
+ val syms = map_filter (fn (feat, weight) =>
+ Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
in
trace_msg ctxt (fn () => "MaSh_SML " ^ (if engine = MaSh_SML_kNN then "kNN" else "NB") ^
" query " ^ encode_features feats ^ " from {" ^
@@ -548,8 +544,6 @@
end)
rev_featss num_facts
val get_facts = curry Array.sub facts_ary
- val syms = map_filter (fn (feat, weight) =>
- Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
in
k_nearest_neighbors num_facts num_visible_facts get_deps get_facts knns max_suggs syms
end
@@ -557,10 +551,9 @@
let
val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
- val unweighted_syms = map_filter (Symtab.lookup feat_tab o fst) feats
in
- naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats ess prior
- max_suggs unweighted_syms
+ naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs
+ syms
end)
|> map (curry Vector.sub fact_vec o fst)
end
@@ -1237,17 +1230,12 @@
|> map (chained_or_extra_features_of chained_feature_factor)
|> rpair [] |-> fold (union (eq_fst (op =)))
val extra_feats =
- (* As long as SML NB does not support weights, it makes little sense to include these
- extra features. *)
- if engine = MaSh_SML_NB then
- []
- else
- facts
- |> take (Int.max (0, num_extra_feature_facts - length chained))
- |> filter fact_has_right_theory
- |> weight_facts_steeply
- |> map (chained_or_extra_features_of extra_feature_factor)
- |> rpair [] |-> fold (union (eq_fst (op =)))
+ facts
+ |> take (Int.max (0, num_extra_feature_facts - length chained))
+ |> filter fact_has_right_theory
+ |> weight_facts_steeply
+ |> map (chained_or_extra_features_of extra_feature_factor)
+ |> rpair [] |-> fold (union (eq_fst (op =)))
val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
|> debug ? sort (Real.compare o swap o pairself snd)
in