# HG changeset patch # User blanchet # Date 1403590795 -7200 # Node ID 1bac14e0a7289535674795f45b83c3cdc15dfc8d # Parent bc06471cb7b7dc565c3029a260ec7b022a2606b6 added experimental MaSh engine diff -r bc06471cb7b7 -r 1bac14e0a728 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 08:19:55 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 08:19:55 2014 +0200 @@ -38,6 +38,7 @@ datatype mash_engine = MaSh_Py | MaSh_SML_kNN + | MaSh_SML_kNN_Cpp | MaSh_SML_NB of bool * bool | MaSh_SML_NB_Py @@ -157,6 +158,7 @@ datatype mash_engine = MaSh_Py | MaSh_SML_kNN +| MaSh_SML_kNN_Cpp | MaSh_SML_NB of bool * bool | MaSh_SML_NB_Py @@ -169,6 +171,7 @@ | "py" => SOME MaSh_Py | "sml" => SOME default_MaSh_SML_NB | "sml_knn" => SOME MaSh_SML_kNN + | "sml_knn_cpp" => SOME MaSh_SML_kNN_Cpp | "sml_nb" => SOME default_MaSh_SML_NB | "sml_nbCC" => SOME (MaSh_SML_NB (false, false)) | "sml_nbCD" => SOME (MaSh_SML_NB (false, true)) @@ -583,6 +586,39 @@ |> map (apfst fact_of_name) end +(* experimental *) +fun k_nearest_neighbors_cpp avail_num adv_max get_deps get_syms advno syms = + let + val ocs = TextIO.openOut "adv_syms" + val ocd = TextIO.openOut "adv_deps" + val ocq = TextIO.openOut "adv_seq" + val occ = TextIO.openOut "adv_conj" + fun os oc s = TextIO.output (oc, s) + fun oi oc i = os oc (Int.toString i) + fun ol _ _ _ [] = () + | ol _ f _ [e] = f e + | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t) + fun do_n n = + (oi ocs n; os ocs ":"; ol ocs (fn i => (os ocs "\""; oi ocs i; os ocs "\"")) ", " (get_syms n); os ocs "\n"; + oi ocd n; os ocd ":"; ol ocd (fn i => oi ocd i) " " (get_deps n); os ocd "\n"; + oi ocq n; os ocq "\n") + fun for n = if n = avail_num then () else (do_n n; for (n + 1)) + fun forkexec no = + let + val cmd = + "~/misc/predict/knn " ^ string_of_int number_of_nearest_neighbors ^ + " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj" + in + fst (Isabelle_System.bash_output cmd) + |> space_explode " " + |> map_filter (Option.map (rpair 1.0) o Int.fromString) + end + in + (for 0; ol occ (fn i => (os occ "\""; oi occ i; os occ "\"")) ", " syms; TextIO.closeOut ocs; + TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ; + forkexec (advno + avail_num - adv_max)) + end + fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) @@ -619,8 +655,7 @@ val num_visible_facts = length visible_facts val get_deps = curry Vector.sub deps_vec in - trace_msg ctxt (fn () => "MaSh_SML " ^ (if engine = MaSh_SML_kNN then "kNN" else "NB") ^ - " query " ^ encode_features feats ^ " from {" ^ + trace_msg ctxt (fn () => "MaSh_SML " ^ " query " ^ encode_features feats ^ " from {" ^ elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); (if engine = MaSh_SML_kNN then let @@ -646,9 +681,14 @@ val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats in (case engine of - MaSh_SML_NB opts => naive_bayes opts - | _ => naive_bayes_py ctxt overlord) - num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats' + MaSh_SML_kNN_Cpp => + k_nearest_neighbors_cpp num_facts num_visible_facts get_deps get_unweighted_feats + max_suggs (map fst feats') + | MaSh_SML_NB opts => + naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats + max_suggs feats' + | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps + get_unweighted_feats num_feats max_suggs feats') end) |> map (curry Vector.sub fact_vec o fst) end