# HG changeset patch # User blanchet # Date 1404226030 -7200 # Node ID 22023ab4df3c1d32075f4793a222d45bbbb0e7c2 # Parent 419180c354c04091b622d45cdea517112e29360f mix NB and kNN diff -r 419180c354c0 -r 22023ab4df3c src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200 @@ -35,16 +35,17 @@ datatype mash_engine = MaSh_NB + | MaSh_kNN + | MaSh_NB_kNN | MaSh_NB_Ext - | MaSh_kNN | MaSh_kNN_Ext val is_mash_enabled : unit -> bool val the_mash_engine : unit -> mash_engine + val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list val nickname_of_thm : thm -> string val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list - val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list val crude_thm_ord : thm * thm -> order val thm_less : thm * thm -> bool val goal_of_thm : theory -> thm -> thm @@ -139,8 +140,9 @@ datatype mash_engine = MaSh_NB +| MaSh_kNN +| MaSh_NB_kNN | MaSh_NB_Ext -| MaSh_kNN | MaSh_kNN_Ext fun mash_engine () = @@ -149,8 +151,9 @@ "yes" => SOME MaSh_NB | "sml" => SOME MaSh_NB | "nb" => SOME MaSh_NB + | "knn" => SOME MaSh_kNN + | "nb_knn" => SOME MaSh_NB_kNN | "nb_ext" => SOME MaSh_NB_Ext - | "knn" => SOME MaSh_kNN | "knn_ext" => SOME MaSh_kNN_Ext | _ => NONE) end @@ -158,6 +161,42 @@ val is_mash_enabled = is_some o mash_engine val the_mash_engine = the_default MaSh_NB o mash_engine +fun scaled_avg [] = 0 + | scaled_avg xs = Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs + +fun avg [] = 0.0 + | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs) + +fun normalize_scores _ [] = [] + | normalize_scores max_facts xs = + map (apsnd (curry Real.* (1.0 / avg (map snd (take max_facts xs))))) xs + +fun mesh_facts fact_eq max_facts [(_, (sels, unks))] = + distinct fact_eq (map fst (take max_facts sels) @ take (max_facts - length sels) unks) + | mesh_facts fact_eq max_facts mess = + let + val mess = mess |> map (apsnd (apfst (normalize_scores max_facts))) + + fun score_in fact (global_weight, (sels, unks)) = + let val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score) in + (case find_index (curry fact_eq fact o fst) sels of + ~1 => if member fact_eq unks fact then NONE else SOME 0.0 + | rank => score_at rank) + end + + fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg + in + fold (union fact_eq o map fst o take max_facts o fst o snd) mess [] + |> map (`weight_of) |> sort (int_ord o swap o pairself fst) + |> map snd |> take max_facts + end + +fun smooth_weight_of_fact rank = Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 (* FUDGE *) +fun steep_weight_of_fact rank = Math.pow (0.62, log2 (Real.fromInt (rank + 1))) (* FUDGE *) + +fun weight_facts_smoothly facts = facts ~~ map smooth_weight_of_fact (0 upto length facts - 1) +fun weight_facts_steeply facts = facts ~~ map steep_weight_of_fact (0 upto length facts - 1) + (*** Isabelle-agnostic machine learning ***) @@ -321,7 +360,7 @@ val number_of_nearest_neighbors = 10 (* FUDGE *) -fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats = +fun k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts goal_feats = let exception EXIT of unit @@ -330,6 +369,10 @@ val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0) + val feat_facts = Array.array (num_feats, []) + val _ = Vector.foldl (fn (feats, fact) => + (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1)) 0 featss + fun do_feat (s, sw0) = let val sw = sw0 * tfidf s @@ -440,21 +483,30 @@ fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss) (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats = - (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^ - elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}"); - (case engine of - MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats - | MaSh_kNN => - let - val feat_facts = Array.array (num_feats, []) - val _ = - Vector.foldl (fn (feats, fact) => - (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1)) - 0 featss - in - k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats - end) - |> map (curry Vector.sub fact_names o fst)) + let + fun nb () = + naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats + |> map fst + fun knn () = + k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts + int_goal_feats + |> map fst + in + (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^ + elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}"); + (case engine of + MaSh_NB => nb () + | MaSh_kNN => knn () + | MaSh_NB_kNN => + let + val mess = + [(0.5 (* FUDGE *), (weight_facts_steeply (nb ()), [])), + (0.5 (* FUDGE *), (weight_facts_steeply (knn ()), []))] + in + mesh_facts (op =) max_suggs mess + end) + |> map (curry Vector.sub fact_names)) + end end; @@ -706,36 +758,6 @@ |> tap (fn NONE => trace_msg ctxt (fn () => "Cannot find " ^ quote nick) | _ => ()) in map_filter lookup end -fun scaled_avg [] = 0 - | scaled_avg xs = Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs - -fun avg [] = 0.0 - | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs) - -fun normalize_scores _ [] = [] - | normalize_scores max_facts xs = - map (apsnd (curry Real.* (1.0 / avg (map snd (take max_facts xs))))) xs - -fun mesh_facts fact_eq max_facts [(_, (sels, unks))] = - distinct fact_eq (map fst (take max_facts sels) @ take (max_facts - length sels) unks) - | mesh_facts fact_eq max_facts mess = - let - val mess = mess |> map (apsnd (apfst (normalize_scores max_facts))) - - fun score_in fact (global_weight, (sels, unks)) = - let val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score) in - (case find_index (curry fact_eq fact o fst) sels of - ~1 => if member fact_eq unks fact then NONE else SOME 0.0 - | rank => score_at rank) - end - - fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg - in - fold (union fact_eq o map fst o take max_facts o fst o snd) mess [] - |> map (`weight_of) |> sort (int_ord o swap o pairself fst) - |> map snd |> take max_facts - end - fun free_feature_of s = "f" ^ s fun thy_feature_of s = "y" ^ s fun type_feature_of s = "t" ^ s @@ -1098,20 +1120,6 @@ val extra_feature_factor = 0.1 (* FUDGE *) val num_extra_feature_facts = 10 (* FUDGE *) -(* FUDGE *) -fun weight_of_proximity_fact rank = - Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 - -fun weight_facts_smoothly facts = - facts ~~ map weight_of_proximity_fact (0 upto length facts - 1) - -(* FUDGE *) -fun steep_weight_of_fact rank = - Math.pow (0.62, log2 (Real.fromInt (rank + 1))) - -fun weight_facts_steeply facts = - facts ~~ map steep_weight_of_fact (0 upto length facts - 1) - val max_proximity_facts = 100 fun find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown = @@ -1587,7 +1595,6 @@ end fun kill_learners () = Async_Manager.kill_threads MaShN "learner" - fun running_learners () = Async_Manager.running_threads MaShN "learner" end;