# HG changeset patch # User blanchet # Date 1400722175 -7200 # Node ID ea5912e3b00819294d40f47ee9941d2597b61fa9 # Parent 5e30ffe79980867dfae24d83c35930118a0339f2 until naive Bayes supports weights, don't incorporate 'extra' low-weight features diff -r 5e30ffe79980 -r ea5912e3b008 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed May 21 22:06:10 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu May 22 03:29:35 2014 +0200 @@ -375,7 +375,7 @@ advno = number of predictions to return syms = symbols of the conjecture *) -fun knn avail_num adv_max get_deps get_sym_ths knns advno syms = +fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths knns advno syms = let (* Can be later used for TFIDF *) fun sym_wght _ = 1.0 @@ -429,10 +429,10 @@ ret [] (Integer.max 0 (adv_max - advno)) end -(* Two arguments control the behaviour of nbayes: prior and ess. Prior expresses our belief in +(* Two arguments control the behaviour of naive Bayes: prior and ess. Prior expresses our belief in usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the prior. *) -fun nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms = +fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms = let val afreq = Unsynchronized.ref 0 val tfreq = Array.array (avail_num, 0) @@ -540,13 +540,11 @@ val deps_vec = Vector.fromList (rev rev_depss) - val avail_num = Vector.length deps_vec - val adv_max = length visible_facts + val num_visible_facts = length visible_facts val get_deps = curry Vector.sub deps_vec - val advno = max_suggs in trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ - elide_string 1000 (space_implode " " facts) ^ "}"); + elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); (if engine = MaSh_SML_kNN then let val facts_ary = Array.array (num_feats, []) @@ -558,20 +556,20 @@ fact' end) rev_featss num_facts - val get_sym_ths = curry Array.sub facts_ary + val get_facts = curry Array.sub facts_ary val syms = map_filter (fn (feat, weight) => Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats in - knn avail_num adv_max get_deps get_sym_ths knns advno syms + k_nearest_neighbors num_facts num_visible_facts get_deps get_facts knns max_suggs syms end else let val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) - val get_th_syms = curry Vector.sub unweighted_feats_ary - val sym_num = num_feats + val get_unweighted_feats = curry Vector.sub unweighted_feats_ary val unweighted_syms = map_filter (Symtab.lookup feat_tab o str_of_feat o fst) feats in - nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno unweighted_syms + naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats ess prior + max_suggs unweighted_syms end) |> map (curry Vector.sub fact_vec o fst) end @@ -1258,6 +1256,8 @@ let val thy = Proof_Context.theory_of ctxt val thy_name = Context.theory_name thy + val engine = the_mash_engine () + val facts = facts |> sort (crude_thm_ord o pairself snd o swap) val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained) val num_facts = length facts @@ -1284,20 +1284,24 @@ |> map (rpair 1.0) |> map (chained_or_extra_features_of chained_feature_factor) |> rpair [] |-> fold (union (eq_fst (op =))) - val extra_feats = facts - |> take (Int.max (0, num_extra_feature_facts - length chained)) - |> filter fact_has_right_theory - |> weight_facts_steeply - |> map (chained_or_extra_features_of extra_feature_factor) - |> rpair [] |-> fold (union (eq_fst (op =))) + val extra_feats = + (* As long as SML NB does not support weights, it makes little sense to include these + extra features. *) + if engine = MaSh_SML_NB then + [] + else + facts + |> take (Int.max (0, num_extra_feature_facts - length chained)) + |> filter fact_has_right_theory + |> weight_facts_steeply + |> map (chained_or_extra_features_of extra_feature_factor) + |> rpair [] |-> fold (union (eq_fst (op =))) val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats |> debug ? sort (Real.compare o swap o pairself snd) in (parents, hints, feats) end - val engine = the_mash_engine () - val (access_G, py_suggs) = peek_state ctxt overlord (fn {access_G, ...} => if Graph.is_empty access_G then