src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57278 8f7d6f01a775
parent 57277 31b35f5a5fdb
child 57281 bb671e6b740d
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Jun 18 17:42:24 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Jun 18 17:42:24 2014 +0200
@@ -35,7 +35,11 @@
   val encode_features : (string * real) list -> string
   val extract_suggestions : string -> string * (string * real) list
 
-  datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py
+  datatype mash_engine =
+    MaSh_Py
+  | MaSh_SML_kNN
+  | MaSh_SML_NB of bool * bool
+  | MaSh_SML_NB_Py
 
   val is_mash_enabled : unit -> bool
   val the_mash_engine : unit -> mash_engine
@@ -56,8 +60,8 @@
   sig
     val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) ->
       int -> (int * real) list -> (int * real) list
-    val naive_bayes : int -> int -> (int -> int list) -> (int -> int list) -> int -> int ->
-      (int * real) list -> (int * real) list
+    val naive_bayes : (bool * bool) -> int -> int -> (int -> int list) -> (int -> int list) ->
+      int -> int -> (int * real) list -> (int * real) list
     val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) ->
       (int -> int list) -> int -> int -> (int * real) list -> (int * real) list
     val query : Proof.context -> bool -> mash_engine -> string list -> int ->
@@ -150,22 +154,32 @@
     ()
   end
 
-datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py
+datatype mash_engine =
+  MaSh_Py
+| MaSh_SML_kNN
+| MaSh_SML_NB of bool * bool
+| MaSh_SML_NB_Py
+
+val default_MaSh_SML_NB = MaSh_SML_NB (false, false)
 
 fun mash_engine () =
   let val flag1 = Options.default_string @{system_option MaSh} in
     (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
-      "yes" => SOME MaSh_SML_NB
+      "yes" => SOME default_MaSh_SML_NB
     | "py" => SOME MaSh_Py
-    | "sml" => SOME MaSh_SML_NB
+    | "sml" => SOME default_MaSh_SML_NB
     | "sml_knn" => SOME MaSh_SML_kNN
-    | "sml_nb" => SOME MaSh_SML_NB
+    | "sml_nb" => SOME default_MaSh_SML_NB
+    | "sml_nbCC" => SOME (MaSh_SML_NB (false, false))
+    | "sml_nbCD" => SOME (MaSh_SML_NB (false, true))
+    | "sml_nbDC" => SOME (MaSh_SML_NB (true, false))
+    | "sml_nbDD" => SOME (MaSh_SML_NB (true, true))
     | "sml_nb_py" => SOME MaSh_SML_NB_Py
     | _ => NONE)
   end
 
 val is_mash_enabled = is_some o mash_engine
-val the_mash_engine = the_default MaSh_SML_NB o mash_engine
+val the_mash_engine = the_default default_MaSh_SML_NB o mash_engine
 
 
 (*** Low-level communication with the Python version of MaSh ***)
@@ -498,15 +512,13 @@
     for 0; (Array.vector tfreq, Array.vector sfreq (*, Array.vector dffreq *))
   end
 
-val nb_kuehlwein_style_log = false
-val nb_kuehlwein_style_params = false
+fun naive_bayes_query (kuehlwein_log, kuehlwein_params) _ (* num_facts *) num_visible_facts
+    max_suggs feats (tfreq, sfreq (*, dffreq*)) =
+  let
+    val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *)
+    val pos_weight = if kuehlwein_params then 10.0 else 2.0 (* FUDGE *)
+    val def_val = ~15.0 (* FUDGE *)
 
-val nb_tau = if nb_kuehlwein_style_params then 0.05 else 0.02 (* FUDGE *)
-val nb_pos_weight = if nb_kuehlwein_style_params then 10.0 else 2.0 (* FUDGE *)
-val nb_def_val = ~15.0 (* FUDGE *)
-
-fun naive_bayes_query _ (* num_facts *) num_visible_facts max_suggs feats (tfreq, sfreq (*, dffreq*)) =
-  let
 (*
     val afreq = Real.fromInt num_facts
     fun tfidf feat = Math.ln afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
@@ -520,14 +532,14 @@
         fun fold_feats (f, fw) (res, sfh) =
           (case Inttab.lookup sfh f of
             SOME sf =>
-            (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq),
+            (res + tfidf f * fw * Math.ln (pos_weight * Real.fromInt sf / tfreq),
              Inttab.delete f sfh)
-          | NONE => (res + fw * nb_def_val, sfh))
+          | NONE => (res + fw * def_val, sfh))
 
         val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i))
 
         val fold_sfh =
-          if nb_kuehlwein_style_log then
+          if kuehlwein_log then
             (fn (f, sf) => fn sow => sow - tfidf f * Math.ln (Real.fromInt sf / tfreq))
           else
             (fn (f, sf) => fn sow =>
@@ -535,7 +547,7 @@
 
         val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
       in
-        res + nb_tau * sum_of_weights
+        res + tau * sum_of_weights
       end
 
     val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j)))
@@ -547,9 +559,9 @@
     ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   end
 
-fun naive_bayes num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
+fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
   naive_bayes_learn num_facts get_deps get_feats num_feats
-  |> naive_bayes_query num_facts num_visible_facts max_suggs feats
+  |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats
 
 (* experimental *)
 fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs
@@ -633,7 +645,9 @@
          val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
          val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
        in
-         (if engine = MaSh_SML_NB then naive_bayes else naive_bayes_py ctxt overlord)
+         (case engine of
+           MaSh_SML_NB opts => naive_bayes opts
+         | _ => naive_bayes_py ctxt overlord)
            num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats'
        end)
     |> map (curry Vector.sub fact_vec o fst)