--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 08:19:55 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 08:19:55 2014 +0200
@@ -38,6 +38,7 @@
datatype mash_engine =
MaSh_Py
| MaSh_SML_kNN
+ | MaSh_SML_kNN_Cpp
| MaSh_SML_NB of bool * bool
| MaSh_SML_NB_Py
@@ -157,6 +158,7 @@
datatype mash_engine =
MaSh_Py
| MaSh_SML_kNN
+| MaSh_SML_kNN_Cpp
| MaSh_SML_NB of bool * bool
| MaSh_SML_NB_Py
@@ -169,6 +171,7 @@
| "py" => SOME MaSh_Py
| "sml" => SOME default_MaSh_SML_NB
| "sml_knn" => SOME MaSh_SML_kNN
+ | "sml_knn_cpp" => SOME MaSh_SML_kNN_Cpp
| "sml_nb" => SOME default_MaSh_SML_NB
| "sml_nbCC" => SOME (MaSh_SML_NB (false, false))
| "sml_nbCD" => SOME (MaSh_SML_NB (false, true))
@@ -583,6 +586,39 @@
|> map (apfst fact_of_name)
end
+(* experimental *)
+fun k_nearest_neighbors_cpp avail_num adv_max get_deps get_syms advno syms =
+ let
+ val ocs = TextIO.openOut "adv_syms"
+ val ocd = TextIO.openOut "adv_deps"
+ val ocq = TextIO.openOut "adv_seq"
+ val occ = TextIO.openOut "adv_conj"
+ fun os oc s = TextIO.output (oc, s)
+ fun oi oc i = os oc (Int.toString i)
+ fun ol _ _ _ [] = ()
+ | ol _ f _ [e] = f e
+ | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
+ fun do_n n =
+ (oi ocs n; os ocs ":"; ol ocs (fn i => (os ocs "\""; oi ocs i; os ocs "\"")) ", " (get_syms n); os ocs "\n";
+ oi ocd n; os ocd ":"; ol ocd (fn i => oi ocd i) " " (get_deps n); os ocd "\n";
+ oi ocq n; os ocq "\n")
+ fun for n = if n = avail_num then () else (do_n n; for (n + 1))
+ fun forkexec no =
+ let
+ val cmd =
+ "~/misc/predict/knn " ^ string_of_int number_of_nearest_neighbors ^
+ " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj"
+ in
+ fst (Isabelle_System.bash_output cmd)
+ |> space_explode " "
+ |> map_filter (Option.map (rpair 1.0) o Int.fromString)
+ end
+ in
+ (for 0; ol occ (fn i => (os occ "\""; oi occ i; os occ "\"")) ", " syms; TextIO.closeOut ocs;
+ TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
+ forkexec (advno + avail_num - adv_max))
+ end
+
fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
@@ -619,8 +655,7 @@
val num_visible_facts = length visible_facts
val get_deps = curry Vector.sub deps_vec
in
- trace_msg ctxt (fn () => "MaSh_SML " ^ (if engine = MaSh_SML_kNN then "kNN" else "NB") ^
- " query " ^ encode_features feats ^ " from {" ^
+ trace_msg ctxt (fn () => "MaSh_SML " ^ " query " ^ encode_features feats ^ " from {" ^
elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
(if engine = MaSh_SML_kNN then
let
@@ -646,9 +681,14 @@
val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
in
(case engine of
- MaSh_SML_NB opts => naive_bayes opts
- | _ => naive_bayes_py ctxt overlord)
- num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats'
+ MaSh_SML_kNN_Cpp =>
+ k_nearest_neighbors_cpp num_facts num_visible_facts get_deps get_unweighted_feats
+ max_suggs (map fst feats')
+ | MaSh_SML_NB opts =>
+ naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
+ max_suggs feats'
+ | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
+ get_unweighted_feats num_feats max_suggs feats')
end)
|> map (curry Vector.sub fact_vec o fst)
end