--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed May 28 03:10:30 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed May 28 09:38:39 2014 +0200
@@ -359,16 +359,17 @@
()
end
+val number_of_nearest_neighbors = 40 (* FUDGE *)
+
(*
avail_num = maximum number of theorems to check dependencies and symbols
adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num
get_deps = returns dependencies of a theorem
get_sym_ths = get theorems that have this feature
- knns = number of nearest neighbours
advno = number of predictions to return
syms = symbols of the conjecture
*)
-fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths knns advno syms =
+fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths advno syms =
let
(* Can be later used for TFIDF *)
fun sym_wght _ = 1.0
@@ -393,7 +394,7 @@
end
val _ = List.app do_sym syms
- val _ = heap (Real.compare o pairself snd) knns overlaps_sqr
+ val _ = heap (Real.compare o pairself snd) number_of_nearest_neighbors overlaps_sqr
val recommends = Array.tabulate (adv_max, rpair 0.0)
fun inc_recommend j v =
@@ -401,7 +402,7 @@
else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
fun for k =
- if k = knns orelse k >= adv_max then
+ if k = number_of_nearest_neighbors orelse k >= adv_max then
()
else
let
@@ -422,64 +423,71 @@
ret [] (Integer.max 0 (adv_max - advno))
end
-val tau = 0.02
-val posWeight = 2.0
-val defVal = ~15.0
-val defPriWei = 20
+val nb_tau = 0.02 (* FUDGE *)
+val nb_pos_weight = 2.0 (* FUDGE *)
+val nb_def_val = ~15.0 (* FUDGE *)
+val nb_def_prior_weight = 20 (* FUDGE *)
fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms =
let
- val afreq = Unsynchronized.ref 0;
- val tfreq = Array.array (avail_num, 0);
- val sfreq = Array.array (avail_num, Inttab.empty);
- val dffreq = Array.array (sym_num, 0);
+ val afreq = Unsynchronized.ref 0
+ val tfreq = Array.array (avail_num, 0)
+ val sfreq = Array.array (avail_num, Inttab.empty)
+ val dffreq = Array.array (sym_num, 0)
fun learn th syms deps =
let
fun add_th t =
let
- val im = Array.sub (sfreq, t);
- fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf;
+ val im = Array.sub (sfreq, t)
+ fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf
in
Array.update (tfreq, t, 1 + Array.sub (tfreq, t));
Array.update (sfreq, t, fold fold_fn syms im)
- end;
- fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s));
+ end
+
+ fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
in
- List.app add_th (replicate defPriWei th);
+ List.app add_th (replicate nb_def_prior_weight th);
List.app add_th deps;
List.app add_sym syms;
afreq := !afreq + 1
- end;
+ end
- fun tfidf _ = 1.0;
- (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)));*)
+ fun tfidf _ = 1.0
+ (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)))*)
fun eval syms =
let
fun log_posterior i =
let
- val tfreq = Real.fromInt (Array.sub (tfreq, i));
+ val tfreq = Real.fromInt (Array.sub (tfreq, i))
+
fun fold_syms (f, fw) (res, sfh) =
(case Inttab.lookup sfh f of
SOME sf =>
- (res + tfidf f * fw * Math.ln (posWeight * Real.fromInt sf / tfreq),
+ (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq),
Inttab.delete f sfh)
- | NONE => (res + fw * defVal, sfh));
- val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i));
+ | NONE => (res + fw * nb_def_val, sfh))
+
+ val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i))
+
fun fold_sfh (f, sf) sow =
- sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq));
- val sumOfWei = Inttab.fold fold_sfh sfh 0.0;
+ sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq))
+
+ val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
in
- res + tau * sumOfWei
+ res + nb_tau * sum_of_weights
end
- val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j)));
+
+ val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j)))
+
fun ret acc at =
- if at = adv_max then acc else ret (Array.sub (posterior,at) :: acc) (at + 1)
+ if at = adv_max then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
in
heap (Real.compare o pairself snd) advno posterior;
ret [] (Integer.max 0 (adv_max - advno))
- end;
+ end
fun for i =
if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1))
@@ -487,8 +495,6 @@
for 0; eval syms
end
-val knns = 40 (* FUDGE *)
-
fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
@@ -544,7 +550,7 @@
rev_featss num_facts
val get_facts = curry Array.sub facts_ary
in
- k_nearest_neighbors num_facts num_visible_facts get_deps get_facts knns max_suggs syms
+ k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs syms
end
else
let
@@ -1317,7 +1323,6 @@
in
map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} =>
let
- val name = learned_proof_name ()
val parents = maximal_wrt_access_graph access_G facts
val deps = used_ths
|> filter (is_fact_in_graph access_G)
@@ -1326,7 +1331,10 @@
if the_mash_engine () = MaSh_Py then
(MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
else
- let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
+ let
+ val name = learned_proof_name ()
+ val access_G = access_G |> add_node Automatic_Proof name parents feats deps
+ in
{access_G = access_G, num_known_facts = num_known_facts + 1,
dirty = Option.map (cons name) dirty}
end