added another experimental engine
authorblanchet
Tue, 24 Jun 2014 12:35:49 +0200
changeset 57297 3d4647ea3e57
parent 57296 8a98f08a0523
child 57298 2502adc3c3f6
added another experimental engine
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jun 24 12:35:43 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jun 24 12:35:49 2014 +0200
@@ -40,6 +40,7 @@
   | MaSh_SML_kNN
   | MaSh_SML_kNN_Cpp
   | MaSh_SML_NB of bool * bool
+  | MaSh_SML_NB_Cpp
   | MaSh_SML_NB_Py
 
   val is_mash_enabled : unit -> bool
@@ -160,6 +161,7 @@
 | MaSh_SML_kNN
 | MaSh_SML_kNN_Cpp
 | MaSh_SML_NB of bool * bool
+| MaSh_SML_NB_Cpp
 | MaSh_SML_NB_Py
 
 val default_MaSh_SML_NB = MaSh_SML_NB (false, true)
@@ -177,6 +179,7 @@
     | "sml_nbCD" => SOME (MaSh_SML_NB (false, true))
     | "sml_nbDC" => SOME (MaSh_SML_NB (true, false))
     | "sml_nbDD" => SOME (MaSh_SML_NB (true, true))
+    | "sml_nb_cpp" => SOME MaSh_SML_NB_Cpp
     | "sml_nb_py" => SOME MaSh_SML_NB_Py
     | _ => NONE)
   end
@@ -587,42 +590,47 @@
   end
 
 (* experimental *)
-fun k_nearest_neighbors_cpp max_suggs learns cfeats =
+fun c_plus_plus_tool tool max_suggs learns cfeats =
   let
-    val number_of_nearest_neighbors = 10 (* FUDGE *)
-
-    val ocs = TextIO.openOut "adv_syms"
-    val ocd = TextIO.openOut "adv_deps"
-    val ocq = TextIO.openOut "adv_seq"
-    val occ = TextIO.openOut "adv_conj"
+    val ser = string_of_int (serial ()) (* poor person's attempt at thread-safety *)
+    val ocs = TextIO.openOut ("adv_syms" ^ ser)
+    val ocd = TextIO.openOut ("adv_deps" ^ ser)
+    val ocq = TextIO.openOut ("adv_seq" ^ ser)
+    val occ = TextIO.openOut ("adv_conj" ^ ser)
 
     fun os oc s = TextIO.output (oc, s)
 
-    fun ol _ _ _   [] = ()
-      | ol _ f _   [e] = f e
+    fun ol _ _ _ [] = ()
+      | ol _ f _ [e] = f e
       | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
 
     fun do_learn (name, feats, deps) =
-      (os ocs name; os ocs ":"; ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n";
-       os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n";
-       os ocq name; os ocq "\n")
+      (os ocs name; os ocs ":";
+       ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n";
+       os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n"; os ocq name; os ocq "\n")
 
     fun forkexec no =
       let
         val cmd =
-          "~/misc/newknn/knn " ^ string_of_int number_of_nearest_neighbors ^
-          " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj"
+          "~/misc/" ^ tool ^ " adv_syms" ^ ser ^ " adv_deps" ^ ser ^ " " ^ string_of_int no ^
+          " adv_seq" ^ ser ^ " < adv_conj" ^ ser
       in
         fst (Isabelle_System.bash_output cmd)
         |> space_explode " "
         |> filter_out (curry (op =) "")
       end
   in
-    (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats; TextIO.closeOut ocs;
-     TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
+    (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats;
+     TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
      forkexec max_suggs)
   end
 
+val cpp_number_of_nearest_neighbors = 10 (* FUDGE *)
+
+val k_nearest_neighbors_cpp =
+  c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int cpp_number_of_nearest_neighbors)
+val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
+
 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
@@ -636,6 +644,8 @@
   in
     if engine = MaSh_SML_kNN_Cpp then
       k_nearest_neighbors_cpp max_suggs learns (map fst feats)
+    else if engine = MaSh_SML_NB_Cpp then
+      naive_bayes_cpp max_suggs learns (map fst feats)
     else
       let
         val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =