src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57459 22023ab4df3c
parent 57458 419180c354c0
child 57460 9cc802a8ab06
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jul 01 16:47:10 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jul 01 16:47:10 2014 +0200
@@ -35,16 +35,17 @@
 
   datatype mash_engine =
     MaSh_NB
+  | MaSh_kNN
+  | MaSh_NB_kNN
   | MaSh_NB_Ext
-  | MaSh_kNN
   | MaSh_kNN_Ext
 
   val is_mash_enabled : unit -> bool
   val the_mash_engine : unit -> mash_engine
 
+  val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
   val nickname_of_thm : thm -> string
   val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
-  val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
   val crude_thm_ord : thm * thm -> order
   val thm_less : thm * thm -> bool
   val goal_of_thm : theory -> thm -> thm
@@ -139,8 +140,9 @@
 
 datatype mash_engine =
   MaSh_NB
+| MaSh_kNN
+| MaSh_NB_kNN
 | MaSh_NB_Ext
-| MaSh_kNN
 | MaSh_kNN_Ext
 
 fun mash_engine () =
@@ -149,8 +151,9 @@
       "yes" => SOME MaSh_NB
     | "sml" => SOME MaSh_NB
     | "nb" => SOME MaSh_NB
+    | "knn" => SOME MaSh_kNN
+    | "nb_knn" => SOME MaSh_NB_kNN
     | "nb_ext" => SOME MaSh_NB_Ext
-    | "knn" => SOME MaSh_kNN
     | "knn_ext" => SOME MaSh_kNN_Ext
     | _ => NONE)
   end
@@ -158,6 +161,42 @@
 val is_mash_enabled = is_some o mash_engine
 val the_mash_engine = the_default MaSh_NB o mash_engine
 
+fun scaled_avg [] = 0
+  | scaled_avg xs = Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
+
+fun avg [] = 0.0
+  | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
+
+fun normalize_scores _ [] = []
+  | normalize_scores max_facts xs =
+    map (apsnd (curry Real.* (1.0 / avg (map snd (take max_facts xs))))) xs
+
+fun mesh_facts fact_eq max_facts [(_, (sels, unks))] =
+    distinct fact_eq (map fst (take max_facts sels) @ take (max_facts - length sels) unks)
+  | mesh_facts fact_eq max_facts mess =
+    let
+      val mess = mess |> map (apsnd (apfst (normalize_scores max_facts)))
+
+      fun score_in fact (global_weight, (sels, unks)) =
+        let val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score) in
+          (case find_index (curry fact_eq fact o fst) sels of
+            ~1 => if member fact_eq unks fact then NONE else SOME 0.0
+          | rank => score_at rank)
+        end
+
+      fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
+    in
+      fold (union fact_eq o map fst o take max_facts o fst o snd) mess []
+      |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
+      |> map snd |> take max_facts
+    end
+
+fun smooth_weight_of_fact rank = Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 (* FUDGE *)
+fun steep_weight_of_fact rank = Math.pow (0.62, log2 (Real.fromInt (rank + 1))) (* FUDGE *)
+
+fun weight_facts_smoothly facts = facts ~~ map smooth_weight_of_fact (0 upto length facts - 1)
+fun weight_facts_steeply facts = facts ~~ map steep_weight_of_fact (0 upto length facts - 1)
+
 
 (*** Isabelle-agnostic machine learning ***)
 
@@ -321,7 +360,7 @@
 
 val number_of_nearest_neighbors = 10 (* FUDGE *)
 
-fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats =
+fun k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts goal_feats =
   let
     exception EXIT of unit
 
@@ -330,6 +369,10 @@
 
     val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
 
+    val feat_facts = Array.array (num_feats, [])
+    val _ = Vector.foldl (fn (feats, fact) =>
+      (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1)) 0 featss
+
     fun do_feat (s, sw0) =
       let
         val sw = sw0 * tfidf s
@@ -440,21 +483,30 @@
 
 fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
     (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
-  (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
-     elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
-   (case engine of
-     MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats
-   | MaSh_kNN =>
-     let
-       val feat_facts = Array.array (num_feats, [])
-       val _ =
-         Vector.foldl (fn (feats, fact) =>
-             (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1))
-           0 featss
-     in
-       k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
-     end)
-   |> map (curry Vector.sub fact_names o fst))
+  let
+    fun nb () =
+      naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats
+      |> map fst
+    fun knn () =
+      k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs visible_facts
+        int_goal_feats
+      |> map fst
+  in
+    (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
+       elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
+     (case engine of
+       MaSh_NB => nb ()
+     | MaSh_kNN => knn ()
+     | MaSh_NB_kNN =>
+       let
+         val mess =
+           [(0.5 (* FUDGE *), (weight_facts_steeply (nb ()), [])),
+            (0.5 (* FUDGE *), (weight_facts_steeply (knn ()), []))]
+       in
+         mesh_facts (op =) max_suggs mess
+       end)
+     |> map (curry Vector.sub fact_names))
+   end
 
 end;
 
@@ -706,36 +758,6 @@
       |> tap (fn NONE => trace_msg ctxt (fn () => "Cannot find " ^ quote nick) | _ => ())
   in map_filter lookup end
 
-fun scaled_avg [] = 0
-  | scaled_avg xs = Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
-
-fun avg [] = 0.0
-  | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
-
-fun normalize_scores _ [] = []
-  | normalize_scores max_facts xs =
-    map (apsnd (curry Real.* (1.0 / avg (map snd (take max_facts xs))))) xs
-
-fun mesh_facts fact_eq max_facts [(_, (sels, unks))] =
-    distinct fact_eq (map fst (take max_facts sels) @ take (max_facts - length sels) unks)
-  | mesh_facts fact_eq max_facts mess =
-    let
-      val mess = mess |> map (apsnd (apfst (normalize_scores max_facts)))
-
-      fun score_in fact (global_weight, (sels, unks)) =
-        let val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score) in
-          (case find_index (curry fact_eq fact o fst) sels of
-            ~1 => if member fact_eq unks fact then NONE else SOME 0.0
-          | rank => score_at rank)
-        end
-
-      fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
-    in
-      fold (union fact_eq o map fst o take max_facts o fst o snd) mess []
-      |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
-      |> map snd |> take max_facts
-    end
-
 fun free_feature_of s = "f" ^ s
 fun thy_feature_of s = "y" ^ s
 fun type_feature_of s = "t" ^ s
@@ -1098,20 +1120,6 @@
 val extra_feature_factor = 0.1 (* FUDGE *)
 val num_extra_feature_facts = 10 (* FUDGE *)
 
-(* FUDGE *)
-fun weight_of_proximity_fact rank =
-  Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
-
-fun weight_facts_smoothly facts =
-  facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
-
-(* FUDGE *)
-fun steep_weight_of_fact rank =
-  Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
-
-fun weight_facts_steeply facts =
-  facts ~~ map steep_weight_of_fact (0 upto length facts - 1)
-
 val max_proximity_facts = 100
 
 fun find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
@@ -1587,7 +1595,6 @@
     end
 
 fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
-
 fun running_learners () = Async_Manager.running_threads MaShN "learner"
 
 end;