# HG changeset patch # User blanchet # Date 1403782546 -7200 # Node ID 6d422f19cefb5159587944ccb73171ba910c2d5e # Parent b89937ed60992ceb015ae337de038f843893aa30 tuning diff -r b89937ed6099 -r 6d422f19cefb src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:39 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:46 2014 +0200 @@ -46,30 +46,6 @@ val is_mash_enabled : unit -> bool val the_mash_engine : unit -> mash_engine - structure MaSh_Py : - sig - val unlearn : Proof.context -> bool -> unit - val learn : Proof.context -> bool -> bool -> - (string * string list * string list * string list) list -> unit - val relearn : Proof.context -> bool -> bool -> (string * string list) list -> unit - val query : Proof.context -> bool -> int -> - (string * string list * string list * string list) list * string list * string list - * (string * real) list -> - (string * real) list - end - - structure MaSh_SML : - sig - val k_nearest_neighbors : int -> (int -> int list) -> (int -> int list) -> int -> int list -> - int -> int list -> (int * real) list - val naive_bayes : int -> (int -> int list) -> (int -> int list) -> int -> int -> int list -> - int list -> (int * real) list - val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) -> - int -> int -> int list -> (int * real) list - val query : Proof.context -> bool -> mash_engine -> string list -> int -> - (string * string list * string list) list * string list * string list -> string list - end - val mash_unlearn : Proof.context -> params -> unit val nickname_of_thm : thm -> string val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list @@ -492,7 +468,7 @@ val nb_def_prior_weight = 21 (* FUDGE *) -fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats = +fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats = let fun learn_fact th feats deps = let @@ -525,7 +501,7 @@ val sfreq = Array.array (num_facts, Inttab.empty) val dffreq = Array.array (num_feats, 0) in - learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats + learn_facts tfreq sfreq dffreq num_facts get_deps get_feats end fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) = @@ -574,7 +550,7 @@ |> naive_bayes_query num_facts max_suggs visible_facts feats (* experimental *) -fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats = +fun naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs feats = let fun name_of_fact j = "f" ^ string_of_int j fun fact_of_name s = the (Int.fromString (unprefix "f" s)) @@ -631,66 +607,54 @@ c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors) val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes" -val empty_xtab = (0, Symtab.empty) - -fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab) -fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key)) +fun query ctxt overlord engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs + learns conj_feats = + if engine = MaSh_SML_kNN_Cpp then + k_nearest_neighbors_cpp max_suggs learns conj_feats + else if engine = MaSh_SML_NB_Cpp then + naive_bayes_cpp max_suggs learns conj_feats + else + let + val facts = map #1 learns + val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns + val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns -fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) = - let - val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)]) - in - if engine = MaSh_SML_kNN_Cpp then - k_nearest_neighbors_cpp max_suggs learns feats - else if engine = MaSh_SML_NB_Cpp then - naive_bayes_cpp max_suggs learns feats - else - let - val facts = map #1 learns - val fact_vec = Vector.fromList facts + val fact_vec = Vector.fromList facts + val deps_vec = Vector.fromList depss - val fact_xtab as (num_facts, fact_tab) = fold add_to_xtab facts empty_xtab - val feat_xtab as (num_feats, feat_tab) = fold (fold maybe_add_to_xtab o #2) learns empty_xtab - - val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns - - val deps_vec = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns) - - val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts - - val get_deps = curry Vector.sub deps_vec + val get_deps = curry Vector.sub deps_vec - val int_feats = map_filter (Symtab.lookup feat_tab) feats - in - trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs feats ^ " from {" ^ - elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}"); - (if engine = MaSh_SML_kNN then - let - val facts_ary = Array.array (num_feats, []) - val _ = - fold (fn feats => fn fact => - (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1)) - featss 0 - val get_facts = curry Array.sub facts_ary - in - k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats - int_feats - end - else - let - val unweighted_feats_ary = Vector.fromList featss - val get_unweighted_feats = curry Vector.sub unweighted_feats_ary - in - (case engine of - MaSh_SML_NB => - naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs - int_visible_facts int_feats - | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps - get_unweighted_feats num_feats max_suggs int_feats) - end) - |> map (curry Vector.sub fact_vec o fst) - end - end + val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts + val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats + in + trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^ + elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}"); + (if engine = MaSh_SML_kNN then + let + val facts_ary = Array.array (num_feats, []) + val _ = + fold (fn feats => fn fact => + (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1)) + featss 0 + val get_facts = curry Array.sub facts_ary + in + k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats + int_conj_feats + end + else + let + val feats_ary = Vector.fromList featss + val get_feats = curry Vector.sub feats_ary + in + (case engine of + MaSh_SML_NB => + naive_bayes num_facts get_deps get_feats num_feats max_suggs int_visible_facts + int_conj_feats + | MaSh_SML_NB_Py => + naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs int_conj_feats) + end) + |> map (curry Vector.sub fact_vec o fst) + end end; @@ -1328,6 +1292,11 @@ fun add_const_counts t = fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t []) +val empty_xtab = (0, Symtab.empty) + +fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab) +fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key)) + fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts = let val thy = Proof_Context.theory_of ctxt @@ -1395,12 +1364,18 @@ [] else let - val (parents, hints, feats) = query_args access_G + val (parents, hints, feats0) = query_args access_G + val feats = map fst feats0 val visible_facts = Graph.all_preds access_G parents val learns = - Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G @ + (if null hints then [] else [(".hints", feats, hints)]) + + val fact_xtab = fold (add_to_xtab o #1) learns empty_xtab + val feat_xtab = fold (fold maybe_add_to_xtab o #2) learns empty_xtab in - MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, map fst feats) + MaSh_SML.query ctxt overlord engine fact_xtab feat_xtab visible_facts max_facts learns + feats end val unknown = filter_out (is_fact_in_graph access_G o snd) facts