tuning
authorblanchet
Thu, 26 Jun 2014 13:36:13 +0200
changeset 57374 cb6667e7cbc1
parent 57373 e9d47cd3239b
child 57375 b75438e23925
tuning
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:36:06 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:36:13 2014 +0200
@@ -398,10 +398,10 @@
 
 exception EXIT of unit
 
-fun k_nearest_neighbors dffreq num_facts deps_vec get_sym_ths max_suggs visible_facts conj_feats =
+fun k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts conj_feats =
   let
     val ln_afreq = Math.ln (Real.fromInt num_facts)
-    fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Array.sub (dffreq, feat)))
+    fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat)))
 
     val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
 
@@ -416,7 +416,7 @@
         val w2 = sw * sw
         fun do_th j = if j < num_facts then inc_overlap j w2 else ()
       in
-        List.app do_th (get_sym_ths s)
+        List.app do_th (Array.sub (feat_facts, s))
       end
 
     val _ = List.app do_feat conj_feats
@@ -427,8 +427,13 @@
 
     fun inc_recommend j v =
       let val ov = snd (Array.sub (recommends, j)) in
-      if ov <= 0.0 then (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
-      else (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ()) end
+        if ov <= 0.0 then
+          (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
+        else if ov < !age + 1000.0 then
+          Array.update (recommends, j, (j, v + ov))
+        else
+          ()
+      end
 
     val k = Unsynchronized.ref 0
     fun do_k k =
@@ -439,7 +444,7 @@
           val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
           val o1 = Math.sqrt o2
           val _ = inc_recommend j o1
-          val ds = Vector.sub (deps_vec, j)
+          val ds = Vector.sub (depss, j)
           val l = Real.fromInt (length ds)
         in
           List.app (fn d => inc_recommend d (o1 / l)) ds
@@ -464,16 +469,26 @@
     ret [] (Integer.max 0 (num_facts - max_suggs))
   end
 
+fun wider_array_of_vector init vec =
+  let val ary = Array.array init in
+    Array.copyVec {src = vec, dst = ary, di = 0};
+    ary
+  end
+
 val nb_def_prior_weight = 21 (* FUDGE *)
 
-fun learn_facts tfreq sfreq dffreq num_facts depss featss =
+fun learn_facts (tfreq0, sfreq0, dffreq0) num_facts0 num_facts num_feats depss featss =
   let
-    fun learn_fact th feats deps =
+    val tfreq = wider_array_of_vector (num_facts, 0) tfreq0
+    val sfreq = wider_array_of_vector (num_facts, Inttab.empty) sfreq0
+    val dffreq = wider_array_of_vector (num_feats, 0) dffreq0
+
+    fun learn_one th feats deps =
       let
         fun add_th weight t =
           let
             val im = Array.sub (sfreq, t)
-            fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
+            fun fold_fn s = Inttab.map_default (s, 0) (Integer.add weight)
           in
             map_array_at tfreq (Integer.add weight) t;
             Array.update (sfreq, t, fold fold_fn feats im)
@@ -487,26 +502,30 @@
       end
 
     fun for i =
-      if i = num_facts then ()
-      else (learn_fact i (Vector.sub (featss, i)) (Vector.sub (depss, i)); for (i + 1))
+      if i = num_facts then
+        ()
+      else
+        (learn_one (num_facts0 + i) (Vector.sub (featss, i)) (Vector.sub (depss, i));
+         for (i + 1))
   in
-    for 0
+    for 0;
+    (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
   end
 
-fun naive_bayes_query tfreq sfreq dffreq num_facts max_suggs visible_facts conj_feats =
+fun naive_bayes tfreq sfreq dffreq num_facts max_suggs visible_facts conj_feats =
   let
     val tau = 0.05 (* FUDGE *)
     val pos_weight = 10.0 (* FUDGE *)
     val def_val = ~15.0 (* FUDGE *)
 
     val ln_afreq = Math.ln (Real.fromInt num_facts)
-    val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq)
+    val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) dffreq
 
     fun tfidf feat = Vector.sub (idf, feat)
 
     fun log_posterior i =
       let
-        val tfreq = Real.fromInt (Array.sub (tfreq, i))
+        val tfreq = Real.fromInt (Vector.sub (tfreq, i))
 
         fun fold_feats f (res, sfh) =
           (case Inttab.lookup sfh f of
@@ -515,7 +534,7 @@
              Inttab.delete f sfh)
           | NONE => (res + tfidf f * def_val, sfh))
 
-        val (res, sfh) = fold fold_feats conj_feats (Math.ln tfreq, Array.sub (sfreq, i))
+        val (res, sfh) = fold fold_feats conj_feats (Math.ln tfreq, Vector.sub (sfreq, i))
 
         fun fold_sfh (f, sf) sow = sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)
 
@@ -593,56 +612,57 @@
   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
 
-fun query ctxt engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs learns0
-    conj_feats =
+fun reorder_learns (num_facts, fact_tab) learns0 =
+  let
+    val learns = Array.array (num_facts, ("", [], []))
+  in
+    List.app (fn learn as (fact, _, _) =>
+        Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
+      learns0;
+    Array.foldr (op ::) [] learns
+  end
+
+fun query ctxt engine (fact_xtab as (num_facts, fact_tab)) (num_feats, feat_tab) visible_facts
+    max_suggs learns0 conj_feats =
   if engine = MaSh_SML_kNN_Cpp then
     k_nearest_neighbors_cpp max_suggs learns0 conj_feats
   else if engine = MaSh_SML_NB_Cpp then
     naive_bayes_cpp max_suggs learns0 conj_feats
   else
     let
-      val learn_ary = Array.array (num_facts, ("", [], []))
-      val _ =
-        List.app (fn entry as (fact, _, _) =>
-            Array.update (learn_ary, the (Symtab.lookup fact_tab fact), entry))
-          learns0
-      val learns = Array.foldr (op ::) [] learn_ary
+      val learns = reorder_learns fact_xtab learns0
+
+      val facts = Vector.fromList (map #1 learns)
+      val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
+      val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
 
-      val facts = map #1 learns
-      val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
-      val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
+      val tfreq = Vector.tabulate (num_facts, K 0)
+      val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
+      val dffreq = Vector.tabulate (num_feats, K 0)
 
-      val fact_vec = Vector.fromList facts
-      val feats_vec = Vector.fromList featss
-      val deps_vec = Vector.fromList depss
-
-      val tfreq = Array.array (num_facts, 0)
-      val sfreq = Array.array (num_facts, Inttab.empty)
-      val dffreq = Array.array (num_feats, 0)
-
-      val _ = learn_facts tfreq sfreq dffreq num_facts deps_vec feats_vec
+      val (tfreq, sfreq, dffreq) =
+        learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
 
       val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
       val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
     in
       trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^
-        elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
+        elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
       (case engine of
         MaSh_SML_kNN =>
         let
-          val facts_ary = Array.array (num_feats, [])
+          val feat_facts = Array.array (num_feats, [])
           val _ =
-            fold (fn feats => fn fact =>
-                (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
-              featss 0
-          val get_facts = curry Array.sub facts_ary
+            Vector.foldl (fn (feats, fact) =>
+                (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1))
+              0 featss
         in
-          k_nearest_neighbors dffreq num_facts deps_vec get_facts max_suggs int_visible_facts
+          k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs int_visible_facts
             int_conj_feats
         end
       | MaSh_SML_NB =>
-        naive_bayes_query tfreq sfreq dffreq num_facts max_suggs int_visible_facts int_conj_feats)
-      |> map (curry Vector.sub fact_vec o fst)
+        naive_bayes tfreq sfreq dffreq num_facts max_suggs int_visible_facts int_conj_feats)
+      |> map (curry Vector.sub facts o fst)
     end
 
 end;
@@ -664,12 +684,12 @@
   Graph.default_node (parent, (Isar_Proof, [], []))
   #> Graph.add_edge (parent, name)
 
-fun add_node kind name parents feats deps (access_G, fact_xtab, feat_xtab) =
+fun add_node kind name parents feats deps (access_G, (fact_xtab, feat_xtab)) =
   ((Graph.new_node (name, (kind, feats, deps)) access_G
     handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) access_G)
    |> fold (add_edge_to name) parents,
-  maybe_add_to_xtab name fact_xtab,
-  fold maybe_add_to_xtab feats feat_xtab)
+  (maybe_add_to_xtab name fact_xtab,
+   fold maybe_add_to_xtab feats feat_xtab))
 
 fun try_graph ctxt when def f =
   f ()
@@ -694,19 +714,17 @@
 
 type mash_state =
   {access_G : (proof_kind * string list * string list) Graph.T,
-   fact_xtab : xtab,
-   feat_xtab : xtab,
+   xtabs : xtab * xtab,
    num_known_facts : int, (* ### FIXME: kill *)
    dirty_facts : string list option}
 
 val empty_state =
   {access_G = Graph.empty,
-   fact_xtab = empty_xtab,
-   feat_xtab = empty_xtab,
+   xtabs = (empty_xtab, empty_xtab),
    num_known_facts = 0,
    dirty_facts = SOME []} : mash_state
 
-val empty_graphxx = (Graph.empty, empty_xtab, empty_xtab)
+val empty_graphxx = (Graph.empty, (empty_xtab, empty_xtab))
 
 local
 
@@ -741,7 +759,7 @@
                  NONE => I (* should not happen *)
                | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
 
-             val ((access_G, fact_xtab, feat_xtab), num_known_facts) =
+             val ((access_G, xtabs), num_known_facts) =
                (case string_ord (version', version) of
                  EQUAL =>
                  (try_graph ctxt "loading state" empty_graphxx (fn () =>
@@ -755,8 +773,8 @@
                | GREATER => raise FILE_VERSION_TOO_NEW ())
            in
              trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
-             {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
-              num_known_facts = num_known_facts, dirty_facts = SOME []}
+             {access_G = access_G, xtabs = xtabs, num_known_facts = num_known_facts,
+              dirty_facts = SOME []}
            end
          | _ => empty_state)))
   end
@@ -766,7 +784,7 @@
   encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
 
 fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
-  | save_state ctxt (memory_time, {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts}) =
+  | save_state ctxt (memory_time, {access_G, xtabs, num_known_facts, dirty_facts}) =
     let
       fun append_entry (name, ((kind, feats, deps), (parents, _))) =
         cons (kind, name, Graph.Keys.dest parents, feats, deps)
@@ -788,8 +806,8 @@
           SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
         | _ => "") ^  ")");
       (Time.now (),
-       {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
-        num_known_facts = num_known_facts, dirty_facts = SOME []})
+       {access_G = access_G, xtabs = xtabs, num_known_facts = num_known_facts,
+        dirty_facts = SOME []})
     end
 
 val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
@@ -1342,9 +1360,9 @@
         (parents, hints, feats)
       end
 
-    val ((access_G, fact_xtab, feat_xtab), py_suggs) =
-      peek_state ctxt overlord (fn {access_G, fact_xtab, feat_xtab, ...} =>
-        ((access_G, fact_xtab, feat_xtab),
+    val ((access_G, (fact_xtab, feat_xtab)), py_suggs) =
+      peek_state ctxt overlord (fn {access_G, xtabs, ...} =>
+        ((access_G, xtabs),
          if Graph.is_empty access_G then
            (trace_msg ctxt (K "Nothing has been learned yet"); [])
          else if engine = MaSh_Py then
@@ -1377,7 +1395,7 @@
   end
 
 fun learn_wrt_access_graph ctxt (name, parents, feats, deps)
-    (learns, (access_G, fact_xtab, feat_xtab)) =
+    (learns, (access_G, (fact_xtab, feat_xtab))) =
   let
     fun maybe_learn_from from (accum as (parents, access_G)) =
       try_graph ctxt "updating graph" accum (fn () =>
@@ -1390,7 +1408,7 @@
     val fact_xtab = maybe_add_to_xtab name fact_xtab
     val feat_xtab = fold maybe_add_to_xtab feats feat_xtab
   in
-    ((name, parents, feats, deps) :: learns, (access_G, fact_xtab, feat_xtab))
+    ((name, parents, feats, deps) :: learns, (access_G, (fact_xtab, feat_xtab)))
   end
 
 fun relearn_wrt_access_graph ctxt (name, deps) (relearns, access_G) =
@@ -1431,7 +1449,7 @@
         val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
       in
         map_state ctxt overlord
-          (fn state as {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =>
+          (fn state as {access_G, xtabs, num_known_facts, dirty_facts} =>
              let
                val parents = maximal_wrt_access_graph access_G facts
                val deps = used_ths
@@ -1443,12 +1461,10 @@
                else
                  let
                    val name = learned_proof_name ()
-                   val (access_G, fact_xtab, feat_xtab) =
-                     add_node Automatic_Proof name parents feats deps
-                       (access_G, fact_xtab, feat_xtab)
+                   val (access_G, xtabs) =
+                     add_node Automatic_Proof name parents feats deps (access_G, xtabs)
                  in
-                   {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
-                    num_known_facts = num_known_facts + 1,
+                   {access_G = access_G, xtabs = xtabs, num_known_facts = num_known_facts + 1,
                     dirty_facts = Option.map (cons name) dirty_facts}
                  end
              end);
@@ -1496,11 +1512,10 @@
             isar_dependencies_of name_tabs th
 
         fun do_commit [] [] [] state = state
-          | do_commit learns relearns flops
-              {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =
+          | do_commit learns relearns flops {access_G, xtabs, num_known_facts, dirty_facts} =
             let
-              val (learns, (access_G, fact_xtab, feat_xtab)) =
-                fold (learn_wrt_access_graph ctxt) learns ([], (access_G, fact_xtab, feat_xtab))
+              val (learns, (access_G, xtabs)) =
+                fold (learn_wrt_access_graph ctxt) learns ([], (access_G, xtabs))
               val (relearns, access_G) =
                 fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
 
@@ -1517,8 +1532,8 @@
                  MaSh_Py.relearn ctxt overlord save relearns)
               else
                 ();
-              {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
-               num_known_facts = num_known_facts, dirty_facts = dirty_facts}
+              {access_G = access_G, xtabs = xtabs, num_known_facts = num_known_facts,
+               dirty_facts = dirty_facts}
             end
 
         fun commit last learns relearns flops =