--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:36:06 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:36:13 2014 +0200
@@ -398,10 +398,10 @@
exception EXIT of unit
-fun k_nearest_neighbors dffreq num_facts deps_vec get_sym_ths max_suggs visible_facts conj_feats =
+fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts conj_feats =
let
val ln_afreq = Math.ln (Real.fromInt num_facts)
- fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Array.sub (dffreq, feat)))
+ fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
@@ -416,7 +416,7 @@
val w2 = sw * sw
fun do_th j = if j < num_facts then inc_overlap j w2 else ()
in
- List.app do_th (get_sym_ths s)
+ List.app do_th (Array.sub (feat_facts, s))
end
val _ = List.app do_feat conj_feats
@@ -427,8 +427,13 @@
fun inc_recommend j v =
let val ov = snd (Array.sub (recommends, j)) in
- if ov <= 0.0 then (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
- else (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ()) end
+ if ov <= 0.0 then
+ (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
+ else if ov < !age + 1000.0 then
+ Array.update (recommends, j, (j, v + ov))
+ else
+ ()
+ end
val k = Unsynchronized.ref 0
fun do_k k =
@@ -439,7 +444,7 @@
val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
val o1 = Math.sqrt o2
val _ = inc_recommend j o1
- val ds = Vector.sub (deps_vec, j)
+ val ds = Vector.sub (depss, j)
val l = Real.fromInt (length ds)
in
List.app (fn d => inc_recommend d (o1 / l)) ds
@@ -464,16 +469,26 @@
ret [] (Integer.max 0 (num_facts - max_suggs))
end
+fun wider_array_of_vector init vec =
+ let val ary = Array.array init in
+ Array.copyVec {src = vec, dst = ary, di = 0};
+ ary
+ end
+
val nb_def_prior_weight = 21 (* FUDGE *)
-fun learn_facts tfreq sfreq dffreq num_facts depss featss =
+fun learn_facts (tfreq0, sfreq0, dffreq0) num_facts0 num_facts num_feats depss featss =
let
- fun learn_fact th feats deps =
+ val tfreq = wider_array_of_vector (num_facts, 0) tfreq0
+ val sfreq = wider_array_of_vector (num_facts, Inttab.empty) sfreq0
+ val dffreq = wider_array_of_vector (num_feats, 0) dffreq0
+
+ fun learn_one th feats deps =
let
fun add_th weight t =
let
val im = Array.sub (sfreq, t)
- fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
+ fun fold_fn s = Inttab.map_default (s, 0) (Integer.add weight)
in
map_array_at tfreq (Integer.add weight) t;
Array.update (sfreq, t, fold fold_fn feats im)
@@ -487,26 +502,30 @@
end
fun for i =
- if i = num_facts then ()
- else (learn_fact i (Vector.sub (featss, i)) (Vector.sub (depss, i)); for (i + 1))
+ if i = num_facts then
+ ()
+ else
+ (learn_one (num_facts0 + i) (Vector.sub (featss, i)) (Vector.sub (depss, i));
+ for (i + 1))
in
- for 0
+ for 0;
+ (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
end
-fun naive_bayes_query tfreq sfreq dffreq num_facts max_suggs visible_facts conj_feats =
+fun naive_bayes tfreq sfreq dffreq num_facts max_suggs visible_facts conj_feats =
let
val tau = 0.05 (* FUDGE *)
val pos_weight = 10.0 (* FUDGE *)
val def_val = ~15.0 (* FUDGE *)
val ln_afreq = Math.ln (Real.fromInt num_facts)
- val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq)
+ val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) dffreq
fun tfidf feat = Vector.sub (idf, feat)
fun log_posterior i =
let
- val tfreq = Real.fromInt (Array.sub (tfreq, i))
+ val tfreq = Real.fromInt (Vector.sub (tfreq, i))
fun fold_feats f (res, sfh) =
(case Inttab.lookup sfh f of
@@ -515,7 +534,7 @@
Inttab.delete f sfh)
| NONE => (res + tfidf f * def_val, sfh))
- val (res, sfh) = fold fold_feats conj_feats (Math.ln tfreq, Array.sub (sfreq, i))
+ val (res, sfh) = fold fold_feats conj_feats (Math.ln tfreq, Vector.sub (sfreq, i))
fun fold_sfh (f, sf) sow = sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)
@@ -593,56 +612,57 @@
c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
-fun query ctxt engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs learns0
- conj_feats =
+fun reorder_learns (num_facts, fact_tab) learns0 =
+ let
+ val learns = Array.array (num_facts, ("", [], []))
+ in
+ List.app (fn learn as (fact, _, _) =>
+ Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
+ learns0;
+ Array.foldr (op ::) [] learns
+ end
+
+fun query ctxt engine (fact_xtab as (num_facts, fact_tab)) (num_feats, feat_tab) visible_facts
+ max_suggs learns0 conj_feats =
if engine = MaSh_SML_kNN_Cpp then
k_nearest_neighbors_cpp max_suggs learns0 conj_feats
else if engine = MaSh_SML_NB_Cpp then
naive_bayes_cpp max_suggs learns0 conj_feats
else
let
- val learn_ary = Array.array (num_facts, ("", [], []))
- val _ =
- List.app (fn entry as (fact, _, _) =>
- Array.update (learn_ary, the (Symtab.lookup fact_tab fact), entry))
- learns0
- val learns = Array.foldr (op ::) [] learn_ary
+ val learns = reorder_learns fact_xtab learns0
+
+ val facts = Vector.fromList (map #1 learns)
+ val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
+ val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
- val facts = map #1 learns
- val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
- val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
+ val tfreq = Vector.tabulate (num_facts, K 0)
+ val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
+ val dffreq = Vector.tabulate (num_feats, K 0)
- val fact_vec = Vector.fromList facts
- val feats_vec = Vector.fromList featss
- val deps_vec = Vector.fromList depss
-
- val tfreq = Array.array (num_facts, 0)
- val sfreq = Array.array (num_facts, Inttab.empty)
- val dffreq = Array.array (num_feats, 0)
-
- val _ = learn_facts tfreq sfreq dffreq num_facts deps_vec feats_vec
+ val (tfreq, sfreq, dffreq) =
+ learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
in
trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^
- elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
+ elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
(case engine of
MaSh_SML_kNN =>
let
- val facts_ary = Array.array (num_feats, [])
+ val feat_facts = Array.array (num_feats, [])
val _ =
- fold (fn feats => fn fact =>
- (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
- featss 0
- val get_facts = curry Array.sub facts_ary
+ Vector.foldl (fn (feats, fact) =>
+ (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1))
+ 0 featss
in
- k_nearest_neighbors dffreq num_facts deps_vec get_facts max_suggs int_visible_facts
+ k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs int_visible_facts
int_conj_feats
end
| MaSh_SML_NB =>
- naive_bayes_query tfreq sfreq dffreq num_facts max_suggs int_visible_facts int_conj_feats)
- |> map (curry Vector.sub fact_vec o fst)
+ naive_bayes tfreq sfreq dffreq num_facts max_suggs int_visible_facts int_conj_feats)
+ |> map (curry Vector.sub facts o fst)
end
end;
@@ -664,12 +684,12 @@
Graph.default_node (parent, (Isar_Proof, [], []))
#> Graph.add_edge (parent, name)
-fun add_node kind name parents feats deps (access_G, fact_xtab, feat_xtab) =
+fun add_node kind name parents feats deps (access_G, (fact_xtab, feat_xtab)) =
((Graph.new_node (name, (kind, feats, deps)) access_G
handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) access_G)
|> fold (add_edge_to name) parents,
- maybe_add_to_xtab name fact_xtab,
- fold maybe_add_to_xtab feats feat_xtab)
+ (maybe_add_to_xtab name fact_xtab,
+ fold maybe_add_to_xtab feats feat_xtab))
fun try_graph ctxt when def f =
f ()
@@ -694,19 +714,17 @@
type mash_state =
{access_G : (proof_kind * string list * string list) Graph.T,
- fact_xtab : xtab,
- feat_xtab : xtab,
+ xtabs : xtab * xtab,
num_known_facts : int, (* ### FIXME: kill *)
dirty_facts : string list option}
val empty_state =
{access_G = Graph.empty,
- fact_xtab = empty_xtab,
- feat_xtab = empty_xtab,
+ xtabs = (empty_xtab, empty_xtab),
num_known_facts = 0,
dirty_facts = SOME []} : mash_state
-val empty_graphxx = (Graph.empty, empty_xtab, empty_xtab)
+val empty_graphxx = (Graph.empty, (empty_xtab, empty_xtab))
local
@@ -741,7 +759,7 @@
NONE => I (* should not happen *)
| SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
- val ((access_G, fact_xtab, feat_xtab), num_known_facts) =
+ val ((access_G, xtabs), num_known_facts) =
(case string_ord (version', version) of
EQUAL =>
(try_graph ctxt "loading state" empty_graphxx (fn () =>
@@ -755,8 +773,8 @@
| GREATER => raise FILE_VERSION_TOO_NEW ())
in
trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
- {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
- num_known_facts = num_known_facts, dirty_facts = SOME []}
+ {access_G = access_G, xtabs = xtabs, num_known_facts = num_known_facts,
+ dirty_facts = SOME []}
end
| _ => empty_state)))
end
@@ -766,7 +784,7 @@
encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
- | save_state ctxt (memory_time, {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts}) =
+ | save_state ctxt (memory_time, {access_G, xtabs, num_known_facts, dirty_facts}) =
let
fun append_entry (name, ((kind, feats, deps), (parents, _))) =
cons (kind, name, Graph.Keys.dest parents, feats, deps)
@@ -788,8 +806,8 @@
SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
| _ => "") ^ ")");
(Time.now (),
- {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
- num_known_facts = num_known_facts, dirty_facts = SOME []})
+ {access_G = access_G, xtabs = xtabs, num_known_facts = num_known_facts,
+ dirty_facts = SOME []})
end
val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
@@ -1342,9 +1360,9 @@
(parents, hints, feats)
end
- val ((access_G, fact_xtab, feat_xtab), py_suggs) =
- peek_state ctxt overlord (fn {access_G, fact_xtab, feat_xtab, ...} =>
- ((access_G, fact_xtab, feat_xtab),
+ val ((access_G, (fact_xtab, feat_xtab)), py_suggs) =
+ peek_state ctxt overlord (fn {access_G, xtabs, ...} =>
+ ((access_G, xtabs),
if Graph.is_empty access_G then
(trace_msg ctxt (K "Nothing has been learned yet"); [])
else if engine = MaSh_Py then
@@ -1377,7 +1395,7 @@
end
fun learn_wrt_access_graph ctxt (name, parents, feats, deps)
- (learns, (access_G, fact_xtab, feat_xtab)) =
+ (learns, (access_G, (fact_xtab, feat_xtab))) =
let
fun maybe_learn_from from (accum as (parents, access_G)) =
try_graph ctxt "updating graph" accum (fn () =>
@@ -1390,7 +1408,7 @@
val fact_xtab = maybe_add_to_xtab name fact_xtab
val feat_xtab = fold maybe_add_to_xtab feats feat_xtab
in
- ((name, parents, feats, deps) :: learns, (access_G, fact_xtab, feat_xtab))
+ ((name, parents, feats, deps) :: learns, (access_G, (fact_xtab, feat_xtab)))
end
fun relearn_wrt_access_graph ctxt (name, deps) (relearns, access_G) =
@@ -1431,7 +1449,7 @@
val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
in
map_state ctxt overlord
- (fn state as {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =>
+ (fn state as {access_G, xtabs, num_known_facts, dirty_facts} =>
let
val parents = maximal_wrt_access_graph access_G facts
val deps = used_ths
@@ -1443,12 +1461,10 @@
else
let
val name = learned_proof_name ()
- val (access_G, fact_xtab, feat_xtab) =
- add_node Automatic_Proof name parents feats deps
- (access_G, fact_xtab, feat_xtab)
+ val (access_G, xtabs) =
+ add_node Automatic_Proof name parents feats deps (access_G, xtabs)
in
- {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
- num_known_facts = num_known_facts + 1,
+ {access_G = access_G, xtabs = xtabs, num_known_facts = num_known_facts + 1,
dirty_facts = Option.map (cons name) dirty_facts}
end
end);
@@ -1496,11 +1512,10 @@
isar_dependencies_of name_tabs th
fun do_commit [] [] [] state = state
- | do_commit learns relearns flops
- {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =
+ | do_commit learns relearns flops {access_G, xtabs, num_known_facts, dirty_facts} =
let
- val (learns, (access_G, fact_xtab, feat_xtab)) =
- fold (learn_wrt_access_graph ctxt) learns ([], (access_G, fact_xtab, feat_xtab))
+ val (learns, (access_G, xtabs)) =
+ fold (learn_wrt_access_graph ctxt) learns ([], (access_G, xtabs))
val (relearns, access_G) =
fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
@@ -1517,8 +1532,8 @@
MaSh_Py.relearn ctxt overlord save relearns)
else
();
- {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
- num_known_facts = num_known_facts, dirty_facts = dirty_facts}
+ {access_G = access_G, xtabs = xtabs, num_known_facts = num_known_facts,
+ dirty_facts = dirty_facts}
end
fun commit last learns relearns flops =