src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57124 e4c2c792226f
parent 57122 5f69b8c3af5a
child 57125 2f620ef839ee
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri May 30 12:27:51 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri May 30 12:27:51 2014 +0200
@@ -53,6 +53,10 @@
 
   structure MaSh_SML :
   sig
+    val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) ->
+      int -> (int * real) list -> (int * real) list
+    val naive_bayes : int -> int -> (int -> int list) -> (int -> Inttab.key list) -> int -> int ->
+      (Inttab.key * real) list -> (int * real) list
     val query : Proof.context -> mash_engine -> string list -> int ->
       (string * (string * real) list * string list) list * string list * (string * real) list ->
       string list
@@ -450,9 +454,6 @@
     ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   end
 
-val nb_tau = 0.02 (* FUDGE *)
-val nb_pos_weight = 2.0 (* FUDGE *)
-val nb_def_val = ~15.0 (* FUDGE *)
 val nb_def_prior_weight = 20 (* FUDGE *)
 
 (* TODO: Either use IDF or don't use it. See commented out code portions below. *)
@@ -487,6 +488,12 @@
     for 0; (Array.vector tfreq, Array.vector sfreq (*, Array.vector dffreq *))
   end
 
+val nb_kuehlwein_style = false
+
+val nb_tau = if nb_kuehlwein_style then 0.05 else 0.02 (* FUDGE *)
+val nb_pos_weight = if nb_kuehlwein_style then 20.0 else 2.0 (* FUDGE *)
+val nb_def_val = ~15.0 (* FUDGE *)
+
 fun naive_bayes_query _ (* num_facts *) num_visible_facts max_suggs feats (tfreq, sfreq (*, dffreq*)) =
   let
 (*
@@ -508,8 +515,12 @@
 
         val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i))
 
-        fun fold_sfh (f, sf) sow =
-          sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq))
+        val fold_sfh =
+          if nb_kuehlwein_style then
+            (fn (f, sf) => fn sow => sow - tfidf f * (tfreq / Math.ln (Real.fromInt sf)))
+          else
+            (fn (f, sf) => fn sow =>
+               sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq))
 
         val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
       in