tuned code: factored out parent computation
authorblanchet
Mon, 18 Feb 2013 11:33:43 +0100
changeset 51177 e8c9755fd14e
parent 51176 407b0258464b
child 51178 06689dbfe072
tuned code: factored out parent computation
src/HOL/TPTP/mash_export.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/TPTP/mash_export.ML	Mon Feb 18 10:43:36 2013 +0100
+++ b/src/HOL/TPTP/mash_export.ML	Mon Feb 18 11:33:43 2013 +0100
@@ -104,7 +104,7 @@
             | NONE => isar_dependencies_of name_tabs th
         in (if null deps then unprovable_marker else isar_marker, deps) end
   in
-    case trim_dependencies th deps of
+    case trim_dependencies deps of
       SOME deps => (marker, deps)
     | NONE => (omitted_marker, [])
   end
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon Feb 18 10:43:36 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon Feb 18 11:33:43 2013 +0100
@@ -62,7 +62,7 @@
   val features_of :
     Proof.context -> string -> theory -> stature -> term list
     -> (string * real) list
-  val trim_dependencies : thm -> string list -> string list option
+  val trim_dependencies : string list -> string list option
   val isar_dependencies_of :
     string Symtab.table * string Symtab.table -> thm -> string list
   val prover_dependencies_of :
@@ -238,7 +238,7 @@
    sense. *)
 fun extract_suggestion sugg =
   case space_explode "=" sugg of
-    [name, weight] =>
+    [name, _ (* weight *)] =>
     SOME (unencode_str name (* , Real.fromString weight |> the_default 1.0 *))
   | [name] => SOME (unencode_str name (* , 1.0 *))
   | _ => NONE
@@ -679,7 +679,7 @@
      | NONE => false)
   | is_size_def _ _ = false
 
-fun trim_dependencies th deps =
+fun trim_dependencies deps =
   if length deps > max_dependencies then NONE else SOME deps
 
 fun isar_dependencies_of name_tabs th =
@@ -746,13 +746,12 @@
 
 (*** High-level communication with MaSh ***)
 
-fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
-
 fun maximal_wrt_graph G keys =
   let
     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
     fun insert_new seen name =
       not (Symtab.defined seen name) ? insert (op =) name
+    fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
     fun find_maxes _ (maxs, []) = map snd maxs
       | find_maxes seen (maxs, new :: news) =
         find_maxes
@@ -901,6 +900,13 @@
         (true, "")
       end)
 
+fun attach_parents_to_facts facts =
+  let
+    fun do_facts _ [] = []
+      | do_facts parents ((fact as (_, th)) :: facts) =
+        (parents, fact) :: do_facts [nickname_of_thm th] facts
+  in do_facts [] facts end
+
 fun sendback sub = Active.sendback_markup (sledgehammerN ^ " " ^ sub)
 
 val commit_timeout = seconds 30.0
@@ -913,11 +919,11 @@
     fun next_commit_time () =
       Time.+ (Timer.checkRealTimer timer, commit_timeout)
     val {access_G, ...} = peek_state ctxt I
+    val is_in_access_G = is_fact_in_graph access_G snd
     val facts = facts |> sort (crude_thm_ord o pairself snd)
-    val (old_facts, new_facts) =
-      facts |> List.partition (is_fact_in_graph access_G snd)
+    val no_new_facts = forall is_in_access_G facts
   in
-    if null new_facts andalso (not run_prover orelse null old_facts) then
+    if no_new_facts andalso not run_prover then
       if auto_level < 2 then
         "No new " ^ (if run_prover then "automatic" else "Isar") ^
         " proofs to learn." ^
@@ -938,10 +944,10 @@
             prover_dependencies_of ctxt params prover auto_level facts name_tabs
                                    th
             |> (fn (false, _) => NONE
-                 | (true, deps) => trim_dependencies th deps)
+                 | (true, deps) => trim_dependencies deps)
           else
             isar_dependencies_of name_tabs th
-            |> trim_dependencies th
+            |> trim_dependencies
         fun do_commit [] [] [] state = state
           | do_commit learns relearns flops {access_G, dirty} =
             let
@@ -976,9 +982,9 @@
              end
            else
              ())
-        fun learn_new_fact _ (accum as (_, (_, _, _, true))) = accum
-          | learn_new_fact ((_, stature as (_, status)), th)
-                           (learns, (parents, n, next_commit, _)) =
+        fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
+          | learn_new_fact (parents, ((_, stature as (_, status)), th))
+                           (learns, (n, next_commit, _)) =
             let
               val name = nickname_of_thm th
               val feats =
@@ -995,21 +1001,18 @@
                 case learn_timeout of
                   SOME timeout => Time.> (Timer.checkRealTimer timer, timeout)
                 | NONE => false
-            in (learns, ([name], n, next_commit, timed_out)) end
+            in (learns, (n, next_commit, timed_out)) end
         val n =
-          if null new_facts then
+          if no_new_facts then
             0
           else
             let
-              val last_th = new_facts |> List.last |> snd
-              (* crude approximation *)
-              val ancestors =
-                old_facts
-                |> filter (fn (_, th) => crude_thm_ord (th, last_th) <> GREATER)
-              val parents = maximal_wrt_access_graph access_G ancestors
-              val (learns, (_, n, _, _)) =
-                ([], (parents, 0, next_commit_time (), false))
-                |> fold learn_new_fact new_facts
+              val facts =
+                facts |> attach_parents_to_facts
+                      |> filter_out (is_in_access_G o snd)
+              val (learns, (n, _, _)) =
+                ([], (0, next_commit_time (), false))
+                |> fold learn_new_fact facts
             in commit true learns [] []; n end
         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
           | relearn_old_fact ((_, (_, status)), th)
@@ -1032,7 +1035,7 @@
                 | NONE => false
             in ((relearns, flops), (n, next_commit, timed_out)) end
         val n =
-          if not run_prover orelse null old_facts then
+          if not run_prover then
             n
           else
             let
@@ -1048,9 +1051,10 @@
                    | Isar_Proof_wegen_Prover_Flop => max_isar)
                 - 500 * length (isar_dependencies_of name_tabs th)
               val old_facts =
-                old_facts |> map (`priority_of)
-                          |> sort (int_ord o pairself fst)
-                          |> map snd
+                facts |> filter is_in_access_G
+                      |> map (`priority_of)
+                      |> sort (int_ord o pairself fst)
+                      |> map snd
               val ((relearns, flops), (n, _, _)) =
                 (([], []), (n, next_commit_time (), false))
                 |> fold relearn_old_fact old_facts