# HG changeset patch # User blanchet # Date 1400617724 -7200 # Node ID 75cc30d2b83f2483b3e523723fec81960412a404 # Parent e5466055e94f2a42bf1e0bc14dc82f46f2cfb60a added naive Bayes ML implementation, due to Cezary Kaliszyk (like k-NN) diff -r e5466055e94f -r 75cc30d2b83f NEWS --- a/NEWS Tue May 20 22:28:08 2014 +0200 +++ b/NEWS Tue May 20 22:28:44 2014 +0200 @@ -386,8 +386,8 @@ - Activation of MaSh now works via the "mash" system option (without requiring restart), instead of former settings variable "MASH". The option can be edited in Isabelle/jEdit menu Plugin - Options / Isabelle / General. Allowed values include "sml" (for the new - SML engine), "py" (for the Python engine), and "no". + Options / Isabelle / General. Allowed values include "sml" (for the + default SML engine), "py" (for the old Python engine), and "none". - New option: smt_proofs - Renamed options: diff -r e5466055e94f -r 75cc30d2b83f src/Doc/Sledgehammer/document/root.tex --- a/src/Doc/Sledgehammer/document/root.tex Tue May 20 22:28:08 2014 +0200 +++ b/src/Doc/Sledgehammer/document/root.tex Tue May 20 22:28:44 2014 +0200 @@ -1070,8 +1070,8 @@ The experimental MaSh machine learner. Three learning engines are provided: \begin{enum} -\item[\labelitemi] \textbf{\textit{sml}} (also called -\textbf{\textit{sml\_knn}}) is a Standard ML implementation of $k$-nearest +\item[\labelitemi] \textbf{\textit{sml\_knn}} (also called +\textbf{\textit{sml}}) is a Standard ML implementation of $k$-nearest neighbors. \item[\labelitemi] \textbf{\textit{sml\_nb}} is a Standard ML implementation of diff -r e5466055e94f -r 75cc30d2b83f src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 22:28:08 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue May 20 22:28:44 2014 +0200 @@ -115,26 +115,21 @@ () end -datatype mash_engine = MaSh_Py | MaSh_SML_KNN | MaSh_SML_NB +datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB fun mash_engine () = let val flag1 = Options.default_string @{system_option maSh} in (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of "yes" => SOME MaSh_Py | "py" => SOME MaSh_Py - | "sml" => SOME MaSh_SML_KNN - | "sml_knn" => SOME MaSh_SML_KNN + | "sml" => SOME MaSh_SML_kNN + | "sml_knn" => SOME MaSh_SML_kNN | "sml_nb" => SOME MaSh_SML_NB | _ => NONE) end val is_mash_enabled = is_some o mash_engine - -fun is_mash_sml_enabled () = - (case mash_engine () of - SOME MaSh_SML_KNN => true - | SOME MaSh_SML_NB => true - | _ => false) +val the_mash_engine = the_default MaSh_SML_kNN o mash_engine (*** Low-level communication with Python version of MaSh ***) @@ -320,71 +315,55 @@ end fun trickledown l i e = - let - val j = maxson l i - in + let val j = maxson l i in if cmp (Array.sub (a, j), e) = GREATER then - let val _ = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end - else Array.update (a, i, e) + (Array.update (a, i, Array.sub (a, j)); trickledown l j e) + else + Array.update (a, i, e) end fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e) fun bubbledown l i = - let - val j = maxson l i - val _ = Array.update (a, i, Array.sub (a, j)) - in + let val j = maxson l i in + Array.update (a, i, Array.sub (a, j)); bubbledown l j end fun bubble l i = bubbledown l i handle BOTTOM i => i fun trickleup i e = - let - val father = (i - 1) div 3 - in + let val father = (i - 1) div 3 in if cmp (Array.sub (a, father), e) = LESS then - let - val _ = Array.update (a, i, Array.sub (a, father)) - in - if father > 0 then trickleup father e else Array.update (a, 0, e) - end - else Array.update (a, i, e) + (Array.update (a, i, Array.sub (a, father)); + if father > 0 then trickleup father e else Array.update (a, 0, e)) + else + Array.update (a, i, e) end val l = Array.length a - fun for i = - if i < 0 then () else - let - val _ = trickle l i (Array.sub (a, i)) - in - for (i - 1) - end - - val _ = for (((l + 1) div 3) - 1) + fun for i = if i < 0 then () else (trickle l i (Array.sub (a, i)); for (i - 1)) fun for2 i = - if i < Integer.max 2 (l - bnd) then () else - let - val e = Array.sub (a, i) - val _ = Array.update (a, i, Array.sub (a, 0)) - val _ = trickleup (bubble i 0) e - in - for2 (i - 1) - end - - val _ = for2 (l - 1) + if i < Integer.max 2 (l - bnd) then + () + else + let val e = Array.sub (a, i) in + Array.update (a, i, Array.sub (a, 0)); + trickleup (bubble i 0) e; + for2 (i - 1) + end in + for (((l + 1) div 3) - 1); + for2 (l - 1); if l > 1 then - let - val e = Array.sub (a, 1) - val _ = Array.update (a, 1, Array.sub (a, 0)) - in + let val e = Array.sub (a, 1) in + Array.update (a, 1, Array.sub (a, 0)); Array.update (a, 0, e) end - else () + else + () end (* @@ -421,7 +400,7 @@ end val _ = List.app do_sym syms - val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr + val _ = heap (Real.compare o pairself snd) knns overlaps_sqr val recommends = Array.tabulate (adv_max, rpair 0.0) fun inc_recommend j v = @@ -438,27 +417,97 @@ val _ = inc_recommend j o1 val ds = get_deps j val l = Real.fromInt (length ds) - val _ = map (fn d => inc_recommend d (o1 / l)) ds in - for (k + 1) + List.app (fn d => inc_recommend d (o1 / l)) ds; for (k + 1) end - val _ = for 0 - val _ = heap (Real.compare o pairself snd) advno recommends - fun ret acc at = if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1) in + for 0; + heap (Real.compare o pairself snd) advno recommends; ret [] (Integer.max 0 (adv_max - advno)) end +(* Two arguments control the behaviour of nbayes: prior and ess. Prior expresses our belief in + usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the + prior. *) +fun nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms = + let + val afreq = Unsynchronized.ref 0 + val tfreq = Array.array (avail_num, 0) + val sfreq = Array.array (avail_num, Inttab.empty) + + fun nb_learn syms ts = + let + fun add_sym hpis sym = + let + val im = Array.sub (sfreq, hpis) + val v = the_default 0 (Inttab.lookup im sym) + in + Array.update(sfreq, hpis, Inttab.update (sym, v + 1) im) + end + + fun add_th t = + (Array.update (tfreq, t, Array.sub (tfreq, t) + 1); List.app (add_sym t) syms) + in + afreq := !afreq + 1; + List.app add_th ts + end + + fun nb_eval syms = + let + fun log_posterior i = + let + val symh = fold (fn s => fn sf => Inttab.update (s, ()) sf) syms Inttab.empty + val n = Real.fromInt (Array.sub (tfreq, i)) + val sfreqh = Array.sub (sfreq, i) + val p = if prior > 0.0 then prior else ess / Real.fromInt (!afreq) + val mp = ess * p + val logmp = Math.ln mp + val lognmp = Math.ln (n + mp) + + fun in_sfreqh (s, sfreqv) (sofar, sfsymh) = + let val sfreqv = Real.fromInt sfreqv in + if Inttab.defined sfsymh s then + (sofar + Math.ln (sfreqv + mp), Inttab.delete s sfsymh) + else + (sofar + Math.ln (n - sfreqv + mp), sfsymh) + end + + val (postsfreqh, symh) = Inttab.fold in_sfreqh sfreqh (Math.ln n, symh) + val len_mem = length (Inttab.keys symh) + val len_nomem = sym_num - len_mem - length (Inttab.keys sfreqh) + in + postsfreqh + Real.fromInt len_mem * logmp + Real.fromInt len_nomem * lognmp - + Real.fromInt sym_num * Math.ln(n + ess) + end + + val posterior = Array.tabulate (adv_max, swap o `log_posterior) + + fun ret acc at = + if at = Array.length posterior then acc + else ret (Array.sub (posterior,at) :: acc) (at + 1) + in + heap (Real.compare o pairself snd) advno posterior; + ret [] (Integer.max 0 (adv_max - advno)) + end + + fun for i = + if i = avail_num then () else (nb_learn (get_th_syms i) (get_deps i); for (i + 1)) + in + for 0; nb_eval syms + end + val knns = 40 (* FUDGE *) +val ess = 0.00001 (* FUDGE *) +val prior = 0.001 (* FUDGE *) fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) -fun query ctxt parents access_G max_suggs hints feats = +fun query ctxt engine parents access_G max_suggs hints feats = let val str_of_feat = space_implode "|" @@ -470,9 +519,9 @@ |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @ (if null hints then [] else [(".goal", feats, hints)]) - val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) = + val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) = fold (fn (fact, feats, deps) => - fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) => + fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) => let fun add_feat (feat, weight) (xtab as (n, tab, _)) = (case Symtab.lookup tab feat of @@ -481,7 +530,7 @@ val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab in - (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss, + (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss, add_to_xtab fact fact_xtab, feat_xtab') end) all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) @@ -490,22 +539,40 @@ val fact_vec = Vector.fromList facts val deps_vec = Vector.fromList (rev rev_depss) - val facts_ary = Array.array (num_feats, []) - val _ = - fold (fn feats => fn fact => - let val fact' = fact - 1 in - List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat) - feats; - fact' - end) - featss (length featss) + + val avail_num = Vector.length deps_vec + val adv_max = length visible_facts + val get_deps = curry Vector.sub deps_vec + val advno = max_suggs in trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ elide_string 1000 (space_implode " " facts) ^ "}"); - knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec) - (curry Array.sub facts_ary) knns max_suggs - (map_filter (fn (feat, weight) => - Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats) + (if engine = MaSh_SML_kNN then + let + val facts_ary = Array.array (num_feats, []) + val _ = + fold (fn feats => fn fact => + let val fact' = fact - 1 in + List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat) + feats; + fact' + end) + rev_featss num_facts + val get_sym_ths = curry Array.sub facts_ary + val syms = map_filter (fn (feat, weight) => + Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats + in + knn avail_num adv_max get_deps get_sym_ths knns advno syms + end + else + let + val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) + val get_th_syms = curry Vector.sub unweighted_feats_ary + val sym_num = num_feats + val unweighted_syms = map_filter (Symtab.lookup feat_tab o str_of_feat o fst) feats + in + nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno unweighted_syms + end) |> map (curry Vector.sub fact_vec o fst) end @@ -596,8 +663,10 @@ fold extract_line_and_add_node node_lines Graph.empty), length node_lines) | LESS => - (if is_mash_sml_enabled () then wipe_out_mash_state_dir () - else MaSh_Py.unlearn ctxt overlord; (Graph.empty, 0)) (* cannot parse old file *) + (* cannot parse old file *) + (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord + else wipe_out_mash_state_dir (); + (Graph.empty, 0)) | GREATER => raise FILE_VERSION_TOO_NEW ()) in trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")"); @@ -645,7 +714,8 @@ fun clear_state ctxt overlord = (* "MaSh_Py.unlearn" also removes the state file *) Synchronized.change global_state (fn _ => - (if is_mash_sml_enabled () then wipe_out_mash_state_dir () else MaSh_Py.unlearn ctxt overlord; + (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord + else wipe_out_mash_state_dir (); (false, empty_state))) end @@ -1224,7 +1294,7 @@ (parents, hints, feats) end - val sml = is_mash_sml_enabled () + val engine = the_mash_engine () val (access_G, py_suggs) = peek_state ctxt overlord (fn {access_G, ...} => @@ -1232,20 +1302,20 @@ (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, [])) else (access_G, - if sml then - [] - else + if engine = MaSh_Py then let val (parents, hints, feats) = query_args access_G in MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats) - end)) + end + else + [])) val sml_suggs = - if sml then + if engine = MaSh_Py then + [] + else let val (parents, hints, feats) = query_args access_G in - MaSh_SML.query ctxt parents access_G max_facts hints feats + MaSh_SML.query ctxt engine parents access_G max_facts hints feats end - else - [] val unknown = filter_out (is_fact_in_graph access_G o snd) facts in @@ -1309,13 +1379,13 @@ |> filter (is_fact_in_graph access_G) |> map nickname_of_thm in - if is_mash_sml_enabled () then + if the_mash_engine () = MaSh_Py then + (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state) + else let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in {access_G = access_G, num_known_facts = num_known_facts + 1, dirty = Option.map (cons name) dirty} end - else - (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state) end); (true, "") end) @@ -1334,7 +1404,7 @@ val timer = Timer.startRealTimer () fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout) - val sml = is_mash_sml_enabled () + val engine = the_mash_engine () val {access_G, ...} = peek_state ctxt overlord I val is_in_access_G = is_fact_in_graph access_G o snd val no_new_facts = forall is_in_access_G facts @@ -1376,11 +1446,11 @@ (false, SOME names, []) => SOME (map #1 learns @ names) | _ => NONE) in - if sml then - () + if engine = MaSh_Py then + (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns); + MaSh_Py.relearn ctxt overlord save relearns) else - (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns); - MaSh_Py.relearn ctxt overlord save relearns); + (); {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty} end @@ -1613,7 +1683,7 @@ |> Par_List.map (apsnd (fn f => f ())) val mesh = mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess |> add_and_take in - if is_mash_sml_enabled () orelse not save then () else MaSh_Py.save ctxt overlord; + if the_mash_engine () = MaSh_Py andalso save then MaSh_Py.save ctxt overlord else (); (case (fact_filter, mess) of (NONE, [(_, (mepo, _)), (_, (mash, _))]) => [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take), @@ -1623,7 +1693,7 @@ fun kill_learners ctxt ({overlord, ...} : params) = (Async_Manager.kill_threads MaShN "learner"; - if is_mash_sml_enabled () then () else MaSh_Py.shutdown ctxt overlord) + if the_mash_engine () = MaSh_Py then MaSh_Py.shutdown ctxt overlord else ()) fun running_learners () = Async_Manager.running_threads MaShN "learner"