--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 12:35:43 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 12:35:49 2014 +0200
@@ -40,6 +40,7 @@
| MaSh_SML_kNN
| MaSh_SML_kNN_Cpp
| MaSh_SML_NB of bool * bool
+ | MaSh_SML_NB_Cpp
| MaSh_SML_NB_Py
val is_mash_enabled : unit -> bool
@@ -160,6 +161,7 @@
| MaSh_SML_kNN
| MaSh_SML_kNN_Cpp
| MaSh_SML_NB of bool * bool
+| MaSh_SML_NB_Cpp
| MaSh_SML_NB_Py
val default_MaSh_SML_NB = MaSh_SML_NB (false, true)
@@ -177,6 +179,7 @@
| "sml_nbCD" => SOME (MaSh_SML_NB (false, true))
| "sml_nbDC" => SOME (MaSh_SML_NB (true, false))
| "sml_nbDD" => SOME (MaSh_SML_NB (true, true))
+ | "sml_nb_cpp" => SOME MaSh_SML_NB_Cpp
| "sml_nb_py" => SOME MaSh_SML_NB_Py
| _ => NONE)
end
@@ -587,42 +590,47 @@
end
(* experimental *)
-fun k_nearest_neighbors_cpp max_suggs learns cfeats =
+fun c_plus_plus_tool tool max_suggs learns cfeats =
let
- val number_of_nearest_neighbors = 10 (* FUDGE *)
-
- val ocs = TextIO.openOut "adv_syms"
- val ocd = TextIO.openOut "adv_deps"
- val ocq = TextIO.openOut "adv_seq"
- val occ = TextIO.openOut "adv_conj"
+ val ser = string_of_int (serial ()) (* poor person's attempt at thread-safety *)
+ val ocs = TextIO.openOut ("adv_syms" ^ ser)
+ val ocd = TextIO.openOut ("adv_deps" ^ ser)
+ val ocq = TextIO.openOut ("adv_seq" ^ ser)
+ val occ = TextIO.openOut ("adv_conj" ^ ser)
fun os oc s = TextIO.output (oc, s)
- fun ol _ _ _ [] = ()
- | ol _ f _ [e] = f e
+ fun ol _ _ _ [] = ()
+ | ol _ f _ [e] = f e
| ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
fun do_learn (name, feats, deps) =
- (os ocs name; os ocs ":"; ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n";
- os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n";
- os ocq name; os ocq "\n")
+ (os ocs name; os ocs ":";
+ ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n";
+ os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n"; os ocq name; os ocq "\n")
fun forkexec no =
let
val cmd =
- "~/misc/newknn/knn " ^ string_of_int number_of_nearest_neighbors ^
- " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj"
+ "~/misc/" ^ tool ^ " adv_syms" ^ ser ^ " adv_deps" ^ ser ^ " " ^ string_of_int no ^
+ " adv_seq" ^ ser ^ " < adv_conj" ^ ser
in
fst (Isabelle_System.bash_output cmd)
|> space_explode " "
|> filter_out (curry (op =) "")
end
in
- (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats; TextIO.closeOut ocs;
- TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
+ (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats;
+ TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
forkexec max_suggs)
end
+val cpp_number_of_nearest_neighbors = 10 (* FUDGE *)
+
+val k_nearest_neighbors_cpp =
+ c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int cpp_number_of_nearest_neighbors)
+val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
+
fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
@@ -636,6 +644,8 @@
in
if engine = MaSh_SML_kNN_Cpp then
k_nearest_neighbors_cpp max_suggs learns (map fst feats)
+ else if engine = MaSh_SML_NB_Cpp then
+ naive_bayes_cpp max_suggs learns (map fst feats)
else
let
val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =