--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:00 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:07 2014 +0200
@@ -39,7 +39,7 @@
MaSh_Py
| MaSh_SML_kNN
| MaSh_SML_kNN_Cpp
- | MaSh_SML_NB of bool * bool
+ | MaSh_SML_NB
| MaSh_SML_NB_Cpp
| MaSh_SML_NB_Py
@@ -62,8 +62,8 @@
sig
val k_nearest_neighbors : int -> (int -> int list) -> (int -> int list) -> int -> int list ->
int -> int list -> (int * real) list
- val naive_bayes : (bool * bool) -> int -> (int -> int list) -> (int -> int list) -> int ->
- int -> int list -> int list -> (int * real) list
+ val naive_bayes : int -> (int -> int list) -> (int -> int list) -> int -> int -> int list ->
+ int list -> (int * real) list
val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
int -> int -> int list -> (int * real) list
val query : Proof.context -> bool -> mash_engine -> string list -> int ->
@@ -159,32 +159,26 @@
MaSh_Py
| MaSh_SML_kNN
| MaSh_SML_kNN_Cpp
-| MaSh_SML_NB of bool * bool
+| MaSh_SML_NB
| MaSh_SML_NB_Cpp
| MaSh_SML_NB_Py
-val default_MaSh_SML_NB = MaSh_SML_NB (false, true)
-
fun mash_engine () =
let val flag1 = Options.default_string @{system_option MaSh} in
(case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
- "yes" => SOME default_MaSh_SML_NB
+ "yes" => SOME MaSh_SML_NB
| "py" => SOME MaSh_Py
- | "sml" => SOME default_MaSh_SML_NB
+ | "sml" => SOME MaSh_SML_NB
| "sml_knn" => SOME MaSh_SML_kNN
| "sml_knn_cpp" => SOME MaSh_SML_kNN_Cpp
- | "sml_nb" => SOME default_MaSh_SML_NB
- | "sml_nbCC" => SOME (MaSh_SML_NB (false, false))
- | "sml_nbCD" => SOME (MaSh_SML_NB (false, true))
- | "sml_nbDC" => SOME (MaSh_SML_NB (true, false))
- | "sml_nbDD" => SOME (MaSh_SML_NB (true, true))
+ | "sml_nb" => SOME MaSh_SML_NB
| "sml_nb_cpp" => SOME MaSh_SML_NB_Cpp
| "sml_nb_py" => SOME MaSh_SML_NB_Py
| _ => NONE)
end
val is_mash_enabled = is_some o mash_engine
-val the_mash_engine = the_default default_MaSh_SML_NB o mash_engine
+val the_mash_engine = the_default MaSh_SML_NB o mash_engine
(*** Low-level communication with the Python version of MaSh ***)
@@ -527,11 +521,10 @@
learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
end
-fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts max_suggs visible_facts feats
- (tfreq, sfreq, idf) =
+fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, idf) =
let
- val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *)
- val pos_weight = if kuehlwein_params then 10.0 else 2.0 (* FUDGE *)
+ val tau = 0.05 (* FUDGE *)
+ val pos_weight = 10.0 (* FUDGE *)
val def_val = ~15.0 (* FUDGE *)
fun tfidf feat = Vector.sub (idf, feat)
@@ -549,12 +542,7 @@
val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i))
- val fold_sfh =
- if kuehlwein_log then
- (fn (f, sf) => fn sow => sow - tfidf f * Math.ln (Real.fromInt sf / tfreq))
- else
- (fn (f, sf) => fn sow =>
- sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq))
+ fun fold_sfh (f, sf) sow = sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)
val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
in
@@ -570,9 +558,9 @@
ret [] (Integer.max 0 (num_facts - max_suggs))
end
-fun naive_bayes opts num_facts get_deps get_feats num_feats max_suggs visible_facts feats =
+fun naive_bayes num_facts get_deps get_feats num_feats max_suggs visible_facts feats =
learn num_facts get_deps get_feats num_feats
- |> naive_bayes_query opts num_facts max_suggs visible_facts feats
+ |> naive_bayes_query num_facts max_suggs visible_facts feats
(* experimental *)
fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats =
@@ -694,8 +682,8 @@
val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
in
(case engine of
- MaSh_SML_NB opts =>
- naive_bayes opts num_facts get_deps get_unweighted_feats num_feats max_suggs
+ MaSh_SML_NB =>
+ naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs
int_visible_facts int_feats
| MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps
get_unweighted_feats num_feats max_suggs int_feats)