better way to take invisible facts into account than 'island' business
authorblanchet
Tue, 20 May 2014 09:38:39 +0200
changeset 57013 ed95456499e6
parent 57012 43fd82a537a3
child 57014 b7999893ffcc
better way to take invisible facts into account than 'island' business
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 02:47:23 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 09:38:39 2014 +0200
@@ -108,12 +108,6 @@
 val relearn_isarN = "relearn_isar"
 val relearn_proverN = "relearn_prover"
 
-val learned_proof_prefix = ".."
-
-fun learned_proof_name () =
-  learned_proof_prefix ^ Date.fmt "%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^
-  serial_string ()
-
 fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
 
@@ -382,68 +376,62 @@
   end
 
 (*
-  avail_no = maximum number of theorems to check dependencies and symbols
+  avail_num = maximum number of theorems to check dependencies and symbols
+  adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num
   get_deps = returns dependencies of a theorem
   get_sym_ths = get theorems that have this feature
-  knns    = number of nearest neighbours
-  advno   = number of predictions to return
-  syms    = symbols of the conjecture
+  knns = number of nearest neighbours
+  advno = number of predictions to return
+  syms = symbols of the conjecture
 *)
-fun knn avail_no get_deps get_sym_ths knns advno syms =
+fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
   let
     (* Can be later used for TFIDF *)
-    fun sym_wght _ = 1.0
-
-    val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)))
-
+    fun sym_wght _ = 1.0;
+    val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)));
     fun inc_overlap j v =
       let
         val ov = snd (Array.sub (overlaps_sqr,j))
       in
         Array.update (overlaps_sqr, j, (j, v + ov))
-      end
-
+      end;
     fun do_sym (s, con_wght) =
       let
-        val sw = sym_wght s
-        val w2 = sw * sw * con_wght
-        fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else ()
+        val sw = sym_wght s;
+        val w2 = sw * sw * con_wght;
+        fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
       in
         ignore (map do_th (get_sym_ths s))
-      end
-
-    val _ = ignore (map do_sym syms)
-    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
-    val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)))
-
+      end;
+    val () = ignore (map do_sym syms);
+    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
+    val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0)));
     fun inc_recommend j v =
+      if j >= adv_max then () else
       let
         val ov = snd (Array.sub (recommends,j))
       in
         Array.update (recommends, j, (j, v + ov))
-      end
-
+      end;
     fun for k =
       if k = knns then () else
-      if k >= avail_no then () else
+      if k >= adv_max then () else
       let
-        val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1)
-        val o1 = Math.sqrt o2
-        val _ = inc_recommend j o1
-        val ds = get_deps j
-        val l = Real.fromInt (length ds)
+        val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1);
+        val o1 = Math.sqrt o2;
+        val () = inc_recommend j o1;
+        val ds = get_deps j;
+        val l = Real.fromInt (length ds);
         val _ = map (fn d => inc_recommend d (o1 / l)) ds
       in
         for (k + 1)
-      end
-
-    val _ = for 0
-    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends
-
+      end;
+    val () = for 0;
+    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
     fun ret acc at =
       if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
   in
-    ret [] (max 0 (avail_no - advno))
+    ret [] (max 0 (adv_max - advno))
   end
 
 val knns = 40 (* FUDGE *)
@@ -456,11 +444,18 @@
   let
     val str_of_feat = space_implode "|"
 
-    val (depss0, featss, (_, _, facts0), (num_feats, feat_tab, _)) =
-      fold_rev (fn fact => fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
+    val visible_facts = Graph.all_preds access_G parents
+    val visible_fact_set = Symtab.make_set visible_facts
+
+    val all_nodes =
+      Graph.schedule (K I) access_G
+      |> List.partition (Symtab.defined visible_fact_set o fst)
+      |> op @
+
+    val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) =
+      fold (fn (fact, (_, feats, deps)) =>
+            fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
           let
-            val (_, feats, deps) = Graph.get_node access_G fact
-
             fun add_feat (feat, weight) (xtab as (n, tab, _)) =
               (case Symtab.lookup tab feat of
                 SOME i => ((i, weight), xtab)
@@ -471,12 +466,12 @@
             (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
              add_to_xtab fact fact_xtab, feat_xtab')
           end)
-        (Graph.all_preds access_G parents) ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
+        all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
 
-    val facts = rev facts0
+    val facts = rev rev_facts
     val fact_ary = Array.fromList facts
 
-    val deps_ary = Array.fromList (rev depss0)
+    val deps_ary = Array.fromList (rev rev_depss)
     val facts_ary = Array.array (num_feats, [])
     val _ =
       fold (fn feats => fn fact =>
@@ -487,15 +482,13 @@
           end)
         featss (length featss)
   in
-    trace_msg ctxt (fn () =>
-      "MaSh_SML query " ^ encode_features feats ^ " from {" ^
-       elide_string 1000 (space_implode " " facts) ^ "}");
-    knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
-      max_suggs
+    trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
+      elide_string 1000 (space_implode " " facts) ^ "}");
+    knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary)
+      (curry Array.sub facts_ary) knns max_suggs
       (map_filter (fn (feat, weight) =>
          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
     |> map ((fn i => Array.sub (fact_ary, i)) o fst)
-    |> filter_out (String.isPrefix learned_proof_prefix)
   end
 
 end;
@@ -517,9 +510,10 @@
   Graph.default_node (parent, (Isar_Proof, [], []))
   #> Graph.add_edge (parent, name)
 
-fun add_node kind name feats deps =
+fun add_node kind name parents feats deps =
   Graph.default_node (name, (kind, feats, deps))
   #> Graph.map_node name (K (kind, feats, deps))
+  #> fold (add_edge_to name) parents
 
 fun try_graph ctxt when def f =
   f ()
@@ -575,9 +569,7 @@
            fun extract_line_and_add_node line =
              (case extract_node line of
                NONE => I (* should not happen *)
-             | SOME (kind, name, parents, feats, deps) =>
-               add_node kind name feats deps
-               #> fold (add_edge_to name) parents)
+             | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
 
            val (access_G, num_known_facts) =
              (case string_ord (version', version) of
@@ -1132,13 +1124,8 @@
     find_maxes Symtab.empty ([], Graph.maximals G)
   end
 
-fun graph_islands G =
-  Graph.fold (fn (m, (_, (preds, succs))) =>
-    (Graph.Keys.is_empty preds andalso Graph.Keys.is_empty succs) ? cons m) G [];
-
-(* islands represent learned proofs associated with no facts *)
 fun maximal_wrt_access_graph access_G facts =
-  map (nickname_of_thm o snd) facts @ graph_islands access_G
+  map (nickname_of_thm o snd) facts
   |> maximal_wrt_graph access_G
 
 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
@@ -1240,7 +1227,7 @@
 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
   let
     fun maybe_learn_from from (accum as (parents, G)) =
-      try_graph ctxt "updating G" accum (fn () =>
+      try_graph ctxt "updating graph" accum (fn () =>
         (from :: parents, Graph.add_edge_acyclic (from, name) G))
     val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
     val (parents, G) = ([], G) |> fold maybe_learn_from parents
@@ -1275,6 +1262,9 @@
     Async_Manager.thread MaShN birth_time death_time desc task
   end
 
+fun learned_proof_name () =
+  Date.fmt ".%Y%m%d.%H%M%S." (Date.fromTimeLocal (Time.now ())) ^ serial_string ()
+
 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths =
   if is_mash_enabled () then
     launch_thread timeout (fn () =>
@@ -1285,19 +1275,18 @@
         map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} =>
           let
             val name = learned_proof_name ()
+            val parents = maximal_wrt_access_graph access_G facts
             val deps = used_ths
               |> filter (is_fact_in_graph access_G)
               |> map nickname_of_thm
           in
             if Config.get ctxt sml then
-              let val access_G = access_G |> add_node Automatic_Proof name feats deps in
+              let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
                 {access_G = access_G, num_known_facts = num_known_facts + 1,
                  dirty = Option.map (cons name) dirty}
               end
             else
-              let val parents = maximal_wrt_access_graph access_G facts in
-                (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
-              end
+              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
           end);
         (true, "")
       end)