# HG changeset patch # User blanchet # Date 1401274967 -7200 # Node ID 3e6af473d666c0a8c4ee617443ed049db895c552 # Parent c881a983a19fa50c397a93ea8688d16eec0302aa tuning diff -r c881a983a19f -r 3e6af473d666 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed May 28 12:34:26 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed May 28 13:02:47 2014 +0200 @@ -362,19 +362,20 @@ val number_of_nearest_neighbors = 40 (* FUDGE *) (* - avail_num = maximum number of theorems to check dependencies and symbols - adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num + num_facts = maximum number of theorems to check dependencies and symbols + num_visible_facts = do not return theorems over or equal to this number. + Must satisfy: num_visible_facts <= num_facts. get_deps = returns dependencies of a theorem get_sym_ths = get theorems that have this feature - advno = number of predictions to return - syms = symbols of the conjecture + max_suggs = number of suggestions to return + feats = features of the goal *) -fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths advno syms = +fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats = let (* Can be later used for TFIDF *) fun sym_wght _ = 1.0 - val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0))) + val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0) fun inc_overlap j v = let @@ -383,30 +384,30 @@ Array.update (overlaps_sqr, j, (j, v + ov)) end - fun do_sym (s, con_wght) = + fun do_feat (s, con_wght) = let val sw = sym_wght s val w2 = sw * sw * con_wght - fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else () + fun do_th (j, prem_wght) = if j < num_facts then inc_overlap j (w2 * prem_wght) else () in List.app do_th (get_sym_ths s) end - val _ = List.app do_sym syms + val _ = List.app do_feat feats val _ = heap (Real.compare o pairself snd) number_of_nearest_neighbors overlaps_sqr - val recommends = Array.tabulate (adv_max, rpair 0.0) + val recommends = Array.tabulate (num_visible_facts, rpair 0.0) fun inc_recommend j v = - if j >= adv_max then () + if j >= num_visible_facts then () else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j)))) fun for k = - if k = number_of_nearest_neighbors orelse k >= adv_max then + if k = number_of_nearest_neighbors orelse k >= num_visible_facts then () else let - val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1) + val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1) val o1 = Math.sqrt o2 val _ = inc_recommend j o1 val ds = get_deps j @@ -419,8 +420,8 @@ if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1) in for 0; - heap (Real.compare o pairself snd) advno recommends; - ret [] (Integer.max 0 (adv_max - advno)) + heap (Real.compare o pairself snd) max_suggs recommends; + ret [] (Integer.max 0 (num_visible_facts - max_suggs)) end val nb_tau = 0.02 (* FUDGE *) @@ -428,14 +429,14 @@ val nb_def_val = ~15.0 (* FUDGE *) val nb_def_prior_weight = 20 (* FUDGE *) -fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms = +fun naive_bayes_learn num_facts get_deps get_th_feats num_feats = let val afreq = Unsynchronized.ref 0 - val tfreq = Array.array (avail_num, 0) - val sfreq = Array.array (avail_num, Inttab.empty) - val dffreq = Array.array (sym_num, 0) + val tfreq = Array.array (num_facts, 0) + val sfreq = Array.array (num_facts, Inttab.empty) + val dffreq = Array.array (num_feats, 0) - fun learn th syms deps = + fun learn th feats deps = let fun add_th weight t = let @@ -443,57 +444,61 @@ fun fold_fn s sf = Inttab.update (s, weight + the_default 0 (Inttab.lookup im s)) sf in Array.update (tfreq, t, weight + Array.sub (tfreq, t)); - Array.update (sfreq, t, fold fold_fn syms im) + Array.update (sfreq, t, fold fold_fn feats im) end fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s)) in add_th nb_def_prior_weight th; List.app (add_th 1) deps; - List.app add_sym syms; + List.app add_sym feats; afreq := !afreq + 1 end - fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym))) + fun for i = + if i = num_facts then () else (learn i (get_th_feats i) (get_deps i); for (i + 1)) + in + for 0; (Real.fromInt (!afreq), Array.vector tfreq, Array.vector sfreq, Array.vector dffreq) + end - fun eval syms = +fun naive_bayes_query num_visible_facts max_suggs feats (afreq, tfreq, sfreq, dffreq) = + let + fun tfidf feat = Math.ln afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat))) + + fun log_posterior i = let - fun log_posterior i = - let - val tfreq = Real.fromInt (Array.sub (tfreq, i)) - - fun fold_syms (f, fw) (res, sfh) = - (case Inttab.lookup sfh f of - SOME sf => - (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq), - Inttab.delete f sfh) - | NONE => (res + fw * nb_def_val, sfh)) + val tfreq = Real.fromInt (Vector.sub (tfreq, i)) - val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i)) - - fun fold_sfh (f, sf) sow = - sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)) + fun fold_feats (f, fw) (res, sfh) = + (case Inttab.lookup sfh f of + SOME sf => + (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq), + Inttab.delete f sfh) + | NONE => (res + fw * nb_def_val, sfh)) - val sum_of_weights = Inttab.fold fold_sfh sfh 0.0 - in - res + nb_tau * sum_of_weights - end + val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i)) - val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j))) + fun fold_sfh (f, sf) sow = + sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)) - fun ret acc at = - if at = adv_max then acc else ret (Array.sub (posterior, at) :: acc) (at + 1) + val sum_of_weights = Inttab.fold fold_sfh sfh 0.0 in - heap (Real.compare o pairself snd) advno posterior; - ret [] (Integer.max 0 (adv_max - advno)) + res + nb_tau * sum_of_weights end - fun for i = - if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1)) + val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j))) + + fun ret acc at = + if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1) in - for 0; eval syms + heap (Real.compare o pairself snd) max_suggs posterior; + ret [] (Integer.max 0 (num_visible_facts - max_suggs)) end +fun naive_bayes num_facts num_visible_facts get_deps get_th_feats num_feats max_suggs feats = + naive_bayes_learn num_facts get_deps get_th_feats num_feats + |> naive_bayes_query num_visible_facts max_suggs feats + fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) @@ -547,19 +552,19 @@ end) rev_featss num_facts val get_facts = curry Array.sub facts_ary - val syms = map_filter (fn (feat, weight) => + val feats' = map_filter (fn (feat, weight) => Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats in - k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs syms + k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats' end else let val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) val get_unweighted_feats = curry Vector.sub unweighted_feats_ary - val syms = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats + val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats in naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs - syms + feats' end) |> map (curry Vector.sub fact_vec o fst) end