--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200
@@ -35,16 +35,17 @@
datatype mash_engine =
MaSh_NB
+ | MaSh_kNN
+ | MaSh_NB_kNN
| MaSh_NB_Ext
- | MaSh_kNN
| MaSh_kNN_Ext
val is_mash_enabled : unit -> bool
val the_mash_engine : unit -> mash_engine
+ val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
val nickname_of_thm : thm -> string
val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
- val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
val crude_thm_ord : thm * thm -> order
val thm_less : thm * thm -> bool
val goal_of_thm : theory -> thm -> thm
@@ -139,8 +140,9 @@
datatype mash_engine =
MaSh_NB
+| MaSh_kNN
+| MaSh_NB_kNN
| MaSh_NB_Ext
-| MaSh_kNN
| MaSh_kNN_Ext
fun mash_engine () =
@@ -149,8 +151,9 @@
"yes" => SOME MaSh_NB
| "sml" => SOME MaSh_NB
| "nb" => SOME MaSh_NB
+ | "knn" => SOME MaSh_kNN
+ | "nb_knn" => SOME MaSh_NB_kNN
| "nb_ext" => SOME MaSh_NB_Ext
- | "knn" => SOME MaSh_kNN
| "knn_ext" => SOME MaSh_kNN_Ext
| _ => NONE)
end
@@ -158,6 +161,42 @@
val is_mash_enabled = is_some o mash_engine
val the_mash_engine = the_default MaSh_NB o mash_engine
+fun scaled_avg [] = 0
+ | scaled_avg xs = Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
+
+fun avg [] = 0.0
+ | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
+
+fun normalize_scores _ [] = []
+ | normalize_scores max_facts xs =
+ map (apsnd (curry Real.* (1.0 / avg (map snd (take max_facts xs))))) xs
+
+fun mesh_facts fact_eq max_facts [(_, (sels, unks))] =
+ distinct fact_eq (map fst (take max_facts sels) @ take (max_facts - length sels) unks)
+ | mesh_facts fact_eq max_facts mess =
+ let
+ val mess = mess |> map (apsnd (apfst (normalize_scores max_facts)))
+
+ fun score_in fact (global_weight, (sels, unks)) =
+ let val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score) in
+ (case find_index (curry fact_eq fact o fst) sels of
+ ~1 => if member fact_eq unks fact then NONE else SOME 0.0
+ | rank => score_at rank)
+ end
+
+ fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
+ in
+ fold (union fact_eq o map fst o take max_facts o fst o snd) mess []
+ |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
+ |> map snd |> take max_facts
+ end
+
+fun smooth_weight_of_fact rank = Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 (* FUDGE *)
+fun steep_weight_of_fact rank = Math.pow (0.62, log2 (Real.fromInt (rank + 1))) (* FUDGE *)
+
+fun weight_facts_smoothly facts = facts ~~ map smooth_weight_of_fact (0 upto length facts - 1)
+fun weight_facts_steeply facts = facts ~~ map steep_weight_of_fact (0 upto length facts - 1)
+
(*** Isabelle-agnostic machine learning ***)
@@ -321,7 +360,7 @@
val number_of_nearest_neighbors = 10 (* FUDGE *)
-fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats =
+fun k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts goal_feats =
let
exception EXIT of unit
@@ -330,6 +369,10 @@
val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
+ val feat_facts = Array.array (num_feats, [])
+ val _ = Vector.foldl (fn (feats, fact) =>
+ (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1)) 0 featss
+
fun do_feat (s, sw0) =
let
val sw = sw0 * tfidf s
@@ -440,21 +483,30 @@
fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
(freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
- (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
- elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
- (case engine of
- MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats
- | MaSh_kNN =>
- let
- val feat_facts = Array.array (num_feats, [])
- val _ =
- Vector.foldl (fn (feats, fact) =>
- (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1))
- 0 featss
- in
- k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
- end)
- |> map (curry Vector.sub fact_names o fst))
+ let
+ fun nb () =
+ naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats
+ |> map fst
+ fun knn () =
+ k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts
+ int_goal_feats
+ |> map fst
+ in
+ (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
+ elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
+ (case engine of
+ MaSh_NB => nb ()
+ | MaSh_kNN => knn ()
+ | MaSh_NB_kNN =>
+ let
+ val mess =
+ [(0.5 (* FUDGE *), (weight_facts_steeply (nb ()), [])),
+ (0.5 (* FUDGE *), (weight_facts_steeply (knn ()), []))]
+ in
+ mesh_facts (op =) max_suggs mess
+ end)
+ |> map (curry Vector.sub fact_names))
+ end
end;
@@ -706,36 +758,6 @@
|> tap (fn NONE => trace_msg ctxt (fn () => "Cannot find " ^ quote nick) | _ => ())
in map_filter lookup end
-fun scaled_avg [] = 0
- | scaled_avg xs = Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
-
-fun avg [] = 0.0
- | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
-
-fun normalize_scores _ [] = []
- | normalize_scores max_facts xs =
- map (apsnd (curry Real.* (1.0 / avg (map snd (take max_facts xs))))) xs
-
-fun mesh_facts fact_eq max_facts [(_, (sels, unks))] =
- distinct fact_eq (map fst (take max_facts sels) @ take (max_facts - length sels) unks)
- | mesh_facts fact_eq max_facts mess =
- let
- val mess = mess |> map (apsnd (apfst (normalize_scores max_facts)))
-
- fun score_in fact (global_weight, (sels, unks)) =
- let val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score) in
- (case find_index (curry fact_eq fact o fst) sels of
- ~1 => if member fact_eq unks fact then NONE else SOME 0.0
- | rank => score_at rank)
- end
-
- fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
- in
- fold (union fact_eq o map fst o take max_facts o fst o snd) mess []
- |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
- |> map snd |> take max_facts
- end
-
fun free_feature_of s = "f" ^ s
fun thy_feature_of s = "y" ^ s
fun type_feature_of s = "t" ^ s
@@ -1098,20 +1120,6 @@
val extra_feature_factor = 0.1 (* FUDGE *)
val num_extra_feature_facts = 10 (* FUDGE *)
-(* FUDGE *)
-fun weight_of_proximity_fact rank =
- Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
-
-fun weight_facts_smoothly facts =
- facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
-
-(* FUDGE *)
-fun steep_weight_of_fact rank =
- Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
-
-fun weight_facts_steeply facts =
- facts ~~ map steep_weight_of_fact (0 upto length facts - 1)
-
val max_proximity_facts = 100
fun find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
@@ -1587,7 +1595,6 @@
end
fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
-
fun running_learners () = Async_Manager.running_threads MaShN "learner"
end;