recompute learning data at learning time, not query time
authorblanchet
Thu, 26 Jun 2014 16:41:30 +0200
changeset 57378 fe96689f393b
parent 57377 73e9b858ec8d
child 57379 dcaf04545de2
recompute learning data at learning time, not query time
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 16:41:30 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 16:41:30 2014 +0200
@@ -616,10 +616,10 @@
      MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
    | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
 
-fun query_internal ctxt engine num_facts num_feats (facts, featss, depss) (freqs as (_, _, dffreq))
-    visible_facts max_suggs goal_feats int_goal_feats =
+fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
+    (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
   (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_strs goal_feats ^ " from {" ^
-     elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
+     elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
    (case engine of
      MaSh_SML_kNN =>
      let
@@ -632,7 +632,7 @@
        k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
      end
    | MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
-   |> map (curry Vector.sub facts o fst))
+   |> map (curry Vector.sub fact_names o fst))
 
 end;
 
@@ -684,14 +684,47 @@
 type mash_state =
   {access_G : (proof_kind * string list * string list) Graph.T,
    xtabs : xtab * xtab,
+   ffds : string vector * int list vector * int list vector,
+   freqs : int vector * int Inttab.table vector * int vector,
    dirty_facts : string list option}
 
+val empty_xtabs = (empty_xtab, empty_xtab)
+val empty_ffds = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
+val empty_freqs = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
+val empty_graphxx = (Graph.empty, empty_xtabs)
+
 val empty_state =
   {access_G = Graph.empty,
-   xtabs = (empty_xtab, empty_xtab),
+   xtabs = empty_xtabs,
+   ffds = empty_ffds,
+   freqs = empty_freqs,
    dirty_facts = SOME []} : mash_state
 
-val empty_graphxx = (Graph.empty, (empty_xtab, empty_xtab))
+fun reorder_learns (num_facts, fact_tab) learns =
+  let val ary = Array.array (num_facts, ("", [], [])) in
+    List.app (fn learn as (fact, _, _) =>
+        Array.update (ary, the (Symtab.lookup fact_tab fact), learn))
+      learns;
+    Array.foldr (op ::) [] ary
+  end
+
+fun recompute_ffd_freqs access_G (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab)) =
+  let
+    val learns =
+      Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+      |> reorder_learns fact_xtab
+
+    val fact_names = Vector.fromList (map #1 learns)
+    val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
+    val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
+
+    val tfreq = Vector.tabulate (num_facts, K 0)
+    val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
+    val dffreq = Vector.tabulate (num_feats, K 0)
+  in
+    ((fact_names, featss, depss),
+     MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss)
+  end
 
 local
 
@@ -737,9 +770,11 @@
                   else wipe_out_mash_state_dir ();
                   empty_graphxx)
                | GREATER => raise FILE_VERSION_TOO_NEW ())
+
+             val (ffds, freqs) = recompute_ffd_freqs access_G xtabs
            in
              trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
-             {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []}
+             {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []}
            end
          | _ => empty_state)))
   end
@@ -749,7 +784,7 @@
   encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
 
 fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
-  | save_state ctxt (memory_time, {access_G, xtabs, dirty_facts}) =
+  | save_state ctxt (memory_time, {access_G, xtabs, ffds, freqs, dirty_facts}) =
     let
       fun append_entry (name, ((kind, feats, deps), (parents, _))) =
         cons (kind, name, Graph.Keys.dest parents, feats, deps)
@@ -770,7 +805,8 @@
         (case dirty_facts of
           SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
         | _ => "") ^  ")");
-      (Time.now (), {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []})
+      (Time.now (),
+       {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []})
     end
 
 val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
@@ -1275,16 +1311,6 @@
 fun add_const_counts t =
   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
 
-fun reorder_learns (num_facts, fact_tab) learns0 =
-  let
-    val learns = Array.array (num_facts, ("", [], []))
-  in
-    List.app (fn learn as (fact, _, _) =>
-        Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
-      learns0;
-    Array.foldr (op ::) [] learns
-  end
-
 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts =
   let
     val thy = Proof_Context.theory_of ctxt
@@ -1333,9 +1359,9 @@
         (parents, hints, feats)
       end
 
-    val ((access_G, (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab))), py_suggs) =
-      peek_state ctxt overlord (fn {access_G, xtabs, ...} =>
-        ((access_G, xtabs),
+    val ((access_G, ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs), py_suggs) =
+      peek_state ctxt overlord (fn {access_G, xtabs, ffds, freqs, ...} =>
+        ((access_G, xtabs, ffds, freqs),
          if Graph.is_empty access_G then
            (trace_msg ctxt (K "Nothing has been learned yet"); [])
          else if engine = MaSh_Py then
@@ -1364,25 +1390,10 @@
             end
           else
             let
-              val learns0 =
-                Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
-              val learns = reorder_learns fact_xtab learns0
-
-              val facts = Vector.fromList (map #1 learns)
-              val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
-              val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
-
-              val tfreq = Vector.tabulate (num_facts, K 0)
-              val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
-              val dffreq = Vector.tabulate (num_feats, K 0)
-
-              val freqs' =
-                MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
-
               val int_goal_feats = map_filter (Symtab.lookup feat_tab) goal_feats
             in
-              MaSh_SML.query_internal ctxt engine num_facts num_feats (facts, featss, depss) freqs'
-                visible_facts max_suggs goal_feats int_goal_feats
+              MaSh_SML.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts
+                max_suggs goal_feats int_goal_feats
             end
         end
 
@@ -1447,7 +1458,7 @@
         val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
       in
         map_state ctxt overlord
-          (fn state as {access_G, xtabs, dirty_facts} =>
+          (fn state as {access_G, xtabs, ffds, freqs, dirty_facts} =>
              let
                val parents = maximal_wrt_access_graph access_G facts
                val deps = used_ths
@@ -1459,10 +1470,12 @@
                else
                  let
                    val name = learned_proof_name ()
-                   val (access_G, xtabs) =
+                   val (access_G', xtabs') =
                      add_node Automatic_Proof name parents feats deps (access_G, xtabs)
+
+                   val (ffds', freqs') = recompute_ffd_freqs access_G' xtabs'
                  in
-                   {access_G = access_G, xtabs = xtabs,
+                   {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs',
                     dirty_facts = Option.map (cons name) dirty_facts}
                  end
              end);
@@ -1510,26 +1523,31 @@
             isar_dependencies_of name_tabs th
 
         fun do_commit [] [] [] state = state
-          | do_commit learns relearns flops {access_G, xtabs, dirty_facts} =
+          | do_commit learns relearns flops {access_G, xtabs, ffds, freqs, dirty_facts} =
             let
+              val was_empty = Graph.is_empty access_G
+
+              (* TODO: use "fold_map" *)
               val (learns, (access_G, xtabs)) =
                 fold (learn_wrt_access_graph ctxt) learns ([], (access_G, xtabs))
               val (relearns, access_G) =
                 fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
 
-              val was_empty = Graph.is_empty access_G
               val access_G = access_G |> fold flop_wrt_access_graph flops
               val dirty_facts =
                 (case (was_empty, dirty_facts) of
                   (false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names)
                 | _ => NONE)
+
+              val (ffds', freqs') = recompute_ffd_freqs access_G xtabs
             in
               if engine = MaSh_Py then
                 (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
                  MaSh_Py.relearn ctxt overlord save relearns)
               else
                 ();
-              {access_G = access_G, xtabs = xtabs, dirty_facts = dirty_facts}
+              {access_G = access_G, xtabs = xtabs, ffds = ffds', freqs = freqs',
+               dirty_facts = dirty_facts}
             end
 
         fun commit last learns relearns flops =