--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200
@@ -34,10 +34,10 @@
val decode_strs : string -> string list
datatype mash_engine =
- MaSh_kNN
+ MaSh_NB
+ | MaSh_NB_Ext
+ | MaSh_kNN
| MaSh_kNN_Ext
- | MaSh_NB
- | MaSh_NB_Ext
val is_mash_enabled : unit -> bool
val the_mash_engine : unit -> mash_engine
@@ -138,20 +138,20 @@
end
datatype mash_engine =
- MaSh_kNN
+ MaSh_NB
+| MaSh_NB_Ext
+| MaSh_kNN
| MaSh_kNN_Ext
-| MaSh_NB
-| MaSh_NB_Ext
fun mash_engine () =
let val flag1 = Options.default_string @{system_option MaSh} in
(case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
"yes" => SOME MaSh_NB
| "sml" => SOME MaSh_NB
+ | "nb" => SOME MaSh_NB
+ | "nb_ext" => SOME MaSh_NB_Ext
| "knn" => SOME MaSh_kNN
| "knn_ext" => SOME MaSh_kNN_Ext
- | "nb" => SOME MaSh_NB
- | "nb_ext" => SOME MaSh_NB_Ext
| _ => NONE)
end
@@ -231,86 +231,12 @@
()
end
-val number_of_nearest_neighbors = 10 (* FUDGE *)
-
fun select_visible_facts big_number recommends =
List.app (fn at =>
let val (j, ov) = Array.sub (recommends, at) in
Array.update (recommends, at, (j, big_number + ov))
end)
-fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats =
- let
- exception EXIT of unit
-
- val ln_afreq = Math.ln (Real.fromInt num_facts)
- fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
-
- val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
-
- fun do_feat (s, sw0) =
- let
- val sw = sw0 * tfidf s
- val w2 = sw * sw
-
- fun inc_overlap j =
- let val (_, ov) = Array.sub (overlaps_sqr, j) in
- Array.update (overlaps_sqr, j, (j, w2 + ov))
- end
- in
- List.app inc_overlap (Array.sub (feat_facts, s))
- end
-
- val _ = List.app do_feat goal_feats
- val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
- val no_recommends = Unsynchronized.ref 0
- val recommends = Array.tabulate (num_facts, rpair 0.0)
- val age = Unsynchronized.ref 500000000.0
-
- fun inc_recommend j v =
- let val (_, ov) = Array.sub (recommends, j) in
- if ov <= 0.0 then
- (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
- else if ov < !age + 1000.0 then
- Array.update (recommends, j, (j, v + ov))
- else
- ()
- end
-
- val k = Unsynchronized.ref 0
- fun do_k k =
- if k >= num_facts then
- raise EXIT ()
- else
- let
- val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
- val o1 = Math.sqrt o2
- val _ = inc_recommend j o1
- val ds = Vector.sub (depss, j)
- val l = Real.fromInt (length ds)
- in
- List.app (fn d => inc_recommend d (o1 / l)) ds
- end
-
- fun while1 () =
- if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ())
- handle EXIT () => ()
-
- fun while2 () =
- if !no_recommends >= max_suggs then ()
- else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ())
- handle EXIT () => ()
-
- fun ret acc at =
- if at = num_facts then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
- in
- while1 ();
- while2 ();
- select_visible_facts 1000000000.0 recommends visible_facts;
- heap (Real.compare o pairself snd) max_suggs num_facts recommends;
- ret [] (Integer.max 0 (num_facts - max_suggs))
- end
-
fun wider_array_of_vector init vec =
let val ary = Array.array init in
Array.copyVec {src = vec, dst = ary, di = 0};
@@ -393,6 +319,80 @@
ret (Integer.max 0 (num_facts - max_suggs)) []
end
+val number_of_nearest_neighbors = 10 (* FUDGE *)
+
+fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats =
+ let
+ exception EXIT of unit
+
+ val ln_afreq = Math.ln (Real.fromInt num_facts)
+ fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
+
+ val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
+
+ fun do_feat (s, sw0) =
+ let
+ val sw = sw0 * tfidf s
+ val w2 = sw * sw
+
+ fun inc_overlap j =
+ let val (_, ov) = Array.sub (overlaps_sqr, j) in
+ Array.update (overlaps_sqr, j, (j, w2 + ov))
+ end
+ in
+ List.app inc_overlap (Array.sub (feat_facts, s))
+ end
+
+ val _ = List.app do_feat goal_feats
+ val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
+ val no_recommends = Unsynchronized.ref 0
+ val recommends = Array.tabulate (num_facts, rpair 0.0)
+ val age = Unsynchronized.ref 500000000.0
+
+ fun inc_recommend v j =
+ let val (_, ov) = Array.sub (recommends, j) in
+ if ov <= 0.0 then
+ (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
+ else if ov < !age + 1000.0 then
+ Array.update (recommends, j, (j, v + ov))
+ else
+ ()
+ end
+
+ val k = Unsynchronized.ref 0
+ fun do_k k =
+ if k >= num_facts then
+ raise EXIT ()
+ else
+ let
+ val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
+ val o1 = Math.sqrt o2
+ val _ = inc_recommend o1 j
+ val ds = Vector.sub (depss, j)
+ val l = Real.fromInt (length ds)
+ in
+ List.app (inc_recommend (o1 / l)) ds
+ end
+
+ fun while1 () =
+ if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ())
+ handle EXIT () => ()
+
+ fun while2 () =
+ if !no_recommends >= max_suggs then ()
+ else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ())
+ handle EXIT () => ()
+
+ fun ret acc at =
+ if at = num_facts then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
+ in
+ while1 ();
+ while2 ();
+ select_visible_facts 1000000000.0 recommends visible_facts;
+ heap (Real.compare o pairself snd) max_suggs num_facts recommends;
+ ret [] (Integer.max 0 (num_facts - max_suggs))
+ end
+
(* experimental *)
fun external_tool tool max_suggs learns goal_feats =
let
@@ -435,15 +435,16 @@
fun query_external ctxt engine max_suggs learns goal_feats =
(trace_msg ctxt (fn () => "MaSh query external " ^ commas (map fst goal_feats));
(case engine of
- MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
- | MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
+ MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats
+ | MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats))
fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
(freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
(trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
(case engine of
- MaSh_kNN =>
+ MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats
+ | MaSh_kNN =>
let
val feat_facts = Array.array (num_feats, [])
val _ =
@@ -452,8 +453,7 @@
0 featss
in
k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
- end
- | MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
+ end)
|> map (curry Vector.sub fact_names o fst))
end;
@@ -1178,7 +1178,7 @@
val (parents, goal_feats) = query_args access_G
val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
in
- if engine = MaSh_kNN_Ext orelse engine = MaSh_NB_Ext then
+ if engine = MaSh_NB_Ext orelse engine = MaSh_kNN_Ext then
let
val learns =
Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G