tuning
authorblanchet
Wed, 28 May 2014 09:38:39 +0200
changeset 57097 80b7c07e7a73
parent 57096 e4074b91b2a6
child 57098 c0a25c7c4b8e
tuning
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed May 28 03:10:30 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed May 28 09:38:39 2014 +0200
@@ -359,16 +359,17 @@
       ()
   end
 
+val number_of_nearest_neighbors = 40 (* FUDGE *)
+
 (*
   avail_num = maximum number of theorems to check dependencies and symbols
   adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num
   get_deps = returns dependencies of a theorem
   get_sym_ths = get theorems that have this feature
-  knns = number of nearest neighbours
   advno = number of predictions to return
   syms = symbols of the conjecture
 *)
-fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths knns advno syms =
+fun k_nearest_neighbors avail_num adv_max get_deps get_sym_ths advno syms =
   let
     (* Can be later used for TFIDF *)
     fun sym_wght _ = 1.0
@@ -393,7 +394,7 @@
       end
 
     val _ = List.app do_sym syms
-    val _ = heap (Real.compare o pairself snd) knns overlaps_sqr
+    val _ = heap (Real.compare o pairself snd) number_of_nearest_neighbors overlaps_sqr
     val recommends = Array.tabulate (adv_max, rpair 0.0)
 
     fun inc_recommend j v =
@@ -401,7 +402,7 @@
       else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
 
     fun for k =
-      if k = knns orelse k >= adv_max then
+      if k = number_of_nearest_neighbors orelse k >= adv_max then
         ()
       else
         let
@@ -422,64 +423,71 @@
     ret [] (Integer.max 0 (adv_max - advno))
   end
 
-val tau = 0.02
-val posWeight = 2.0
-val defVal = ~15.0
-val defPriWei = 20
+val nb_tau = 0.02 (* FUDGE *)
+val nb_pos_weight = 2.0 (* FUDGE *)
+val nb_def_val = ~15.0 (* FUDGE *)
+val nb_def_prior_weight = 20 (* FUDGE *)
 
 fun naive_bayes avail_num adv_max get_deps get_th_syms sym_num advno syms =
   let
-    val afreq = Unsynchronized.ref 0;
-    val tfreq = Array.array (avail_num, 0);
-    val sfreq = Array.array (avail_num, Inttab.empty);
-    val dffreq = Array.array (sym_num, 0);
+    val afreq = Unsynchronized.ref 0
+    val tfreq = Array.array (avail_num, 0)
+    val sfreq = Array.array (avail_num, Inttab.empty)
+    val dffreq = Array.array (sym_num, 0)
 
     fun learn th syms deps =
       let
         fun add_th t =
           let
-            val im = Array.sub (sfreq, t);
-            fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf;
+            val im = Array.sub (sfreq, t)
+            fun fold_fn s sf = Inttab.update (s, 1 + the_default 0 (Inttab.lookup im s)) sf
           in
             Array.update (tfreq, t, 1 + Array.sub (tfreq, t));
             Array.update (sfreq, t, fold fold_fn syms im)
-          end;
-        fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s));
+          end
+
+        fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
       in
-        List.app add_th (replicate defPriWei th);
+        List.app add_th (replicate nb_def_prior_weight th);
         List.app add_th deps;
         List.app add_sym syms;
         afreq := !afreq + 1
-      end;
+      end
 
-    fun tfidf _ = 1.0;
-    (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)));*)
+    fun tfidf _ = 1.0
+    (*fun tfidf sym = Math.ln (Real.fromInt (!afreq)) - Math.ln (Real.fromInt (Array.sub (dffreq, sym)))*)
 
     fun eval syms =
       let
         fun log_posterior i =
           let
-            val tfreq = Real.fromInt (Array.sub (tfreq, i));
+            val tfreq = Real.fromInt (Array.sub (tfreq, i))
+
             fun fold_syms (f, fw) (res, sfh) =
               (case Inttab.lookup sfh f of
                 SOME sf =>
-                (res + tfidf f * fw * Math.ln (posWeight * Real.fromInt sf / tfreq),
+                (res + tfidf f * fw * Math.ln (nb_pos_weight * Real.fromInt sf / tfreq),
                  Inttab.delete f sfh)
-              | NONE => (res + fw * defVal, sfh));
-            val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i));
+              | NONE => (res + fw * nb_def_val, sfh))
+
+            val (res, sfh) = fold fold_syms syms (Math.ln tfreq, Array.sub (sfreq,i))
+
             fun fold_sfh (f, sf) sow =
-              sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq));
-            val sumOfWei = Inttab.fold fold_sfh sfh 0.0;
+              sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq))
+
+            val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
           in
-            res + tau * sumOfWei
+            res + nb_tau * sum_of_weights
           end
-        val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j)));
+
+        val posterior = Array.tabulate (adv_max, (fn j => (j, log_posterior j)))
+
         fun ret acc at =
-          if at = adv_max then acc else ret (Array.sub (posterior,at) :: acc) (at + 1)
+          if at = adv_max then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
       in
         heap (Real.compare o pairself snd) advno posterior;
         ret [] (Integer.max 0 (adv_max - advno))
-      end;
+      end
 
     fun for i =
       if i = avail_num then () else (learn i (get_th_syms i) (get_deps i); for (i + 1))
@@ -487,8 +495,6 @@
     for 0; eval syms
   end
 
-val knns = 40 (* FUDGE *)
-
 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
@@ -544,7 +550,7 @@
              rev_featss num_facts
          val get_facts = curry Array.sub facts_ary
        in
-         k_nearest_neighbors num_facts num_visible_facts get_deps get_facts knns max_suggs syms
+         k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs syms
        end
      else
        let
@@ -1317,7 +1323,6 @@
       in
         map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} =>
           let
-            val name = learned_proof_name ()
             val parents = maximal_wrt_access_graph access_G facts
             val deps = used_ths
               |> filter (is_fact_in_graph access_G)
@@ -1326,7 +1331,10 @@
             if the_mash_engine () = MaSh_Py then
               (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
             else
-              let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
+              let
+                val name = learned_proof_name ()
+                val access_G = access_G |> add_node Automatic_Proof name parents feats deps
+              in
                 {access_G = access_G, num_known_facts = num_known_facts + 1,
                  dirty = Option.map (cons name) dirty}
               end