# HG changeset patch # User blanchet # Date 1403590798 -7200 # Node ID ef9d4e1ceb00a3ab83d327538f188aa74f5ed7f7 # Parent 4e619ee65a6125146a442f35f84274c523ae1be2 use strings to communicate with external process, to ease debugging diff -r 4e619ee65a61 -r ef9d4e1ceb00 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 08:19:57 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jun 24 08:19:58 2014 +0200 @@ -587,22 +587,20 @@ end (* experimental *) -fun k_nearest_neighbors_cpp avail_num adv_max get_deps get_syms advno syms = +fun k_nearest_neighbors_cpp max_suggs learns cfeats = let val ocs = TextIO.openOut "adv_syms" val ocd = TextIO.openOut "adv_deps" val ocq = TextIO.openOut "adv_seq" val occ = TextIO.openOut "adv_conj" fun os oc s = TextIO.output (oc, s) - fun oi oc i = os oc (Int.toString i) - fun ol _ _ _ [] = () - | ol _ f _ [e] = f e + fun ol _ _ _ [] = () + | ol _ f _ [e] = f e | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t) - fun do_n n = - (oi ocs n; os ocs ":"; ol ocs (fn i => (os ocs "\""; oi ocs i; os ocs "\"")) ", " (get_syms n); os ocs "\n"; - oi ocd n; os ocd ":"; ol ocd (fn i => oi ocd i) " " (get_deps n); os ocd "\n"; - oi ocq n; os ocq "\n") - fun for n = if n = avail_num then () else (do_n n; for (n + 1)) + fun do_learn (name, feats, deps) = + (os ocs name; os ocs ":"; ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n"; + os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n"; + os ocq name; os ocq "\n") fun forkexec no = let val cmd = @@ -611,86 +609,87 @@ in fst (Isabelle_System.bash_output cmd) |> space_explode " " - |> map_filter (Option.map (rpair 1.0) o Int.fromString) + |> filter_out (curry (op =) "") end in - (for 0; ol occ (fn i => (os occ "\""; oi occ i; os occ "\"")) ", " syms; TextIO.closeOut ocs; + (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats; TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ; - forkexec (advno + avail_num - adv_max)) + forkexec max_suggs) end fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) -fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) = +fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) = let val visible_fact_set = Symtab.make_set visible_facts - - val learns' = - (learns |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @ + val learns = + (learns0 |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @ (if null hints then [] else [(".goal", feats, hints)]) - - val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) = - fold (fn (fact, feats, deps) => - fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) => - let - fun add_feat (feat, weight) (xtab as (n, tab, _)) = - (case Symtab.lookup tab feat of - SOME i => ((i, weight), xtab) - | NONE => ((n, weight), add_to_xtab feat xtab)) - - val (feats', feat_xtab') = fold_map add_feat feats feat_xtab - in - (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss, - add_to_xtab fact fact_xtab, feat_xtab') - end) - learns' ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) - - val facts = rev rev_facts - val fact_vec = Vector.fromList facts - - val deps_vec = Vector.fromList (rev rev_depss) - - val num_visible_facts = length visible_facts - val get_deps = curry Vector.sub deps_vec in - trace_msg ctxt (fn () => "MaSh_SML " ^ " query " ^ encode_features feats ^ " from {" ^ - elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); - (if engine = MaSh_SML_kNN then - let - val facts_ary = Array.array (num_feats, []) - val _ = - fold (fn feats => fn fact => - let val fact' = fact - 1 in - List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat) - feats; - fact' - end) - rev_featss num_facts - val get_facts = curry Array.sub facts_ary - val feats' = map_filter (fn (feat, weight) => - Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats - in - k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats' - end - else - let - val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) - val get_unweighted_feats = curry Vector.sub unweighted_feats_ary - val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats - in - (case engine of - MaSh_SML_kNN_Cpp => - k_nearest_neighbors_cpp num_facts num_visible_facts get_deps get_unweighted_feats - max_suggs (map fst feats') - | MaSh_SML_NB opts => - naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats - max_suggs feats' - | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps - get_unweighted_feats num_feats max_suggs feats') - end) - |> map (curry Vector.sub fact_vec o fst) + if engine = MaSh_SML_kNN_Cpp then + k_nearest_neighbors_cpp max_suggs learns (map fst feats) + else + let + val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) = + fold (fn (fact, feats, deps) => + fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) => + let + fun add_feat (feat, weight) (xtab as (n, tab, _)) = + (case Symtab.lookup tab feat of + SOME i => ((i, weight), xtab) + | NONE => ((n, weight), add_to_xtab feat xtab)) + + val (feats', feat_xtab') = fold_map add_feat feats feat_xtab + in + (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss, + add_to_xtab fact fact_xtab, feat_xtab') + end) + learns ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) + + val facts = rev rev_facts + val fact_vec = Vector.fromList facts + + val deps_vec = Vector.fromList (rev rev_depss) + + val num_visible_facts = length visible_facts + val get_deps = curry Vector.sub deps_vec + in + trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ + elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); + (if engine = MaSh_SML_kNN then + let + val facts_ary = Array.array (num_feats, []) + val _ = + fold (fn feats => fn fact => + let val fact' = fact - 1 in + List.app (fn (feat, weight) => + map_array_at facts_ary (cons (fact', weight)) feat) feats; + fact' + end) + rev_featss num_facts + val get_facts = curry Array.sub facts_ary + val feats' = map_filter (fn (feat, weight) => + Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats + in + k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats' + end + else + let + val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) + val get_unweighted_feats = curry Vector.sub unweighted_feats_ary + val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats + in + (case engine of + MaSh_SML_NB opts => + naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats + max_suggs int_feats + | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps + get_unweighted_feats num_feats max_suggs int_feats) + end) + |> map (curry Vector.sub fact_vec o fst) + end end end;