# HG changeset patch # User blanchet # Date 1401262719 -7200 # Node ID 80b7c07e7a73e09b71d29e65ca76a068868dc535 # Parent e4074b91b2a60d76020be1f91b16dda3f3a191c2 tuning diff -r e4074b91b2a6 -r 80b7c07e7a73 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed May 28 03:10:30 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed May 28 09:38:39 2014 +0200 @@ -359,16 +359,17 @@ () end +val number_of_nearest_neighbors = 40 (* FUDGE *) + (* avail_num = maximum number of theorems to check dependencies and symbols adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num get_deps = returns dependencies of a theorem get_sym_ths = get theorems that have this feature - knns = number of nearest neighbours advno = number of predictions to return syms = symbols of the conjecture *) -fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths knns advno syms = +fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths advno syms = let (* Can be later used for TFIDF *) fun sym_wght _ = 1.0 @@ -393,7 +394,7 @@ end val _ = List.app do_sym syms - val _ = heap (Real.compare o pairself snd) knns overlaps_sqr + val _ = heap (Real.compare o pairself snd) number_of_nearest_neighbors overlaps_sqr val recommends = Array.tabulate (adv_max, rpair 0.0) fun inc_recommend j v = @@ -401,7 +402,7 @@ else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j)))) fun for k = - if k = knns orelse k >= adv_max then + if k = number_of_nearest_neighbors orelse k >= adv_max then () else let @@ -422,64 +423,71 @@ ret [] (Integer.max 0 (adv_max - advno)) end -val tau = 0.02 -val posWeight = 2.0 -val defVal = ~15.0 -val defPriWei = 20 +val nb_tau = 0.02 (* FUDGE *) +val nb_pos_weight = 2.0 (* FUDGE *) +val nb_def_val = ~15.0 (* FUDGE *) +val nb_def_prior_weight = 20 (* FUDGE *) fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms = let - val afreq = Unsynchronized.ref 0; - val tfreq = Array.array (avail_num, 0); - val sfreq = Array.array (avail_num, Inttab.empty); - val dffreq = Array.array (sym_num, 0); + val afreq = Unsynchronized.ref 0 + val tfreq = Array.array (avail_num, 0) + val sfreq = Array.array (avail_num, Inttab.empty) + val dffreq = Array.array (sym_num, 0) fun learn th syms deps = let fun add_th t = let - val im = Array.sub (sfreq, t); - fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf; + val im = Array.sub (sfreq, t) + fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf in Array.update (tfreq, t, 1 + Array.sub (tfreq, t)); Array.update (sfreq, t, fold fold_fn syms im) - end; - fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s)); + end + + fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s)) in - List.app add_th (replicate defPriWei th); + List.app add_th (replicate nb_def_prior_weight th); List.app add_th deps; List.app add_sym syms; afreq := !afreq + 1 - end; + end - fun tfidf _ = 1.0; - (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)));*) + fun tfidf _ = 1.0 + (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)))*) fun eval syms = let fun log_posterior i = let - val tfreq = Real.fromInt (Array.sub (tfreq, i)); + val tfreq = Real.fromInt (Array.sub (tfreq, i)) + fun fold_syms (f, fw) (res, sfh) = (case Inttab.lookup sfh f of SOME sf => - (res + tfidf f * fw * Math.ln (posWeight * Real.fromInt sf / tfreq), + (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq), Inttab.delete f sfh) - | NONE => (res + fw * defVal, sfh)); - val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i)); + | NONE => (res + fw * nb_def_val, sfh)) + + val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i)) + fun fold_sfh (f, sf) sow = - sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)); - val sumOfWei = Inttab.fold fold_sfh sfh 0.0; + sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)) + + val sum_of_weights = Inttab.fold fold_sfh sfh 0.0 in - res + tau * sumOfWei + res + nb_tau * sum_of_weights end - val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j))); + + val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j))) + fun ret acc at = - if at = adv_max then acc else ret (Array.sub (posterior,at) :: acc) (at + 1) + if at = adv_max then acc else ret (Array.sub (posterior, at) :: acc) (at + 1) in heap (Real.compare o pairself snd) advno posterior; ret [] (Integer.max 0 (adv_max - advno)) - end; + end fun for i = if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1)) @@ -487,8 +495,6 @@ for 0; eval syms end -val knns = 40 (* FUDGE *) - fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) @@ -544,7 +550,7 @@ rev_featss num_facts val get_facts = curry Array.sub facts_ary in - k_nearest_neighbors num_facts num_visible_facts get_deps get_facts knns max_suggs syms + k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs syms end else let @@ -1317,7 +1323,6 @@ in map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} => let - val name = learned_proof_name () val parents = maximal_wrt_access_graph access_G facts val deps = used_ths |> filter (is_fact_in_graph access_G) @@ -1326,7 +1331,10 @@ if the_mash_engine () = MaSh_Py then (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state) else - let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in + let + val name = learned_proof_name () + val access_G = access_G |> add_node Automatic_Proof name parents feats deps + in {access_G = access_G, num_known_facts = num_known_facts + 1, dirty = Option.map (cons name) dirty} end