added experimental MaSh engine
authorblanchet
Tue, 24 Jun 2014 08:19:55 +0200
changeset 57291 1bac14e0a728
parent 57290 bc06471cb7b7
child 57292 d20cf3ec7fa7
added experimental MaSh engine
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jun 24 08:19:55 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jun 24 08:19:55 2014 +0200
@@ -38,6 +38,7 @@
   datatype mash_engine =
     MaSh_Py
   | MaSh_SML_kNN
+  | MaSh_SML_kNN_Cpp
   | MaSh_SML_NB of bool * bool
   | MaSh_SML_NB_Py
 
@@ -157,6 +158,7 @@
 datatype mash_engine =
   MaSh_Py
 | MaSh_SML_kNN
+| MaSh_SML_kNN_Cpp
 | MaSh_SML_NB of bool * bool
 | MaSh_SML_NB_Py
 
@@ -169,6 +171,7 @@
     | "py" => SOME MaSh_Py
     | "sml" => SOME default_MaSh_SML_NB
     | "sml_knn" => SOME MaSh_SML_kNN
+    | "sml_knn_cpp" => SOME MaSh_SML_kNN_Cpp
     | "sml_nb" => SOME default_MaSh_SML_NB
     | "sml_nbCC" => SOME (MaSh_SML_NB (false, false))
     | "sml_nbCD" => SOME (MaSh_SML_NB (false, true))
@@ -583,6 +586,39 @@
     |> map (apfst fact_of_name)
   end
 
+(* experimental *)
+fun k_nearest_neighbors_cpp avail_num adv_max get_deps get_syms advno syms =
+  let
+    val ocs = TextIO.openOut "adv_syms"
+    val ocd = TextIO.openOut "adv_deps"
+    val ocq = TextIO.openOut "adv_seq"
+    val occ = TextIO.openOut "adv_conj"
+    fun os oc s = TextIO.output (oc, s)
+    fun oi oc i = os oc (Int.toString i)
+    fun ol _  _ _   [] = ()
+      | ol _  f _   [e] = f e
+      | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
+    fun do_n n =
+      (oi ocs n; os ocs ":"; ol ocs (fn i => (os ocs "\""; oi ocs i; os ocs "\"")) ", " (get_syms n); os ocs "\n";
+       oi ocd n; os ocd ":"; ol ocd (fn i => oi ocd i) " " (get_deps n); os ocd "\n";
+       oi ocq n; os ocq "\n")
+    fun for n = if n = avail_num then () else (do_n n; for (n + 1))
+    fun forkexec no =
+      let
+        val cmd =
+          "~/misc/predict/knn " ^ string_of_int number_of_nearest_neighbors ^
+          " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj"
+      in
+        fst (Isabelle_System.bash_output cmd)
+        |> space_explode " "
+        |> map_filter (Option.map (rpair 1.0) o Int.fromString)
+      end
+  in
+    (for 0; ol occ (fn i => (os occ "\""; oi occ i; os occ "\"")) ", " syms; TextIO.closeOut ocs;
+     TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
+     forkexec (advno + avail_num - adv_max))
+  end
+
 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
@@ -619,8 +655,7 @@
     val num_visible_facts = length visible_facts
     val get_deps = curry Vector.sub deps_vec
   in
-    trace_msg ctxt (fn () => "MaSh_SML " ^ (if engine = MaSh_SML_kNN then "kNN" else "NB") ^
-      " query " ^ encode_features feats ^ " from {" ^
+    trace_msg ctxt (fn () => "MaSh_SML " ^ " query " ^ encode_features feats ^ " from {" ^
       elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
     (if engine = MaSh_SML_kNN then
        let
@@ -646,9 +681,14 @@
          val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
        in
          (case engine of
-           MaSh_SML_NB opts => naive_bayes opts
-         | _ => naive_bayes_py ctxt overlord)
-           num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats'
+           MaSh_SML_kNN_Cpp =>
+           k_nearest_neighbors_cpp num_facts num_visible_facts get_deps get_unweighted_feats
+             max_suggs (map fst feats')
+         | MaSh_SML_NB opts =>
+           naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
+             max_suggs feats'
+         | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
+             get_unweighted_feats num_feats max_suggs feats')
        end)
     |> map (curry Vector.sub fact_vec o fst)
   end