--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed Jun 18 17:42:24 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Wed Jun 18 17:42:24 2014 +0200
@@ -35,7 +35,11 @@
val encode_features : (string * real) list -> string
val extract_suggestions : string -> string * (string * real) list
- datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py
+ datatype mash_engine =
+ MaSh_Py
+ | MaSh_SML_kNN
+ | MaSh_SML_NB of bool * bool
+ | MaSh_SML_NB_Py
val is_mash_enabled : unit -> bool
val the_mash_engine : unit -> mash_engine
@@ -56,8 +60,8 @@
sig
val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) ->
int -> (int * real) list -> (int * real) list
- val naive_bayes : int -> int -> (int -> int list) -> (int -> int list) -> int -> int ->
- (int * real) list -> (int * real) list
+ val naive_bayes : (bool * bool) -> int -> int -> (int -> int list) -> (int -> int list) ->
+ int -> int -> (int * real) list -> (int * real) list
val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) ->
(int -> int list) -> int -> int -> (int * real) list -> (int * real) list
val query : Proof.context -> bool -> mash_engine -> string list -> int ->
@@ -150,22 +154,32 @@
()
end
-datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py
+datatype mash_engine =
+ MaSh_Py
+| MaSh_SML_kNN
+| MaSh_SML_NB of bool * bool
+| MaSh_SML_NB_Py
+
+val default_MaSh_SML_NB = MaSh_SML_NB (false, false)
fun mash_engine () =
let val flag1 = Options.default_string @{system_option MaSh} in
(case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
- "yes" => SOME MaSh_SML_NB
+ "yes" => SOME default_MaSh_SML_NB
| "py" => SOME MaSh_Py
- | "sml" => SOME MaSh_SML_NB
+ | "sml" => SOME default_MaSh_SML_NB
| "sml_knn" => SOME MaSh_SML_kNN
- | "sml_nb" => SOME MaSh_SML_NB
+ | "sml_nb" => SOME default_MaSh_SML_NB
+ | "sml_nbCC" => SOME (MaSh_SML_NB (false, false))
+ | "sml_nbCD" => SOME (MaSh_SML_NB (false, true))
+ | "sml_nbDC" => SOME (MaSh_SML_NB (true, false))
+ | "sml_nbDD" => SOME (MaSh_SML_NB (true, true))
| "sml_nb_py" => SOME MaSh_SML_NB_Py
| _ => NONE)
end
val is_mash_enabled = is_some o mash_engine
-val the_mash_engine = the_default MaSh_SML_NB o mash_engine
+val the_mash_engine = the_default default_MaSh_SML_NB o mash_engine
(*** Low-level communication with the Python version of MaSh ***)
@@ -498,15 +512,13 @@
for 0; (Array.vector tfreq, Array.vector sfreq (*, Array.vector dffreq *))
end
-val nb_kuehlwein_style_log = false
-val nb_kuehlwein_style_params = false
+fun naive_bayes_query (kuehlwein_log, kuehlwein_params) _ (* num_facts *) num_visible_facts
+ max_suggs feats (tfreq, sfreq (*, dffreq*)) =
+ let
+ val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *)
+ val pos_weight = if kuehlwein_params then 10.0 else 2.0 (* FUDGE *)
+ val def_val = ~15.0 (* FUDGE *)
-val nb_tau = if nb_kuehlwein_style_params then 0.05 else 0.02 (* FUDGE *)
-val nb_pos_weight = if nb_kuehlwein_style_params then 10.0 else 2.0 (* FUDGE *)
-val nb_def_val = ~15.0 (* FUDGE *)
-
-fun naive_bayes_query _ (* num_facts *) num_visible_facts max_suggs feats (tfreq, sfreq (*, dffreq*)) =
- let
(*
val afreq = Real.fromInt num_facts
fun tfidf feat = Math.ln afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
@@ -520,14 +532,14 @@
fun fold_feats (f, fw) (res, sfh) =
(case Inttab.lookup sfh f of
SOME sf =>
- (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq),
+ (res + tfidf f * fw * Math.ln (pos_weight * Real.fromInt sf / tfreq),
Inttab.delete f sfh)
- | NONE => (res + fw * nb_def_val, sfh))
+ | NONE => (res + fw * def_val, sfh))
val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i))
val fold_sfh =
- if nb_kuehlwein_style_log then
+ if kuehlwein_log then
(fn (f, sf) => fn sow => sow - tfidf f * Math.ln (Real.fromInt sf / tfreq))
else
(fn (f, sf) => fn sow =>
@@ -535,7 +547,7 @@
val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
in
- res + nb_tau * sum_of_weights
+ res + tau * sum_of_weights
end
val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j)))
@@ -547,9 +559,9 @@
ret [] (Integer.max 0 (num_visible_facts - max_suggs))
end
-fun naive_bayes num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
+fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
naive_bayes_learn num_facts get_deps get_feats num_feats
- |> naive_bayes_query num_facts num_visible_facts max_suggs feats
+ |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats
(* experimental *)
fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs
@@ -633,7 +645,9 @@
val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
in
- (if engine = MaSh_SML_NB then naive_bayes else naive_bayes_py ctxt overlord)
+ (case engine of
+ MaSh_SML_NB opts => naive_bayes opts
+ | _ => naive_bayes_py ctxt overlord)
num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats'
end)
|> map (curry Vector.sub fact_vec o fst)