store string-to-index tables in memory
authorblanchet
Thu, 26 Jun 2014 13:35:56 +0200
changeset 57371 0b2bce982afd
parent 57370 9d420da6c7e2
child 57372 24738b4f8c6b
store string-to-index tables in memory
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:52 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:56 2014 +0200
@@ -121,8 +121,17 @@
 val relearn_isarN = "relearn_isar"
 val relearn_proverN = "relearn_prover"
 
+val hintsN = ".hints"
+
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
 
+type xtab = int * int Symtab.table
+
+val empty_xtab = (0, Symtab.empty)
+
+fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
+fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
+
 fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
 
@@ -384,10 +393,10 @@
 
 val number_of_nearest_neighbors = 10 (* FUDGE *)
 
-fun select_visible_facts recommends =
+fun select_visible_facts big_number recommends =
   List.app (fn at =>
     let val (j, ov) = Array.sub (recommends, at) in
-      Array.update (recommends, at, (j, 1000000000.0 + ov))
+      Array.update (recommends, at, (j, big_number + ov))
     end)
 
 exception EXIT of unit
@@ -461,7 +470,7 @@
   in
     while1 ();
     while2 ();
-    select_visible_facts recommends visible_facts;
+    select_visible_facts 1000000000.0 recommends visible_facts;
     heap (Real.compare o pairself snd) max_suggs num_facts recommends;
     ret [] (Integer.max 0 (num_facts - max_suggs))
   end
@@ -540,7 +549,7 @@
     fun ret at acc =
       if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc)
   in
-    select_visible_facts posterior visible_facts;
+    select_visible_facts 100000.0 posterior visible_facts;
     heap (Real.compare o pairself snd) max_suggs num_facts posterior;
     ret (Integer.max 0 (num_facts - max_suggs)) []
   end
@@ -608,13 +617,20 @@
 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
 
 fun query ctxt overlord engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs
-    learns conj_feats =
+    learns0 conj_feats =
   if engine = MaSh_SML_kNN_Cpp then
-    k_nearest_neighbors_cpp max_suggs learns conj_feats
+    k_nearest_neighbors_cpp max_suggs learns0 conj_feats
   else if engine = MaSh_SML_NB_Cpp then
-    naive_bayes_cpp max_suggs learns conj_feats
+    naive_bayes_cpp max_suggs learns0 conj_feats
   else
     let
+      val learn_ary = Array.array (num_facts, ("", [], []))
+      val _ =
+        List.app (fn entry as (fact, _, _) =>
+            Array.update (learn_ary, the (Symtab.lookup fact_tab fact), entry))
+          learns0
+      val learns = Array.foldr (op ::) [] learn_ary
+
       val facts = map #1 learns
       val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
       val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
@@ -675,10 +691,12 @@
   Graph.default_node (parent, (Isar_Proof, [], []))
   #> Graph.add_edge (parent, name)
 
-fun add_node kind name parents feats deps G =
-  (Graph.new_node (name, (kind, feats, deps)) G
-   handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) G)
-  |> fold (add_edge_to name) parents
+fun add_node kind name parents feats deps (access_G, fact_xtab, feat_xtab) =
+  ((Graph.new_node (name, (kind, feats, deps)) access_G
+    handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) access_G)
+   |> fold (add_edge_to name) parents,
+  maybe_add_to_xtab name fact_xtab,
+  fold maybe_add_to_xtab feats feat_xtab)
 
 fun try_graph ctxt when def f =
   f ()
@@ -703,10 +721,19 @@
 
 type mash_state =
   {access_G : (proof_kind * string list * string list) Graph.T,
-   num_known_facts : int,
+   fact_xtab : xtab,
+   feat_xtab : xtab,
+   num_known_facts : int, (* ### FIXME: kill *)
    dirty_facts : string list option}
 
-val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty_facts = SOME []} : mash_state
+val empty_state =
+  {access_G = Graph.empty,
+   fact_xtab = empty_xtab,
+   feat_xtab = empty_xtab,
+   num_known_facts = 0,
+   dirty_facts = SOME []} : mash_state
+
+val empty_graphxx = (Graph.empty, empty_xtab, empty_xtab)
 
 local
 
@@ -741,21 +768,22 @@
                  NONE => I (* should not happen *)
                | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
 
-             val (access_G, num_known_facts) =
+             val ((access_G, fact_xtab, feat_xtab), num_known_facts) =
                (case string_ord (version', version) of
                  EQUAL =>
-                 (try_graph ctxt "loading state" Graph.empty (fn () =>
-                    fold extract_line_and_add_node node_lines Graph.empty),
+                 (try_graph ctxt "loading state" empty_graphxx (fn () =>
+                    fold extract_line_and_add_node node_lines empty_graphxx),
                   length node_lines)
                | LESS =>
                  (* cannot parse old file *)
                  (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
                   else wipe_out_mash_state_dir ();
-                  (Graph.empty, 0))
+                  (empty_graphxx, 0))
                | GREATER => raise FILE_VERSION_TOO_NEW ())
            in
              trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
-             {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = SOME []}
+             {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
+              num_known_facts = num_known_facts, dirty_facts = SOME []}
            end
          | _ => empty_state)))
   end
@@ -765,7 +793,7 @@
   encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
 
 fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
-  | save_state ctxt (memory_time, {access_G, num_known_facts, dirty_facts}) =
+  | save_state ctxt (memory_time, {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts}) =
     let
       fun append_entry (name, ((kind, feats, deps), (parents, _))) =
         cons (kind, name, Graph.Keys.dest parents, feats, deps)
@@ -786,7 +814,9 @@
         (case dirty_facts of
           SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
         | _ => "") ^  ")");
-      (Time.now (), {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = SOME []})
+      (Time.now (),
+       {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
+        num_known_facts = num_known_facts, dirty_facts = SOME []})
     end
 
 val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
@@ -1291,11 +1321,6 @@
 fun add_const_counts t =
   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
 
-val empty_xtab = (0, Symtab.empty)
-
-fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
-fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
-
 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   let
     val thy = Proof_Context.theory_of ctxt
@@ -1344,19 +1369,18 @@
         (parents, hints, feats)
       end
 
-    val (access_G, py_suggs) =
-      peek_state ctxt overlord (fn {access_G, ...} =>
-        if Graph.is_empty access_G then
-          (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
-        else
-          (access_G,
-           if engine = MaSh_Py then
-             let val (parents, hints, feats) = query_args access_G in
-               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
-               |> map fst
-             end
-           else
-             []))
+    val ((access_G, fact_xtab, feat_xtab), py_suggs) =
+      peek_state ctxt overlord (fn {access_G, fact_xtab, feat_xtab, ...} =>
+        ((access_G, fact_xtab, feat_xtab),
+         if Graph.is_empty access_G then
+           (trace_msg ctxt (K "Nothing has been learned yet"); [])
+         else if engine = MaSh_Py then
+           let val (parents, hints, feats) = query_args access_G in
+             MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
+             |> map fst
+           end
+         else
+           []))
 
     val sml_suggs =
       if engine = MaSh_Py then
@@ -1367,11 +1391,8 @@
           val feats = map fst feats0
           val visible_facts = Graph.all_preds access_G parents
           val learns =
-            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G @
-            (if null hints then [] else [(".hints", feats, hints)])
-
-          val fact_xtab = fold (add_to_xtab o #1) learns empty_xtab
-          val feat_xtab = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
+            (if null hints then [] else [(hintsN, feats, hints)]) @ (* ### FIXME *)
+            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
         in
           MaSh_SML.query ctxt overlord engine fact_xtab feat_xtab visible_facts max_facts learns
             feats
@@ -1383,27 +1404,33 @@
     |> pairself (map fact_of_raw_fact)
   end
 
-fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
+fun learn_wrt_access_graph ctxt (name, parents, feats, deps)
+    (learns, (access_G, fact_xtab, feat_xtab)) =
   let
-    fun maybe_learn_from from (accum as (parents, G)) =
+    fun maybe_learn_from from (accum as (parents, access_G)) =
       try_graph ctxt "updating graph" accum (fn () =>
-        (from :: parents, Graph.add_edge_acyclic (from, name) G))
-    val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
-    val (parents, G) = ([], G) |> fold maybe_learn_from parents
-    val (deps, _) = ([], G) |> fold maybe_learn_from deps
+        (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
+
+    val access_G = access_G |> Graph.default_node (name, (Isar_Proof, feats, deps))
+    val (parents, access_G) = ([], access_G) |> fold maybe_learn_from parents
+    val (deps, _) = ([], access_G) |> fold maybe_learn_from deps
+
+    val fact_xtab = maybe_add_to_xtab name fact_xtab
+    val feat_xtab = fold maybe_add_to_xtab feats feat_xtab
   in
-    ((name, parents, feats, deps) :: learns, G)
+    ((name, parents, feats, deps) :: learns, (access_G, fact_xtab, feat_xtab))
   end
 
-fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
+fun relearn_wrt_access_graph ctxt (name, deps) (relearns, access_G) =
   let
-    fun maybe_relearn_from from (accum as (parents, G)) =
+    fun maybe_relearn_from from (accum as (parents, access_G)) =
       try_graph ctxt "updating graph" accum (fn () =>
-        (from :: parents, Graph.add_edge_acyclic (from, name) G))
-    val G = G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
-    val (deps, _) = ([], G) |> fold maybe_relearn_from deps
+        (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
+    val access_G =
+      access_G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
+    val (deps, _) = ([], access_G) |> fold maybe_relearn_from deps
   in
-    ((name, deps) :: relearns, G)
+    ((name, deps) :: relearns, access_G)
   end
 
 fun flop_wrt_access_graph name =
@@ -1431,24 +1458,28 @@
         val thy = Proof_Context.theory_of ctxt
         val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
       in
-        map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty_facts} =>
-          let
-            val parents = maximal_wrt_access_graph access_G facts
-            val deps = used_ths
-              |> filter (is_fact_in_graph access_G)
-              |> map nickname_of_thm
-          in
-            if the_mash_engine () = MaSh_Py then
-              (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state)
-            else
-              let
-                val name = learned_proof_name ()
-                val access_G = access_G |> add_node Automatic_Proof name parents feats deps
-              in
-                {access_G = access_G, num_known_facts = num_known_facts + 1,
-                 dirty_facts = Option.map (cons name) dirty_facts}
-              end
-          end);
+        map_state ctxt overlord
+          (fn state as {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =>
+             let
+               val parents = maximal_wrt_access_graph access_G facts
+               val deps = used_ths
+                 |> filter (is_fact_in_graph access_G)
+                 |> map nickname_of_thm
+             in
+               if the_mash_engine () = MaSh_Py then
+                 (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state)
+               else
+                 let
+                   val name = learned_proof_name ()
+                   val (access_G, fact_xtab, feat_xtab) =
+                     add_node Automatic_Proof name parents feats deps
+                       (access_G, fact_xtab, feat_xtab)
+                 in
+                   {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
+                    num_known_facts = num_known_facts + 1,
+                    dirty_facts = Option.map (cons name) dirty_facts}
+                 end
+             end);
         (true, "")
       end)
   else
@@ -1466,7 +1497,7 @@
     fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
 
     val engine = the_mash_engine ()
-    val {access_G, ...} = peek_state ctxt overlord I
+    val {access_G, fact_xtab, feat_xtab, ...} = peek_state ctxt overlord I
     val is_in_access_G = is_fact_in_graph access_G o snd
     val no_new_facts = forall is_in_access_G facts
   in
@@ -1493,12 +1524,15 @@
             isar_dependencies_of name_tabs th
 
         fun do_commit [] [] [] state = state
-          | do_commit learns relearns flops {access_G, num_known_facts, dirty_facts} =
+          | do_commit learns relearns flops
+              {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =
             let
+              val (learns, (access_G, fact_xtab, feat_xtab)) =
+                fold (learn_wrt_access_graph ctxt) learns ([], (access_G, fact_xtab, feat_xtab))
+              val (relearns, access_G) =
+                fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
+
               val was_empty = Graph.is_empty access_G
-              val (learns, access_G) = ([], access_G) |> fold (learn_wrt_access_graph ctxt) learns
-              val (relearns, access_G) =
-                ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
               val access_G = access_G |> fold flop_wrt_access_graph flops
               val num_known_facts = num_known_facts + length learns
               val dirty_facts =
@@ -1511,7 +1545,8 @@
                  MaSh_Py.relearn ctxt overlord save relearns)
               else
                 ();
-              {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = dirty_facts}
+              {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
+               num_known_facts = num_known_facts, dirty_facts = dirty_facts}
             end
 
         fun commit last learns relearns flops =