--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 08:19:57 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 08:19:58 2014 +0200
@@ -587,22 +587,20 @@
end
(* experimental *)
-fun k_nearest_neighbors_cpp avail_num adv_max get_deps get_syms advno syms =
+fun k_nearest_neighbors_cpp max_suggs learns cfeats =
let
val ocs = TextIO.openOut "adv_syms"
val ocd = TextIO.openOut "adv_deps"
val ocq = TextIO.openOut "adv_seq"
val occ = TextIO.openOut "adv_conj"
fun os oc s = TextIO.output (oc, s)
- fun oi oc i = os oc (Int.toString i)
- fun ol _ _ _ [] = ()
- | ol _ f _ [e] = f e
+ fun ol _ _ _ [] = ()
+ | ol _ f _ [e] = f e
| ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
- fun do_n n =
- (oi ocs n; os ocs ":"; ol ocs (fn i => (os ocs "\""; oi ocs i; os ocs "\"")) ", " (get_syms n); os ocs "\n";
- oi ocd n; os ocd ":"; ol ocd (fn i => oi ocd i) " " (get_deps n); os ocd "\n";
- oi ocq n; os ocq "\n")
- fun for n = if n = avail_num then () else (do_n n; for (n + 1))
+ fun do_learn (name, feats, deps) =
+ (os ocs name; os ocs ":"; ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n";
+ os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n";
+ os ocq name; os ocq "\n")
fun forkexec no =
let
val cmd =
@@ -611,86 +609,87 @@
in
fst (Isabelle_System.bash_output cmd)
|> space_explode " "
- |> map_filter (Option.map (rpair 1.0) o Int.fromString)
+ |> filter_out (curry (op =) "")
end
in
- (for 0; ol occ (fn i => (os occ "\""; oi occ i; os occ "\"")) ", " syms; TextIO.closeOut ocs;
+ (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats; TextIO.closeOut ocs;
TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
- forkexec (advno + avail_num - adv_max))
+ forkexec max_suggs)
end
fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
-fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) =
+fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
let
val visible_fact_set = Symtab.make_set visible_facts
-
- val learns' =
- (learns |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
+ val learns =
+ (learns0 |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
(if null hints then [] else [(".goal", feats, hints)])
-
- val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
- fold (fn (fact, feats, deps) =>
- fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
- let
- fun add_feat (feat, weight) (xtab as (n, tab, _)) =
- (case Symtab.lookup tab feat of
- SOME i => ((i, weight), xtab)
- | NONE => ((n, weight), add_to_xtab feat xtab))
-
- val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
- in
- (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
- add_to_xtab fact fact_xtab, feat_xtab')
- end)
- learns' ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
-
- val facts = rev rev_facts
- val fact_vec = Vector.fromList facts
-
- val deps_vec = Vector.fromList (rev rev_depss)
-
- val num_visible_facts = length visible_facts
- val get_deps = curry Vector.sub deps_vec
in
- trace_msg ctxt (fn () => "MaSh_SML " ^ " query " ^ encode_features feats ^ " from {" ^
- elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
- (if engine = MaSh_SML_kNN then
- let
- val facts_ary = Array.array (num_feats, [])
- val _ =
- fold (fn feats => fn fact =>
- let val fact' = fact - 1 in
- List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
- feats;
- fact'
- end)
- rev_featss num_facts
- val get_facts = curry Array.sub facts_ary
- val feats' = map_filter (fn (feat, weight) =>
- Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
- in
- k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats'
- end
- else
- let
- val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
- val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
- val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
- in
- (case engine of
- MaSh_SML_kNN_Cpp =>
- k_nearest_neighbors_cpp num_facts num_visible_facts get_deps get_unweighted_feats
- max_suggs (map fst feats')
- | MaSh_SML_NB opts =>
- naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
- max_suggs feats'
- | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
- get_unweighted_feats num_feats max_suggs feats')
- end)
- |> map (curry Vector.sub fact_vec o fst)
+ if engine = MaSh_SML_kNN_Cpp then
+ k_nearest_neighbors_cpp max_suggs learns (map fst feats)
+ else
+ let
+ val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
+ fold (fn (fact, feats, deps) =>
+ fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
+ let
+ fun add_feat (feat, weight) (xtab as (n, tab, _)) =
+ (case Symtab.lookup tab feat of
+ SOME i => ((i, weight), xtab)
+ | NONE => ((n, weight), add_to_xtab feat xtab))
+
+ val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
+ in
+ (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
+ add_to_xtab fact fact_xtab, feat_xtab')
+ end)
+ learns ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
+
+ val facts = rev rev_facts
+ val fact_vec = Vector.fromList facts
+
+ val deps_vec = Vector.fromList (rev rev_depss)
+
+ val num_visible_facts = length visible_facts
+ val get_deps = curry Vector.sub deps_vec
+ in
+ trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
+ elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
+ (if engine = MaSh_SML_kNN then
+ let
+ val facts_ary = Array.array (num_feats, [])
+ val _ =
+ fold (fn feats => fn fact =>
+ let val fact' = fact - 1 in
+ List.app (fn (feat, weight) =>
+ map_array_at facts_ary (cons (fact', weight)) feat) feats;
+ fact'
+ end)
+ rev_featss num_facts
+ val get_facts = curry Array.sub facts_ary
+ val feats' = map_filter (fn (feat, weight) =>
+ Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
+ in
+ k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats'
+ end
+ else
+ let
+ val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
+ val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
+ val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
+ in
+ (case engine of
+ MaSh_SML_NB opts =>
+ naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
+ max_suggs int_feats
+ | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
+ get_unweighted_feats num_feats max_suggs int_feats)
+ end)
+ |> map (curry Vector.sub fact_vec o fst)
+ end
end
end;