# HG changeset patch # User blanchet # Date 1403782430 -7200 # Node ID 9816f692b0ca2343bdfc0841f59cc24298738875 # Parent a9e0f9d35125e518341ba5e8ce88779994762093 refactoring diff -r a9e0f9d35125 -r 9816f692b0ca src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:27 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:50 2014 +0200 @@ -60,12 +60,12 @@ structure MaSh_SML : sig - val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) -> + val k_nearest_neighbors : int -> (int -> int list) -> (int -> (int * real) list) -> int -> + int list -> (int * real) list -> (int * real) list + val naive_bayes : (bool * bool) -> int -> (int -> int list) -> (int -> int list) -> int -> int -> (int * real) list -> (int * real) list - val naive_bayes : (bool * bool) -> int -> int -> (int -> int list) -> (int -> int list) -> + val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) -> int -> int -> (int * real) list -> (int * real) list - val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) -> - (int -> int list) -> int -> int -> (int * real) list -> (int * real) list val query : Proof.context -> bool -> mash_engine -> string list -> int -> (string * (string * real) list * string list) list * string list * (string * real) list -> string list @@ -423,14 +423,12 @@ (* num_facts = maximum number of theorems to check dependencies and symbols - num_visible_facts = do not return theorems over or equal to this number. - Must satisfy: num_visible_facts <= num_facts. get_deps = returns dependencies of a theorem get_sym_ths = get theorems that have this feature max_suggs = number of suggestions to return feats = features of the goal *) -fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats = +fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts feats = let (* Can be later used for TFIDF *) fun sym_wght _ = 1.0 @@ -457,7 +455,7 @@ val _ = List.app do_feat feats val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr val no_recommends = Unsynchronized.ref 0 - val recommends = Array.tabulate (num_visible_facts, rpair 0.0) + val recommends = Array.tabulate (num_facts, rpair 0.0) val age = Unsynchronized.ref 1000000000.0 fun inc_recommend j v = @@ -470,7 +468,7 @@ val k = Unsynchronized.ref 0 fun do_k k = - if k >= num_visible_facts then + if k >= num_facts then raise EXIT () else let @@ -496,8 +494,8 @@ if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1) in while1 (); while2 (); - heap (Real.compare o pairself snd) max_suggs num_visible_facts recommends; - ret [] (Integer.max 0 (num_visible_facts - max_suggs)) + heap (Real.compare o pairself snd) max_suggs num_facts recommends; + ret [] (Integer.max 0 (num_facts - max_suggs)) end val nb_def_prior_weight = 21 (* FUDGE *) @@ -541,7 +539,7 @@ learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats end -fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts num_visible_facts max_suggs feats +fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts max_suggs feats (tfreq, sfreq, idf) = let val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *) @@ -576,22 +574,21 @@ res + tau * sum_of_weights end - val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j))) + val posterior = Array.tabulate (num_facts, (fn j => (j, log_posterior j))) fun ret acc at = - if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1) + if at = num_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1) in - heap (Real.compare o pairself snd) max_suggs num_visible_facts posterior; - ret [] (Integer.max 0 (num_visible_facts - max_suggs)) + heap (Real.compare o pairself snd) max_suggs num_facts posterior; + ret [] (Integer.max 0 (num_facts - max_suggs)) end -fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats = +fun naive_bayes opts num_facts get_deps get_feats num_feats max_suggs feats = learn num_facts get_deps get_feats num_feats - |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats + |> naive_bayes_query opts num_facts max_suggs feats (* experimental *) -fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs - feats = +fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats = let fun name_of_fact j = "f" ^ string_of_int j fun fact_of_name s = the (Int.fromString (unprefix "f" s)) @@ -600,7 +597,7 @@ val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j), map name_of_fact (get_deps j))) (0 upto num_facts - 1) - val parents' = parents_of num_visible_facts + val parents' = parents_of num_facts val feats' = map (apfst name_of_feature) feats in MaSh_Py.unlearn ctxt overlord; @@ -655,10 +652,7 @@ fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) = let - val visible_fact_set = Symtab.make_set visible_facts - val learns = - (learns0 |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @ - (if null hints then [] else [(".hints", feats, hints)]) + val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)]) in if engine = MaSh_SML_kNN_Cpp then k_nearest_neighbors_cpp max_suggs learns (map fst feats) @@ -666,7 +660,7 @@ naive_bayes_cpp max_suggs learns (map fst feats) else let - val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) = + val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) = fold (fn (fact, feats, deps) => fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) => let @@ -687,11 +681,12 @@ val deps_vec = Vector.fromList (rev rev_depss) - val num_visible_facts = length visible_facts val get_deps = curry Vector.sub deps_vec + + val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts in trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ - elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}"); + elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}"); (if engine = MaSh_SML_kNN then let val facts_ary = Array.array (num_feats, []) @@ -704,10 +699,10 @@ end) rev_featss num_facts val get_facts = curry Array.sub facts_ary - val feats' = map_filter (fn (feat, weight) => + val int_feats = map_filter (fn (feat, weight) => Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats in - k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats' + k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts int_feats end else let @@ -717,9 +712,9 @@ in (case engine of MaSh_SML_NB opts => - naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats - max_suggs int_feats - | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps + naive_bayes opts num_facts get_deps get_unweighted_feats num_feats max_suggs + int_feats + | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps get_unweighted_feats num_feats max_suggs int_feats) end) |> map (curry Vector.sub fact_vec o fst)