# HG changeset patch # User blanchet # Date 1403106144 -7200 # Node ID 8f7d6f01a77588f588208edf020693d893442efb # Parent 31b35f5a5fdb845926fe6decdb799fe3503fe4d8 more MaSh engine variations, for evaluations diff -r 31b35f5a5fdb -r 8f7d6f01a775 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed Jun 18 17:42:24 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed Jun 18 17:42:24 2014 +0200 @@ -35,7 +35,11 @@ val encode_features : (string * real) list -> string val extract_suggestions : string -> string * (string * real) list - datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py + datatype mash_engine = + MaSh_Py + | MaSh_SML_kNN + | MaSh_SML_NB of bool * bool + | MaSh_SML_NB_Py val is_mash_enabled : unit -> bool val the_mash_engine : unit -> mash_engine @@ -56,8 +60,8 @@ sig val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) -> int -> (int * real) list -> (int * real) list - val naive_bayes : int -> int -> (int -> int list) -> (int -> int list) -> int -> int -> - (int * real) list -> (int * real) list + val naive_bayes : (bool * bool) -> int -> int -> (int -> int list) -> (int -> int list) -> + int -> int -> (int * real) list -> (int * real) list val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) -> (int -> int list) -> int -> int -> (int * real) list -> (int * real) list val query : Proof.context -> bool -> mash_engine -> string list -> int -> @@ -150,22 +154,32 @@ () end -datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py +datatype mash_engine = + MaSh_Py +| MaSh_SML_kNN +| MaSh_SML_NB of bool * bool +| MaSh_SML_NB_Py + +val default_MaSh_SML_NB = MaSh_SML_NB (false, false) fun mash_engine () = let val flag1 = Options.default_string @{system_option MaSh} in (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of - "yes" => SOME MaSh_SML_NB + "yes" => SOME default_MaSh_SML_NB | "py" => SOME MaSh_Py - | "sml" => SOME MaSh_SML_NB + | "sml" => SOME default_MaSh_SML_NB | "sml_knn" => SOME MaSh_SML_kNN - | "sml_nb" => SOME MaSh_SML_NB + | "sml_nb" => SOME default_MaSh_SML_NB + | "sml_nbCC" => SOME (MaSh_SML_NB (false, false)) + | "sml_nbCD" => SOME (MaSh_SML_NB (false, true)) + | "sml_nbDC" => SOME (MaSh_SML_NB (true, false)) + | "sml_nbDD" => SOME (MaSh_SML_NB (true, true)) | "sml_nb_py" => SOME MaSh_SML_NB_Py | _ => NONE) end val is_mash_enabled = is_some o mash_engine -val the_mash_engine = the_default MaSh_SML_NB o mash_engine +val the_mash_engine = the_default default_MaSh_SML_NB o mash_engine (*** Low-level communication with the Python version of MaSh ***) @@ -498,15 +512,13 @@ for 0; (Array.vector tfreq, Array.vector sfreq (*, Array.vector dffreq *)) end -val nb_kuehlwein_style_log = false -val nb_kuehlwein_style_params = false +fun naive_bayes_query (kuehlwein_log, kuehlwein_params) _ (* num_facts *) num_visible_facts + max_suggs feats (tfreq, sfreq (*, dffreq*)) = + let + val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *) + val pos_weight = if kuehlwein_params then 10.0 else 2.0 (* FUDGE *) + val def_val = ~15.0 (* FUDGE *) -val nb_tau = if nb_kuehlwein_style_params then 0.05 else 0.02 (* FUDGE *) -val nb_pos_weight = if nb_kuehlwein_style_params then 10.0 else 2.0 (* FUDGE *) -val nb_def_val = ~15.0 (* FUDGE *) - -fun naive_bayes_query _ (* num_facts *) num_visible_facts max_suggs feats (tfreq, sfreq (*, dffreq*)) = - let (* val afreq = Real.fromInt num_facts fun tfidf feat = Math.ln afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat))) @@ -520,14 +532,14 @@ fun fold_feats (f, fw) (res, sfh) = (case Inttab.lookup sfh f of SOME sf => - (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq), + (res + tfidf f * fw * Math.ln (pos_weight * Real.fromInt sf / tfreq), Inttab.delete f sfh) - | NONE => (res + fw * nb_def_val, sfh)) + | NONE => (res + fw * def_val, sfh)) val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i)) val fold_sfh = - if nb_kuehlwein_style_log then + if kuehlwein_log then (fn (f, sf) => fn sow => sow - tfidf f * Math.ln (Real.fromInt sf / tfreq)) else (fn (f, sf) => fn sow => @@ -535,7 +547,7 @@ val sum_of_weights = Inttab.fold fold_sfh sfh 0.0 in - res + nb_tau * sum_of_weights + res + tau * sum_of_weights end val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j))) @@ -547,9 +559,9 @@ ret [] (Integer.max 0 (num_visible_facts - max_suggs)) end -fun naive_bayes num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats = +fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats = naive_bayes_learn num_facts get_deps get_feats num_feats - |> naive_bayes_query num_facts num_visible_facts max_suggs feats + |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats (* experimental *) fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs @@ -633,7 +645,9 @@ val get_unweighted_feats = curry Vector.sub unweighted_feats_ary val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats in - (if engine = MaSh_SML_NB then naive_bayes else naive_bayes_py ctxt overlord) + (case engine of + MaSh_SML_NB opts => naive_bayes opts + | _ => naive_bayes_py ctxt overlord) num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats' end) |> map (curry Vector.sub fact_vec o fst)