--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:27 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:50 2014 +0200
@@ -60,12 +60,12 @@
structure MaSh_SML :
sig
- val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) ->
+ val k_nearest_neighbors : int -> (int -> int list) -> (int -> (int * real) list) -> int ->
+ int list -> (int * real) list -> (int * real) list
+ val naive_bayes : (bool * bool) -> int -> (int -> int list) -> (int -> int list) -> int ->
int -> (int * real) list -> (int * real) list
- val naive_bayes : (bool * bool) -> int -> int -> (int -> int list) -> (int -> int list) ->
+ val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
int -> int -> (int * real) list -> (int * real) list
- val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) ->
- (int -> int list) -> int -> int -> (int * real) list -> (int * real) list
val query : Proof.context -> bool -> mash_engine -> string list -> int ->
(string * (string * real) list * string list) list * string list * (string * real) list ->
string list
@@ -423,14 +423,12 @@
(*
num_facts = maximum number of theorems to check dependencies and symbols
- num_visible_facts = do not return theorems over or equal to this number.
- Must satisfy: num_visible_facts <= num_facts.
get_deps = returns dependencies of a theorem
get_sym_ths = get theorems that have this feature
max_suggs = number of suggestions to return
feats = features of the goal
*)
-fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats =
+fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts feats =
let
(* Can be later used for TFIDF *)
fun sym_wght _ = 1.0
@@ -457,7 +455,7 @@
val _ = List.app do_feat feats
val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
val no_recommends = Unsynchronized.ref 0
- val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
+ val recommends = Array.tabulate (num_facts, rpair 0.0)
val age = Unsynchronized.ref 1000000000.0
fun inc_recommend j v =
@@ -470,7 +468,7 @@
val k = Unsynchronized.ref 0
fun do_k k =
- if k >= num_visible_facts then
+ if k >= num_facts then
raise EXIT ()
else
let
@@ -496,8 +494,8 @@
if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
in
while1 (); while2 ();
- heap (Real.compare o pairself snd) max_suggs num_visible_facts recommends;
- ret [] (Integer.max 0 (num_visible_facts - max_suggs))
+ heap (Real.compare o pairself snd) max_suggs num_facts recommends;
+ ret [] (Integer.max 0 (num_facts - max_suggs))
end
val nb_def_prior_weight = 21 (* FUDGE *)
@@ -541,7 +539,7 @@
learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
end
-fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts num_visible_facts max_suggs feats
+fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts max_suggs feats
(tfreq, sfreq, idf) =
let
val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *)
@@ -576,22 +574,21 @@
res + tau * sum_of_weights
end
- val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j)))
+ val posterior = Array.tabulate (num_facts, (fn j => (j, log_posterior j)))
fun ret acc at =
- if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
+ if at = num_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
in
- heap (Real.compare o pairself snd) max_suggs num_visible_facts posterior;
- ret [] (Integer.max 0 (num_visible_facts - max_suggs))
+ heap (Real.compare o pairself snd) max_suggs num_facts posterior;
+ ret [] (Integer.max 0 (num_facts - max_suggs))
end
-fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
+fun naive_bayes opts num_facts get_deps get_feats num_feats max_suggs feats =
learn num_facts get_deps get_feats num_feats
- |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats
+ |> naive_bayes_query opts num_facts max_suggs feats
(* experimental *)
-fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs
- feats =
+fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats =
let
fun name_of_fact j = "f" ^ string_of_int j
fun fact_of_name s = the (Int.fromString (unprefix "f" s))
@@ -600,7 +597,7 @@
val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
map name_of_fact (get_deps j))) (0 upto num_facts - 1)
- val parents' = parents_of num_visible_facts
+ val parents' = parents_of num_facts
val feats' = map (apfst name_of_feature) feats
in
MaSh_Py.unlearn ctxt overlord;
@@ -655,10 +652,7 @@
fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
let
- val visible_fact_set = Symtab.make_set visible_facts
- val learns =
- (learns0 |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
- (if null hints then [] else [(".hints", feats, hints)])
+ val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)])
in
if engine = MaSh_SML_kNN_Cpp then
k_nearest_neighbors_cpp max_suggs learns (map fst feats)
@@ -666,7 +660,7 @@
naive_bayes_cpp max_suggs learns (map fst feats)
else
let
- val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
+ val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) =
fold (fn (fact, feats, deps) =>
fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
let
@@ -687,11 +681,12 @@
val deps_vec = Vector.fromList (rev rev_depss)
- val num_visible_facts = length visible_facts
val get_deps = curry Vector.sub deps_vec
+
+ val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts
in
trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
- elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
+ elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
(if engine = MaSh_SML_kNN then
let
val facts_ary = Array.array (num_feats, [])
@@ -704,10 +699,10 @@
end)
rev_featss num_facts
val get_facts = curry Array.sub facts_ary
- val feats' = map_filter (fn (feat, weight) =>
+ val int_feats = map_filter (fn (feat, weight) =>
Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
in
- k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats'
+ k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts int_feats
end
else
let
@@ -717,9 +712,9 @@
in
(case engine of
MaSh_SML_NB opts =>
- naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
- max_suggs int_feats
- | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
+ naive_bayes opts num_facts get_deps get_unweighted_feats num_feats max_suggs
+ int_feats
+ | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps
get_unweighted_feats num_feats max_suggs int_feats)
end)
|> map (curry Vector.sub fact_vec o fst)