--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:50 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:34:28 2014 +0200
@@ -60,15 +60,14 @@
structure MaSh_SML :
sig
- val k_nearest_neighbors : int -> (int -> int list) -> (int -> (int * real) list) -> int ->
- int list -> (int * real) list -> (int * real) list
+ val k_nearest_neighbors : int -> (int -> int list) -> (int -> int list) -> int -> int list ->
+ int -> int list -> (int * real) list
val naive_bayes : (bool * bool) -> int -> (int -> int list) -> (int -> int list) -> int ->
- int -> (int * real) list -> (int * real) list
+ int -> int list -> (int * real) list
val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
- int -> int -> (int * real) list -> (int * real) list
+ int -> int -> int list -> (int * real) list
val query : Proof.context -> bool -> mash_engine -> string list -> int ->
- (string * (string * real) list * string list) list * string list * (string * real) list ->
- string list
+ (string * string list * string list) list * string list * string list -> string list
end
val mash_unlearn : Proof.context -> params -> unit
@@ -428,10 +427,18 @@
max_suggs = number of suggestions to return
feats = features of the goal
*)
-fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts feats =
+fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts num_feats feats =
let
- (* Can be later used for TFIDF *)
- fun sym_wght _ = 1.0
+ val dffreq = Array.array (num_feats, 0)
+
+ fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
+ fun for1 i =
+ if i = num_feats then () else
+ (List.app (fn _ => add_sym i) (get_sym_ths i); for1 (i + 1))
+ val _ = for1 0
+
+ val ln_afreq = Math.ln (Real.fromInt num_facts)
+ fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Array.sub (dffreq, feat))) handle Subscript => ln_afreq
val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
@@ -442,12 +449,11 @@
Array.update (overlaps_sqr, j, (j, v + ov))
end
- fun do_feat (s, con_wght) =
+ fun do_feat s =
let
- val sw = sym_wght s
- val w2 = sw * sw * con_wght
-
- fun do_th (j, prem_wght) = if j < num_facts then inc_overlap j (w2 * prem_wght) else ()
+ val sw = tfidf s
+ val w2 = sw * sw
+ fun do_th j = if j < num_facts then inc_overlap j w2 else ()
in
List.app do_th (get_sym_ths s)
end
@@ -460,11 +466,8 @@
fun inc_recommend j v =
let val ov = snd (Array.sub (recommends, j)) in
- if ov <= 0.0 then
- (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
- else
- (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ())
- end
+ if ov <= 0.0 then (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
+ else (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ()) end
val k = Unsynchronized.ref 0
fun do_k k =
@@ -482,12 +485,13 @@
end
fun while1 () =
- if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ())
+ (if !k = number_of_nearest_neighbors then () else
+ (do_k (!k); k := !k + 1; while1 ()))
handle EXIT () => ()
fun while2 () =
- if !no_recommends >= max_suggs then ()
- else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ())
+ (if !no_recommends >= max_suggs then () else
+ (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ()))
handle EXIT () => ()
fun ret acc at =
@@ -553,7 +557,7 @@
let
val tfreq = Real.fromInt (Vector.sub (tfreq, i))
- fun fold_feats (f, _) (res, sfh) =
+ fun fold_feats f (res, sfh) =
(case Inttab.lookup sfh f of
SOME sf =>
(res + tfidf f * Math.ln (pos_weight * Real.fromInt sf / tfreq),
@@ -598,7 +602,7 @@
val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
map name_of_fact (get_deps j))) (0 upto num_facts - 1)
val parents' = parents_of num_facts
- val feats' = map (apfst name_of_feature) feats
+ val feats' = map (rpair 1.0 o name_of_feature) feats
in
MaSh_Py.unlearn ctxt overlord;
OS.Process.sleep (seconds 2.0); (* hack *)
@@ -622,8 +626,7 @@
| ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
fun do_learn (name, feats, deps) =
- (os ocs name; os ocs ":";
- ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n";
+ (os ocs name; os ocs ":"; ol ocs (os ocs o quote) ", " feats; os ocs "\n";
os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n"; os ocq name; os ocq "\n")
fun forkexec no =
@@ -637,7 +640,7 @@
|> filter_out (curry (op =) "")
end
in
- (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats;
+ (List.app do_learn learns; ol occ (os occ o quote) ", " cfeats;
TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
forkexec max_suggs)
end
@@ -655,19 +658,19 @@
val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)])
in
if engine = MaSh_SML_kNN_Cpp then
- k_nearest_neighbors_cpp max_suggs learns (map fst feats)
+ k_nearest_neighbors_cpp max_suggs learns feats
else if engine = MaSh_SML_NB_Cpp then
- naive_bayes_cpp max_suggs learns (map fst feats)
+ naive_bayes_cpp max_suggs learns feats
else
let
val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) =
fold (fn (fact, feats, deps) =>
fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
let
- fun add_feat (feat, weight) (xtab as (n, tab, _)) =
+ fun add_feat feat (xtab as (n, tab, _)) =
(case Symtab.lookup tab feat of
- SOME i => ((i, weight), xtab)
- | NONE => ((n, weight), add_to_xtab feat xtab))
+ SOME i => (i, xtab)
+ | NONE => (n, add_to_xtab feat xtab))
val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
in
@@ -678,14 +681,15 @@
val facts = rev rev_facts
val fact_vec = Vector.fromList facts
+ val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts
val deps_vec = Vector.fromList (rev rev_depss)
val get_deps = curry Vector.sub deps_vec
- val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts
+ val int_feats = map (the_default ~1 o Symtab.lookup feat_tab) feats
in
- trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
+ trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs feats ^ " from {" ^
elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
(if engine = MaSh_SML_kNN then
let
@@ -693,22 +697,18 @@
val _ =
fold (fn feats => fn fact =>
let val fact' = fact - 1 in
- List.app (fn (feat, weight) =>
- map_array_at facts_ary (cons (fact', weight)) feat) feats;
- fact'
+ List.app (map_array_at facts_ary (cons fact')) feats; fact'
end)
rev_featss num_facts
val get_facts = curry Array.sub facts_ary
- val int_feats = map_filter (fn (feat, weight) =>
- Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
in
- k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts int_feats
+ k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
+ int_feats
end
else
let
- val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
+ val unweighted_feats_ary = Vector.fromList (rev rev_featss)
val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
- val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
in
(case engine of
MaSh_SML_NB opts =>
@@ -773,6 +773,8 @@
val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []} : mash_state
+(* TODO: get rid of weights in data structure *)
+
local
val version = "*** MaSh version 20140519 ***"
@@ -1428,9 +1430,10 @@
val (parents, hints, feats) = query_args access_G
val visible_facts = Graph.all_preds access_G parents
val learns =
- Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+ Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, map fst feats, deps))
+ access_G
in
- MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, feats)
+ MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, map fst feats)
end
val unknown = filter_out (is_fact_in_graph access_G o snd) facts