--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 22:28:08 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 22:28:44 2014 +0200
@@ -115,26 +115,21 @@
()
end
-datatype mash_engine = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB
+datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB
fun mash_engine () =
let val flag1 = Options.default_string @{system_option maSh} in
(case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
"yes" => SOME MaSh_Py
| "py" => SOME MaSh_Py
- | "sml" => SOME MaSh_SML_KNN
- | "sml_knn" => SOME MaSh_SML_KNN
+ | "sml" => SOME MaSh_SML_kNN
+ | "sml_knn" => SOME MaSh_SML_kNN
| "sml_nb" => SOME MaSh_SML_NB
| _ => NONE)
end
val is_mash_enabled = is_some o mash_engine
-
-fun is_mash_sml_enabled () =
- (case mash_engine () of
- SOME MaSh_SML_KNN => true
- | SOME MaSh_SML_NB => true
- | _ => false)
+val the_mash_engine = the_default MaSh_SML_kNN o mash_engine
(*** Low-level communication with Python version of MaSh ***)
@@ -320,71 +315,55 @@
end
fun trickledown l i e =
- let
- val j = maxson l i
- in
+ let val j = maxson l i in
if cmp (Array.sub (a, j), e) = GREATER then
- let val _ = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
- else Array.update (a, i, e)
+ (Array.update (a, i, Array.sub (a, j)); trickledown l j e)
+ else
+ Array.update (a, i, e)
end
fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e)
fun bubbledown l i =
- let
- val j = maxson l i
- val _ = Array.update (a, i, Array.sub (a, j))
- in
+ let val j = maxson l i in
+ Array.update (a, i, Array.sub (a, j));
bubbledown l j
end
fun bubble l i = bubbledown l i handle BOTTOM i => i
fun trickleup i e =
- let
- val father = (i - 1) div 3
- in
+ let val father = (i - 1) div 3 in
if cmp (Array.sub (a, father), e) = LESS then
- let
- val _ = Array.update (a, i, Array.sub (a, father))
- in
- if father > 0 then trickleup father e else Array.update (a, 0, e)
- end
- else Array.update (a, i, e)
+ (Array.update (a, i, Array.sub (a, father));
+ if father > 0 then trickleup father e else Array.update (a, 0, e))
+ else
+ Array.update (a, i, e)
end
val l = Array.length a
- fun for i =
- if i < 0 then () else
- let
- val _ = trickle l i (Array.sub (a, i))
- in
- for (i - 1)
- end
-
- val _ = for (((l + 1) div 3) - 1)
+ fun for i = if i < 0 then () else (trickle l i (Array.sub (a, i)); for (i - 1))
fun for2 i =
- if i < Integer.max 2 (l - bnd) then () else
- let
- val e = Array.sub (a, i)
- val _ = Array.update (a, i, Array.sub (a, 0))
- val _ = trickleup (bubble i 0) e
- in
- for2 (i - 1)
- end
-
- val _ = for2 (l - 1)
+ if i < Integer.max 2 (l - bnd) then
+ ()
+ else
+ let val e = Array.sub (a, i) in
+ Array.update (a, i, Array.sub (a, 0));
+ trickleup (bubble i 0) e;
+ for2 (i - 1)
+ end
in
+ for (((l + 1) div 3) - 1);
+ for2 (l - 1);
if l > 1 then
- let
- val e = Array.sub (a, 1)
- val _ = Array.update (a, 1, Array.sub (a, 0))
- in
+ let val e = Array.sub (a, 1) in
+ Array.update (a, 1, Array.sub (a, 0));
Array.update (a, 0, e)
end
- else ()
+ else
+ ()
end
(*
@@ -421,7 +400,7 @@
end
val _ = List.app do_sym syms
- val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
+ val _ = heap (Real.compare o pairself snd) knns overlaps_sqr
val recommends = Array.tabulate (adv_max, rpair 0.0)
fun inc_recommend j v =
@@ -438,27 +417,97 @@
val _ = inc_recommend j o1
val ds = get_deps j
val l = Real.fromInt (length ds)
- val _ = map (fn d => inc_recommend d (o1 / l)) ds
in
- for (k + 1)
+ List.app (fn d => inc_recommend d (o1 / l)) ds; for (k + 1)
end
- val _ = for 0
- val _ = heap (Real.compare o pairself snd) advno recommends
-
fun ret acc at =
if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
in
+ for 0;
+ heap (Real.compare o pairself snd) advno recommends;
ret [] (Integer.max 0 (adv_max - advno))
end
+(* Two arguments control the behaviour of nbayes: prior and ess. Prior expresses our belief in
+ usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the
+ prior. *)
+fun nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms =
+ let
+ val afreq = Unsynchronized.ref 0
+ val tfreq = Array.array (avail_num, 0)
+ val sfreq = Array.array (avail_num, Inttab.empty)
+
+ fun nb_learn syms ts =
+ let
+ fun add_sym hpis sym =
+ let
+ val im = Array.sub (sfreq, hpis)
+ val v = the_default 0 (Inttab.lookup im sym)
+ in
+ Array.update(sfreq, hpis, Inttab.update (sym, v + 1) im)
+ end
+
+ fun add_th t =
+ (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms)
+ in
+ afreq := !afreq + 1;
+ List.app add_th ts
+ end
+
+ fun nb_eval syms =
+ let
+ fun log_posterior i =
+ let
+ val symh = fold (fn s => fn sf => Inttab.update (s, ()) sf) syms Inttab.empty
+ val n = Real.fromInt (Array.sub (tfreq, i))
+ val sfreqh = Array.sub (sfreq, i)
+ val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq)
+ val mp = ess * p
+ val logmp = Math.ln mp
+ val lognmp = Math.ln (n + mp)
+
+ fun in_sfreqh (s, sfreqv) (sofar, sfsymh) =
+ let val sfreqv = Real.fromInt sfreqv in
+ if Inttab.defined sfsymh s then
+ (sofar + Math.ln (sfreqv + mp), Inttab.delete s sfsymh)
+ else
+ (sofar + Math.ln (n - sfreqv + mp), sfsymh)
+ end
+
+ val (postsfreqh, symh) = Inttab.fold in_sfreqh sfreqh (Math.ln n, symh)
+ val len_mem = length (Inttab.keys symh)
+ val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh)
+ in
+ postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp -
+ Real.fromInt sym_num * Math.ln(n + ess)
+ end
+
+ val posterior = Array.tabulate (adv_max, swap o `log_posterior)
+
+ fun ret acc at =
+ if at = Array.length posterior then acc
+ else ret (Array.sub (posterior,at) :: acc) (at + 1)
+ in
+ heap (Real.compare o pairself snd) advno posterior;
+ ret [] (Integer.max 0 (adv_max - advno))
+ end
+
+ fun for i =
+ if i = avail_num then () else (nb_learn (get_th_syms i) (get_deps i); for (i + 1))
+ in
+ for 0; nb_eval syms
+ end
+
val knns = 40 (* FUDGE *)
+val ess = 0.00001 (* FUDGE *)
+val prior = 0.001 (* FUDGE *)
fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
-fun query ctxt parents access_G max_suggs hints feats =
+fun query ctxt engine parents access_G max_suggs hints feats =
let
val str_of_feat = space_implode "|"
@@ -470,9 +519,9 @@
|> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
(if null hints then [] else [(".goal", feats, hints)])
- val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) =
+ val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
fold (fn (fact, feats, deps) =>
- fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
+ fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
let
fun add_feat (feat, weight) (xtab as (n, tab, _)) =
(case Symtab.lookup tab feat of
@@ -481,7 +530,7 @@
val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab
in
- (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
+ (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
add_to_xtab fact fact_xtab, feat_xtab')
end)
all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
@@ -490,22 +539,40 @@
val fact_vec = Vector.fromList facts
val deps_vec = Vector.fromList (rev rev_depss)
- val facts_ary = Array.array (num_feats, [])
- val _ =
- fold (fn feats => fn fact =>
- let val fact' = fact - 1 in
- List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
- feats;
- fact'
- end)
- featss (length featss)
+
+ val avail_num = Vector.length deps_vec
+ val adv_max = length visible_facts
+ val get_deps = curry Vector.sub deps_vec
+ val advno = max_suggs
in
trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
elide_string 1000 (space_implode " " facts) ^ "}");
- knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
- (curry Array.sub facts_ary) knns max_suggs
- (map_filter (fn (feat, weight) =>
- Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
+ (if engine = MaSh_SML_kNN then
+ let
+ val facts_ary = Array.array (num_feats, [])
+ val _ =
+ fold (fn feats => fn fact =>
+ let val fact' = fact - 1 in
+ List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
+ feats;
+ fact'
+ end)
+ rev_featss num_facts
+ val get_sym_ths = curry Array.sub facts_ary
+ val syms = map_filter (fn (feat, weight) =>
+ Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats
+ in
+ knn avail_num adv_max get_deps get_sym_ths knns advno syms
+ end
+ else
+ let
+ val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
+ val get_th_syms = curry Vector.sub unweighted_feats_ary
+ val sym_num = num_feats
+ val unweighted_syms = map_filter (Symtab.lookup feat_tab o str_of_feat o fst) feats
+ in
+ nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno unweighted_syms
+ end)
|> map (curry Vector.sub fact_vec o fst)
end
@@ -596,8 +663,10 @@
fold extract_line_and_add_node node_lines Graph.empty),
length node_lines)
| LESS =>
- (if is_mash_sml_enabled () then wipe_out_mash_state_dir ()
- else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *)
+ (* cannot parse old file *)
+ (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
+ else wipe_out_mash_state_dir ();
+ (Graph.empty, 0))
| GREATER => raise FILE_VERSION_TOO_NEW ())
in
trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
@@ -645,7 +714,8 @@
fun clear_state ctxt overlord =
(* "MaSh_Py.unlearn" also removes the state file *)
Synchronized.change global_state (fn _ =>
- (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord;
+ (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
+ else wipe_out_mash_state_dir ();
(false, empty_state)))
end
@@ -1224,7 +1294,7 @@
(parents, hints, feats)
end
- val sml = is_mash_sml_enabled ()
+ val engine = the_mash_engine ()
val (access_G, py_suggs) =
peek_state ctxt overlord (fn {access_G, ...} =>
@@ -1232,20 +1302,20 @@
(trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
else
(access_G,
- if sml then
- []
- else
+ if engine = MaSh_Py then
let val (parents, hints, feats) = query_args access_G in
MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
- end))
+ end
+ else
+ []))
val sml_suggs =
- if sml then
+ if engine = MaSh_Py then
+ []
+ else
let val (parents, hints, feats) = query_args access_G in
- MaSh_SML.query ctxt parents access_G max_facts hints feats
+ MaSh_SML.query ctxt engine parents access_G max_facts hints feats
end
- else
- []
val unknown = filter_out (is_fact_in_graph access_G o snd) facts
in
@@ -1309,13 +1379,13 @@
|> filter (is_fact_in_graph access_G)
|> map nickname_of_thm
in
- if is_mash_sml_enabled () then
+ if the_mash_engine () = MaSh_Py then
+ (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
+ else
let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
{access_G = access_G, num_known_facts = num_known_facts + 1,
dirty = Option.map (cons name) dirty}
end
- else
- (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
end);
(true, "")
end)
@@ -1334,7 +1404,7 @@
val timer = Timer.startRealTimer ()
fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
- val sml = is_mash_sml_enabled ()
+ val engine = the_mash_engine ()
val {access_G, ...} = peek_state ctxt overlord I
val is_in_access_G = is_fact_in_graph access_G o snd
val no_new_facts = forall is_in_access_G facts
@@ -1376,11 +1446,11 @@
(false, SOME names, []) => SOME (map #1 learns @ names)
| _ => NONE)
in
- if sml then
- ()
+ if engine = MaSh_Py then
+ (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
+ MaSh_Py.relearn ctxt overlord save relearns)
else
- (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
- MaSh_Py.relearn ctxt overlord save relearns);
+ ();
{access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
end
@@ -1613,7 +1683,7 @@
|> Par_List.map (apsnd (fn f => f ()))
val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take
in
- if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord;
+ if the_mash_engine () = MaSh_Py andalso save then MaSh_Py.save ctxt overlord else ();
(case (fact_filter, mess) of
(NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
[(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
@@ -1623,7 +1693,7 @@
fun kill_learners ctxt ({overlord, ...} : params) =
(Async_Manager.kill_threads MaShN "learner";
- if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord)
+ if the_mash_engine () = MaSh_Py then MaSh_Py.shutdown ctxt overlord else ())
fun running_learners () = Async_Manager.running_threads MaShN "learner"