speed up MaSh
authorblanchet
Sat, 03 Oct 2015 17:11:04 +0200
changeset 61318 6a5a188ab3e7
parent 61317 b089c00f4db0
child 61320 69022bbcd012
speed up MaSh
NEWS
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/NEWS	Fri Oct 02 21:31:51 2015 +0200
+++ b/NEWS	Sat Oct 03 17:11:04 2015 +0200
@@ -243,6 +243,7 @@
 * Discontinued simp_legacy_precond. Potential INCOMPATIBILITY.
 
 * Sledgehammer:
+  - The MaSh relevance filter has been sped up.
   - Proof reconstruction has been improved, to minimize the incidence of
     cases where Sledgehammer gives a proof that does not work.
   - Auto Sledgehammer now minimizes and preplays the results.
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Oct 02 21:31:51 2015 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Sat Oct 03 17:11:04 2015 +0200
@@ -94,6 +94,8 @@
 open Sledgehammer_Prover_Minimize
 open Sledgehammer_MePo
 
+val anonymous_proof_prefix = "."
+
 val trace = Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
 val duplicates = Attrib.setup_config_bool @{binding sledgehammer_fact_duplicates} (K false)
 
@@ -1089,31 +1091,36 @@
       |> drop (length old_facts)
     end
 
-fun maximal_wrt_graph G keys =
-  let
-    val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
+fun maximal_wrt_graph _ [] = []
+  | maximal_wrt_graph G keys =
+    if can (Graph.get_node G o the_single) keys then
+      keys
+    else
+      let
+        val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
 
-    fun insert_new seen name = not (Symtab.defined seen name) ? insert (op =) name
+        fun insert_new seen name = not (Symtab.defined seen name) ? insert (op =) name
 
-    fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
+        fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
 
-    fun find_maxes _ (maxs, []) = map snd maxs
-      | find_maxes seen (maxs, new :: news) =
-        find_maxes (seen |> num_keys (Graph.imm_succs G new) > 1 ? Symtab.default (new, ()))
-          (if Symtab.defined tab new then
-             let
-               val newp = Graph.all_preds G [new]
-               fun is_ancestor x yp = member (op =) yp x
-               val maxs = maxs |> filter (fn (_, max) => not (is_ancestor max newp))
-             in
-               if exists (is_ancestor new o fst) maxs then (maxs, news)
-               else ((newp, new) :: filter_out (fn (_, max) => is_ancestor max newp) maxs, news)
-             end
-           else
-             (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news))
-  in
-    find_maxes Symtab.empty ([], Graph.maximals G)
-  end
+        fun find_maxes _ (maxs, []) = map snd maxs
+          | find_maxes seen (maxs, new :: news) =
+            find_maxes (seen |> num_keys (Graph.imm_succs G new) > 1 ? Symtab.default (new, ()))
+              (if Symtab.defined tab new then
+                 let
+                   val newp = Graph.all_preds G [new]
+                   fun is_ancestor x yp = member (op =) yp x
+                   val maxs = maxs |> filter (fn (_, max) => not (is_ancestor max newp))
+                 in
+                   if exists (is_ancestor new o fst) maxs then (maxs, news)
+                   else ((newp, new) :: filter_out (fn (_, max) => is_ancestor max newp) maxs, news)
+                 end
+               else
+                 (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news))
+      in
+        find_maxes Symtab.empty ([],
+          G |> Graph.restrict (not o String.isPrefix anonymous_proof_prefix) |> Graph.maximals)
+      end
 
 fun maximal_wrt_access_graph _ _ [] = []
   | maximal_wrt_access_graph ctxt access_G ((fact as (_, th)) :: facts) =
@@ -1259,8 +1266,9 @@
     Async_Manager_Legacy.thread MaShN birth_time death_time desc task
   end
 
-fun learned_proof_name () =
-  Date.fmt ".%Y%m%d.%H%M%S." (Date.fromTimeLocal (Time.now ())) ^ serial_string ()
+fun anonymous_proof_name () =
+  Date.fmt (anonymous_proof_prefix ^ "%Y%m%d.%H%M%S.") (Date.fromTimeLocal (Time.now ())) ^
+  serial_string ()
 
 fun mash_learn_proof ctxt ({timeout, ...} : params) t facts used_ths =
   if not (null used_ths) andalso is_mash_enabled () then
@@ -1278,7 +1286,7 @@
                  |> filter (is_fact_in_graph ctxt access_G)
                  |> map (nickname_of_thm ctxt)
 
-               val name = learned_proof_name ()
+               val name = anonymous_proof_name ()
                val (access_G', xtabs', rev_learns) =
                  add_node Automatic_Proof name parents feats deps (access_G, xtabs, [])