src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57052 ea5912e3b008
parent 57039 1ddd1f75fb40
child 57055 df3a26987a8d
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed May 21 22:06:10 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu May 22 03:29:35 2014 +0200
@@ -375,7 +375,7 @@
   advno = number of predictions to return
   syms = symbols of the conjecture
 *)
-fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
+fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths knns advno syms =
   let
     (* Can be later used for TFIDF *)
     fun sym_wght _ = 1.0
@@ -429,10 +429,10 @@
     ret [] (Integer.max 0 (adv_max - advno))
   end
 
-(* Two arguments control the behaviour of nbayes: prior and ess. Prior expresses our belief in
+(* Two arguments control the behaviour of naive Bayes: prior and ess. Prior expresses our belief in
    usefulness of unknown features, and ess (equivalent sample size) expresses our confidence in the
    prior. *)
-fun nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms =
+fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno syms =
   let
     val afreq = Unsynchronized.ref 0
     val tfreq = Array.array (avail_num, 0)
@@ -540,13 +540,11 @@
 
     val deps_vec = Vector.fromList (rev rev_depss)
 
-    val avail_num = Vector.length deps_vec
-    val adv_max = length visible_facts
+    val num_visible_facts = length visible_facts
     val get_deps = curry Vector.sub deps_vec
-    val advno = max_suggs
   in
     trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
-      elide_string 1000 (space_implode " " facts) ^ "}");
+      elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
     (if engine = MaSh_SML_kNN then
        let
         val facts_ary = Array.array (num_feats, [])
@@ -558,20 +556,20 @@
                 fact'
               end)
             rev_featss num_facts
-         val get_sym_ths = curry Array.sub facts_ary
+         val get_facts = curry Array.sub facts_ary
          val syms = map_filter (fn (feat, weight) =>
            Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats
        in
-         knn avail_num adv_max get_deps get_sym_ths knns advno syms
+         k_nearest_neighbors num_facts num_visible_facts get_deps get_facts knns max_suggs syms
        end
      else
        let
          val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
-         val get_th_syms = curry Vector.sub unweighted_feats_ary
-         val sym_num = num_feats
+         val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
          val unweighted_syms = map_filter (Symtab.lookup feat_tab o str_of_feat o fst) feats
        in
-         nbayes avail_num adv_max get_deps get_th_syms sym_num ess prior advno unweighted_syms
+         naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats ess prior
+           max_suggs unweighted_syms
        end)
     |> map (curry Vector.sub fact_vec o fst)
   end
@@ -1258,6 +1256,8 @@
   let
     val thy = Proof_Context.theory_of ctxt
     val thy_name = Context.theory_name thy
+    val engine = the_mash_engine ()
+
     val facts = facts |> sort (crude_thm_ord o pairself snd o swap)
     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
     val num_facts = length facts
@@ -1284,20 +1284,24 @@
           |> map (rpair 1.0)
           |> map (chained_or_extra_features_of chained_feature_factor)
           |> rpair [] |-> fold (union (eq_fst (op =)))
-        val extra_feats = facts
-          |> take (Int.max (0, num_extra_feature_facts - length chained))
-          |> filter fact_has_right_theory
-          |> weight_facts_steeply
-          |> map (chained_or_extra_features_of extra_feature_factor)
-          |> rpair [] |-> fold (union (eq_fst (op =)))
+        val extra_feats =
+          (* As long as SML NB does not support weights, it makes little sense to include these
+             extra features. *)
+          if engine = MaSh_SML_NB then
+            []
+          else
+            facts
+            |> take (Int.max (0, num_extra_feature_facts - length chained))
+            |> filter fact_has_right_theory
+            |> weight_facts_steeply
+            |> map (chained_or_extra_features_of extra_feature_factor)
+            |> rpair [] |-> fold (union (eq_fst (op =)))
         val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
           |> debug ? sort (Real.compare o swap o pairself snd)
       in
         (parents, hints, feats)
       end
 
-    val engine = the_mash_engine ()
-
     val (access_G, py_suggs) =
       peek_state ctxt overlord (fn {access_G, ...} =>
         if Graph.is_empty access_G then