tuned (reordered) code
authorblanchet
Tue, 01 Jul 2014 16:47:10 +0200
changeset 57458 419180c354c0
parent 57457 b2bafc09b7e7
child 57459 22023ab4df3c
tuned (reordered) code
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jul 01 16:47:10 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Jul 01 16:47:10 2014 +0200
@@ -34,10 +34,10 @@
   val decode_strs : string -> string list
 
   datatype mash_engine =
-    MaSh_kNN
+    MaSh_NB
+  | MaSh_NB_Ext
+  | MaSh_kNN
   | MaSh_kNN_Ext
-  | MaSh_NB
-  | MaSh_NB_Ext
 
   val is_mash_enabled : unit -> bool
   val the_mash_engine : unit -> mash_engine
@@ -138,20 +138,20 @@
   end
 
 datatype mash_engine =
-  MaSh_kNN
+  MaSh_NB
+| MaSh_NB_Ext
+| MaSh_kNN
 | MaSh_kNN_Ext
-| MaSh_NB
-| MaSh_NB_Ext
 
 fun mash_engine () =
   let val flag1 = Options.default_string @{system_option MaSh} in
     (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
       "yes" => SOME MaSh_NB
     | "sml" => SOME MaSh_NB
+    | "nb" => SOME MaSh_NB
+    | "nb_ext" => SOME MaSh_NB_Ext
     | "knn" => SOME MaSh_kNN
     | "knn_ext" => SOME MaSh_kNN_Ext
-    | "nb" => SOME MaSh_NB
-    | "nb_ext" => SOME MaSh_NB_Ext
     | _ => NONE)
   end
 
@@ -231,86 +231,12 @@
       ()
   end
 
-val number_of_nearest_neighbors = 10 (* FUDGE *)
-
 fun select_visible_facts big_number recommends =
   List.app (fn at =>
     let val (j, ov) = Array.sub (recommends, at) in
       Array.update (recommends, at, (j, big_number + ov))
     end)
 
-fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats =
-  let
-    exception EXIT of unit
-
-    val ln_afreq = Math.ln (Real.fromInt num_facts)
-    fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
-
-    val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
-
-    fun do_feat (s, sw0) =
-      let
-        val sw = sw0 * tfidf s
-        val w2 = sw * sw
-
-        fun inc_overlap j =
-          let val (_, ov) = Array.sub (overlaps_sqr, j) in
-            Array.update (overlaps_sqr, j, (j, w2 + ov))
-          end
-      in
-        List.app inc_overlap (Array.sub (feat_facts, s))
-      end
-
-    val _ = List.app do_feat goal_feats
-    val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
-    val no_recommends = Unsynchronized.ref 0
-    val recommends = Array.tabulate (num_facts, rpair 0.0)
-    val age = Unsynchronized.ref 500000000.0
-
-    fun inc_recommend j v =
-      let val (_, ov) = Array.sub (recommends, j) in
-        if ov <= 0.0 then
-          (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
-        else if ov < !age + 1000.0 then
-          Array.update (recommends, j, (j, v + ov))
-        else
-          ()
-      end
-
-    val k = Unsynchronized.ref 0
-    fun do_k k =
-      if k >= num_facts then
-        raise EXIT ()
-      else
-        let
-          val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
-          val o1 = Math.sqrt o2
-          val _ = inc_recommend j o1
-          val ds = Vector.sub (depss, j)
-          val l = Real.fromInt (length ds)
-        in
-          List.app (fn d => inc_recommend d (o1 / l)) ds
-        end
-
-    fun while1 () =
-      if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ())
-      handle EXIT () => ()
-
-    fun while2 () =
-      if !no_recommends >= max_suggs then ()
-      else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ())
-      handle EXIT () => ()
-
-    fun ret acc at =
-      if at = num_facts then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
-  in
-    while1 ();
-    while2 ();
-    select_visible_facts 1000000000.0 recommends visible_facts;
-    heap (Real.compare o pairself snd) max_suggs num_facts recommends;
-    ret [] (Integer.max 0 (num_facts - max_suggs))
-  end
-
 fun wider_array_of_vector init vec =
   let val ary = Array.array init in
     Array.copyVec {src = vec, dst = ary, di = 0};
@@ -393,6 +319,80 @@
     ret (Integer.max 0 (num_facts - max_suggs)) []
   end
 
+val number_of_nearest_neighbors = 10 (* FUDGE *)
+
+fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts goal_feats =
+  let
+    exception EXIT of unit
+
+    val ln_afreq = Math.ln (Real.fromInt num_facts)
+    fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
+
+    val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
+
+    fun do_feat (s, sw0) =
+      let
+        val sw = sw0 * tfidf s
+        val w2 = sw * sw
+
+        fun inc_overlap j =
+          let val (_, ov) = Array.sub (overlaps_sqr, j) in
+            Array.update (overlaps_sqr, j, (j, w2 + ov))
+          end
+      in
+        List.app inc_overlap (Array.sub (feat_facts, s))
+      end
+
+    val _ = List.app do_feat goal_feats
+    val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
+    val no_recommends = Unsynchronized.ref 0
+    val recommends = Array.tabulate (num_facts, rpair 0.0)
+    val age = Unsynchronized.ref 500000000.0
+
+    fun inc_recommend v j =
+      let val (_, ov) = Array.sub (recommends, j) in
+        if ov <= 0.0 then
+          (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
+        else if ov < !age + 1000.0 then
+          Array.update (recommends, j, (j, v + ov))
+        else
+          ()
+      end
+
+    val k = Unsynchronized.ref 0
+    fun do_k k =
+      if k >= num_facts then
+        raise EXIT ()
+      else
+        let
+          val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
+          val o1 = Math.sqrt o2
+          val _ = inc_recommend o1 j
+          val ds = Vector.sub (depss, j)
+          val l = Real.fromInt (length ds)
+        in
+          List.app (inc_recommend (o1 / l)) ds
+        end
+
+    fun while1 () =
+      if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ())
+      handle EXIT () => ()
+
+    fun while2 () =
+      if !no_recommends >= max_suggs then ()
+      else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ())
+      handle EXIT () => ()
+
+    fun ret acc at =
+      if at = num_facts then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
+  in
+    while1 ();
+    while2 ();
+    select_visible_facts 1000000000.0 recommends visible_facts;
+    heap (Real.compare o pairself snd) max_suggs num_facts recommends;
+    ret [] (Integer.max 0 (num_facts - max_suggs))
+  end
+
 (* experimental *)
 fun external_tool tool max_suggs learns goal_feats =
   let
@@ -435,15 +435,16 @@
 fun query_external ctxt engine max_suggs learns goal_feats =
   (trace_msg ctxt (fn () => "MaSh query external " ^ commas (map fst goal_feats));
    (case engine of
-     MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
-   | MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
+     MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats
+   | MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats))
 
 fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
     (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
   (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
      elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
    (case engine of
-     MaSh_kNN =>
+     MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats
+   | MaSh_kNN =>
      let
        val feat_facts = Array.array (num_feats, [])
        val _ =
@@ -452,8 +453,7 @@
            0 featss
      in
        k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
-     end
-   | MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
+     end)
    |> map (curry Vector.sub fact_names o fst))
 
 end;
@@ -1178,7 +1178,7 @@
         val (parents, goal_feats) = query_args access_G
         val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
       in
-        if engine = MaSh_kNN_Ext orelse engine = MaSh_NB_Ext then
+        if engine = MaSh_NB_Ext orelse engine = MaSh_kNN_Ext then
           let
             val learns =
               Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G