# HG changeset patch # User blanchet # Date 1403782468 -7200 # Node ID 30ee18eb23ac5d2c64e71ba11020d6e4eb8a001e # Parent 9816f692b0ca2343bdfc0841f59cc24298738875 new version of adaptive k-NN with TFIDF diff -r 9816f692b0ca -r 30ee18eb23ac src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:50 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:34:28 2014 +0200 @@ -60,15 +60,14 @@ structure MaSh_SML : sig - val k_nearest_neighbors : int -> (int -> int list) -> (int -> (int * real) list) -> int -> - int list -> (int * real) list -> (int * real) list + val k_nearest_neighbors : int -> (int -> int list) -> (int -> int list) -> int -> int list -> + int -> int list -> (int * real) list val naive_bayes : (bool * bool) -> int -> (int -> int list) -> (int -> int list) -> int -> - int -> (int * real) list -> (int * real) list + int -> int list -> (int * real) list val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) -> - int -> int -> (int * real) list -> (int * real) list + int -> int -> int list -> (int * real) list val query : Proof.context -> bool -> mash_engine -> string list -> int -> - (string * (string * real) list * string list) list * string list * (string * real) list -> - string list + (string * string list * string list) list * string list * string list -> string list end val mash_unlearn : Proof.context -> params -> unit @@ -428,10 +427,18 @@ max_suggs = number of suggestions to return feats = features of the goal *) -fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts feats = +fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts num_feats feats = let - (* Can be later used for TFIDF *) - fun sym_wght _ = 1.0 + val dffreq = Array.array (num_feats, 0) + + fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s)) + fun for1 i = + if i = num_feats then () else + (List.app (fn _ => add_sym i) (get_sym_ths i); for1 (i + 1)) + val _ = for1 0 + + val ln_afreq = Math.ln (Real.fromInt num_facts) + fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Array.sub (dffreq, feat))) handle Subscript => ln_afreq val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0) @@ -442,12 +449,11 @@ Array.update (overlaps_sqr, j, (j, v + ov)) end - fun do_feat (s, con_wght) = + fun do_feat s = let - val sw = sym_wght s - val w2 = sw * sw * con_wght - - fun do_th (j, prem_wght) = if j < num_facts then inc_overlap j (w2 * prem_wght) else () + val sw = tfidf s + val w2 = sw * sw + fun do_th j = if j < num_facts then inc_overlap j w2 else () in List.app do_th (get_sym_ths s) end @@ -460,11 +466,8 @@ fun inc_recommend j v = let val ov = snd (Array.sub (recommends, j)) in - if ov <= 0.0 then - (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov))) - else - (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ()) - end + if ov <= 0.0 then (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov))) + else (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ()) end val k = Unsynchronized.ref 0 fun do_k k = @@ -482,12 +485,13 @@ end fun while1 () = - if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ()) + (if !k = number_of_nearest_neighbors then () else + (do_k (!k); k := !k + 1; while1 ())) handle EXIT () => () fun while2 () = - if !no_recommends >= max_suggs then () - else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ()) + (if !no_recommends >= max_suggs then () else + (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ())) handle EXIT () => () fun ret acc at = @@ -553,7 +557,7 @@ let val tfreq = Real.fromInt (Vector.sub (tfreq, i)) - fun fold_feats (f, _) (res, sfh) = + fun fold_feats f (res, sfh) = (case Inttab.lookup sfh f of SOME sf => (res + tfidf f * Math.ln (pos_weight * Real.fromInt sf / tfreq), @@ -598,7 +602,7 @@ val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j), map name_of_fact (get_deps j))) (0 upto num_facts - 1) val parents' = parents_of num_facts - val feats' = map (apfst name_of_feature) feats + val feats' = map (rpair 1.0 o name_of_feature) feats in MaSh_Py.unlearn ctxt overlord; OS.Process.sleep (seconds 2.0); (* hack *) @@ -622,8 +626,7 @@ | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t) fun do_learn (name, feats, deps) = - (os ocs name; os ocs ":"; - ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n"; + (os ocs name; os ocs ":"; ol ocs (os ocs o quote) ", " feats; os ocs "\n"; os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n"; os ocq name; os ocq "\n") fun forkexec no = @@ -637,7 +640,7 @@ |> filter_out (curry (op =) "") end in - (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats; + (List.app do_learn learns; ol occ (os occ o quote) ", " cfeats; TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ; forkexec max_suggs) end @@ -655,19 +658,19 @@ val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)]) in if engine = MaSh_SML_kNN_Cpp then - k_nearest_neighbors_cpp max_suggs learns (map fst feats) + k_nearest_neighbors_cpp max_suggs learns feats else if engine = MaSh_SML_NB_Cpp then - naive_bayes_cpp max_suggs learns (map fst feats) + naive_bayes_cpp max_suggs learns feats else let val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) = fold (fn (fact, feats, deps) => fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) => let - fun add_feat (feat, weight) (xtab as (n, tab, _)) = + fun add_feat feat (xtab as (n, tab, _)) = (case Symtab.lookup tab feat of - SOME i => ((i, weight), xtab) - | NONE => ((n, weight), add_to_xtab feat xtab)) + SOME i => (i, xtab) + | NONE => (n, add_to_xtab feat xtab)) val (feats', feat_xtab') = fold_map add_feat feats feat_xtab in @@ -678,14 +681,15 @@ val facts = rev rev_facts val fact_vec = Vector.fromList facts + val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts val deps_vec = Vector.fromList (rev rev_depss) val get_deps = curry Vector.sub deps_vec - val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts + val int_feats = map (the_default ~1 o Symtab.lookup feat_tab) feats in - trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^ + trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs feats ^ " from {" ^ elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}"); (if engine = MaSh_SML_kNN then let @@ -693,22 +697,18 @@ val _ = fold (fn feats => fn fact => let val fact' = fact - 1 in - List.app (fn (feat, weight) => - map_array_at facts_ary (cons (fact', weight)) feat) feats; - fact' + List.app (map_array_at facts_ary (cons fact')) feats; fact' end) rev_featss num_facts val get_facts = curry Array.sub facts_ary - val int_feats = map_filter (fn (feat, weight) => - Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats in - k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts int_feats + k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats + int_feats end else let - val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss)) + val unweighted_feats_ary = Vector.fromList (rev rev_featss) val get_unweighted_feats = curry Vector.sub unweighted_feats_ary - val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats in (case engine of MaSh_SML_NB opts => @@ -773,6 +773,8 @@ val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []} : mash_state +(* TODO: get rid of weights in data structure *) + local val version = "*** MaSh version 20140519 ***" @@ -1428,9 +1430,10 @@ val (parents, hints, feats) = query_args access_G val visible_facts = Graph.all_preds access_G parents val learns = - Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, map fst feats, deps)) + access_G in - MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, feats) + MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, map fst feats) end val unknown = filter_out (is_fact_in_graph access_G o snd) facts