src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57355 a9e0f9d35125
parent 57354 ded92100ffd7
child 57356 9816f692b0ca
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:33:21 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:33:27 2014 +0200
@@ -352,7 +352,7 @@
 
 exception BOTTOM of int
 
-fun heap cmp bnd a =
+fun heap cmp bnd al a =
   let
     fun maxson l i =
       let val i31 = i + i + i + 1 in
@@ -394,12 +394,10 @@
           Array.update (a, i, e)
       end
 
-    val l = Array.length a
-
-    fun for i = if i < 0 then () else (trickle l i (Array.sub (a, i)); for (i - 1))
+    fun for i = if i < 0 then () else (trickle al i (Array.sub (a, i)); for (i - 1))
 
     fun for2 i =
-      if i < Integer.max 2 (l - bnd) then
+      if i < Integer.max 2 (al - bnd) then
         ()
       else
         let val e = Array.sub (a, i) in
@@ -408,9 +406,9 @@
           for2 (i - 1)
         end
   in
-    for (((l + 1) div 3) - 1);
-    for2 (l - 1);
-    if l > 1 then
+    for (((al + 1) div 3) - 1);
+    for2 (al - 1);
+    if al > 1 then
       let val e = Array.sub (a, 1) in
         Array.update (a, 1, Array.sub (a, 0));
         Array.update (a, 0, e)
@@ -457,7 +455,7 @@
       end
 
     val _ = List.app do_feat feats
-    val _ = heap (Real.compare o pairself snd) num_facts overlaps_sqr
+    val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
     val no_recommends = Unsynchronized.ref 0
     val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
     val age = Unsynchronized.ref 1000000000.0
@@ -498,39 +496,34 @@
       if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   in
     while1 (); while2 ();
-    heap (Real.compare o pairself snd) max_suggs recommends;
+    heap (Real.compare o pairself snd) max_suggs num_visible_facts recommends;
     ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   end
 
 val nb_def_prior_weight = 21 (* FUDGE *)
 
-fun naive_bayes_learn_fact tfreq sfreq dffreq th feats deps =
+fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
   let
-    fun add_th weight t =
+    fun learn_fact th feats deps =
       let
-        val im = Array.sub (sfreq, t)
-        fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
+        fun add_th weight t =
+          let
+            val im = Array.sub (sfreq, t)
+            fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
+          in
+            Array.update (tfreq, t, weight + Array.sub (tfreq, t));
+            Array.update (sfreq, t, fold fold_fn feats im)
+          end
+
+        fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
       in
-        Array.update (tfreq, t, weight + Array.sub (tfreq, t));
-        Array.update (sfreq, t, fold fold_fn feats im)
+        add_th nb_def_prior_weight th;
+        List.app (add_th 1) deps;
+        List.app add_sym feats
       end
 
-    fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
-  in
-    add_th nb_def_prior_weight th;
-    List.app (add_th 1) deps;
-    List.app add_sym feats
-  end
-
-fun naive_bayes_learn num_facts get_deps get_feats num_feats =
-  let
-    val tfreq = Array.array (num_facts, 0)
-    val sfreq = Array.array (num_facts, Inttab.empty)
-    val dffreq = Array.array (num_feats, 0)
-
     fun for i =
-      if i = num_facts then ()
-      else (naive_bayes_learn_fact tfreq sfreq dffreq i (get_feats i) (get_deps i); for (i + 1))
+      if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
 
     val ln_afreq = Math.ln (Real.fromInt num_facts)
   in
@@ -539,6 +532,15 @@
      Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq))
   end
 
+fun learn num_facts get_deps get_feats num_feats =
+  let
+    val tfreq = Array.array (num_facts, 0)
+    val sfreq = Array.array (num_facts, Inttab.empty)
+    val dffreq = Array.array (num_feats, 0)
+  in
+    learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
+  end
+
 fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts num_visible_facts max_suggs feats
     (tfreq, sfreq, idf) =
   let
@@ -579,12 +581,12 @@
     fun ret acc at =
       if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
   in
-    heap (Real.compare o pairself snd) max_suggs posterior;
+    heap (Real.compare o pairself snd) max_suggs num_visible_facts posterior;
     ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   end
 
 fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
-  naive_bayes_learn num_facts get_deps get_feats num_feats
+  learn num_facts get_deps get_feats num_feats
   |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats
 
 (* experimental *)