src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57102 3e6af473d666
parent 57101 c881a983a19f
child 57103 c9e400a05c9e
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed May 28 12:34:26 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed May 28 13:02:47 2014 +0200
@@ -362,19 +362,20 @@
 val number_of_nearest_neighbors = 40 (* FUDGE *)
 
 (*
-  avail_num = maximum number of theorems to check dependencies and symbols
-  adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num
+  num_facts = maximum number of theorems to check dependencies and symbols
+  num_visible_facts = do not return theorems over or equal to this number.
+    Must satisfy: num_visible_facts <= num_facts.
   get_deps = returns dependencies of a theorem
   get_sym_ths = get theorems that have this feature
-  advno = number of predictions to return
-  syms = symbols of the conjecture
+  max_suggs = number of suggestions to return
+  feats = features of the goal
 *)
-fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths advno syms =
+fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats =
   let
     (* Can be later used for TFIDF *)
     fun sym_wght _ = 1.0
 
-    val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)))
+    val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
 
     fun inc_overlap j v =
       let
@@ -383,30 +384,30 @@
         Array.update (overlaps_sqr, j, (j, v + ov))
       end
 
-    fun do_sym (s, con_wght) =
+    fun do_feat (s, con_wght) =
       let
         val sw = sym_wght s
         val w2 = sw * sw * con_wght
 
-        fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
+        fun do_th (j, prem_wght) = if j < num_facts then inc_overlap j (w2 * prem_wght) else ()
       in
         List.app do_th (get_sym_ths s)
       end
 
-    val _ = List.app do_sym syms
+    val _ = List.app do_feat feats
     val _ = heap (Real.compare o pairself snd) number_of_nearest_neighbors overlaps_sqr
-    val recommends = Array.tabulate (adv_max, rpair 0.0)
+    val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
 
     fun inc_recommend j v =
-      if j >= adv_max then ()
+      if j >= num_visible_facts then ()
       else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
 
     fun for k =
-      if k = number_of_nearest_neighbors orelse k >= adv_max then
+      if k = number_of_nearest_neighbors orelse k >= num_visible_facts then
         ()
       else
         let
-          val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1)
+          val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
           val o1 = Math.sqrt o2
           val _ = inc_recommend j o1
           val ds = get_deps j
@@ -419,8 +420,8 @@
       if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   in
     for 0;
-    heap (Real.compare o pairself snd) advno recommends;
-    ret [] (Integer.max 0 (adv_max - advno))
+    heap (Real.compare o pairself snd) max_suggs recommends;
+    ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   end
 
 val nb_tau = 0.02 (* FUDGE *)
@@ -428,14 +429,14 @@
 val nb_def_val = ~15.0 (* FUDGE *)
 val nb_def_prior_weight = 20 (* FUDGE *)
 
-fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms =
+fun naive_bayes_learn num_facts get_deps get_th_feats num_feats =
   let
     val afreq = Unsynchronized.ref 0
-    val tfreq = Array.array (avail_num, 0)
-    val sfreq = Array.array (avail_num, Inttab.empty)
-    val dffreq = Array.array (sym_num, 0)
+    val tfreq = Array.array (num_facts, 0)
+    val sfreq = Array.array (num_facts, Inttab.empty)
+    val dffreq = Array.array (num_feats, 0)
 
-    fun learn th syms deps =
+    fun learn th feats deps =
       let
         fun add_th weight t =
           let
@@ -443,57 +444,61 @@
             fun fold_fn s sf = Inttab.update (s, weight + the_default 0 (Inttab.lookup im s)) sf
           in
             Array.update (tfreq, t, weight + Array.sub (tfreq, t));
-            Array.update (sfreq, t, fold fold_fn syms im)
+            Array.update (sfreq, t, fold fold_fn feats im)
           end
 
         fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
       in
         add_th nb_def_prior_weight th;
         List.app (add_th 1) deps;
-        List.app add_sym syms;
+        List.app add_sym feats;
         afreq := !afreq + 1
       end
 
-    fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)))
+    fun for i =
+      if i = num_facts then () else (learn i (get_th_feats i) (get_deps i); for (i + 1))
+  in
+    for 0; (Real.fromInt (!afreq), Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
+  end
 
-    fun eval syms =
+fun naive_bayes_query num_visible_facts max_suggs feats (afreq, tfreq, sfreq, dffreq) =
+  let
+    fun tfidf feat = Math.ln afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
+
+    fun log_posterior i =
       let
-        fun log_posterior i =
-          let
-            val tfreq = Real.fromInt (Array.sub (tfreq, i))
-
-            fun fold_syms (f, fw) (res, sfh) =
-              (case Inttab.lookup sfh f of
-                SOME sf =>
-                (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq),
-                 Inttab.delete f sfh)
-              | NONE => (res + fw * nb_def_val, sfh))
+        val tfreq = Real.fromInt (Vector.sub (tfreq, i))
 
-            val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i))
-
-            fun fold_sfh (f, sf) sow =
-              sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq))
+        fun fold_feats (f, fw) (res, sfh) =
+          (case Inttab.lookup sfh f of
+            SOME sf =>
+            (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq),
+             Inttab.delete f sfh)
+          | NONE => (res + fw * nb_def_val, sfh))
 
-            val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
-          in
-            res + nb_tau * sum_of_weights
-          end
+        val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i))
 
-        val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j)))
+        fun fold_sfh (f, sf) sow =
+          sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq))
 
-        fun ret acc at =
-          if at = adv_max then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
+        val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
       in
-        heap (Real.compare o pairself snd) advno posterior;
-        ret [] (Integer.max 0 (adv_max - advno))
+        res + nb_tau * sum_of_weights
       end
 
-    fun for i =
-      if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1))
+    val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j)))
+
+    fun ret acc at =
+      if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
   in
-    for 0; eval syms
+    heap (Real.compare o pairself snd) max_suggs posterior;
+    ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   end
 
+fun naive_bayes num_facts num_visible_facts get_deps get_th_feats num_feats max_suggs feats =
+  naive_bayes_learn num_facts get_deps get_th_feats num_feats
+  |> naive_bayes_query num_visible_facts max_suggs feats
+
 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
@@ -547,19 +552,19 @@
                end)
              rev_featss num_facts
          val get_facts = curry Array.sub facts_ary
-         val syms = map_filter (fn (feat, weight) =>
+         val feats' = map_filter (fn (feat, weight) =>
            Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
        in
-         k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs syms
+         k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats'
        end
      else
        let
          val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
          val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
-         val syms = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
+         val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
        in
          naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs
-           syms
+           feats'
        end)
     |> map (curry Vector.sub fact_vec o fst)
   end