src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 48407 47fe0ca12fc2
parent 48406 b002cc16aa99
child 48408 5493e67982ee
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jul 20 22:19:46 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jul 20 22:19:46 2012 +0200
@@ -184,7 +184,8 @@
   in map_filter find_sugg suggs end
 
 fun sum_avg [] = 0
-  | sum_avg xs = Real.ceil (100000.0 * fold (curry (op +)) xs 0.0) div length xs
+  | sum_avg xs =
+    Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
 
 fun normalize_scores [] = []
   | normalize_scores ((fact, score) :: tail) =
@@ -562,32 +563,37 @@
 fun mash_could_suggest_facts () = mash_home () <> ""
 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
 
-fun queue_of xs = Queue.empty |> fold Queue.enqueue xs
+fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
 
-fun max_facts_in_graph fact_G facts =
+fun maximal_in_graph fact_G facts =
   let
     val facts = [] |> fold (cons o nickname_of o snd) facts
-    val tab = Symtab.empty |> fold (fn name => Symtab.update (name, ())) facts
-    fun enqueue_new seen name =
-      not (member (op =) seen name) ? Queue.enqueue name
-    fun find_maxes seen maxs names =
-      case try Queue.dequeue names of
-        NONE => map snd maxs
-      | SOME (name, names) =>
-        if Symtab.defined tab name then
-          let
-            val new = (Graph.all_preds fact_G [name], name)
-            fun is_ancestor (_, x) (yp, _) = member (op =) yp x
-            val maxs = maxs |> filter (fn max => not (is_ancestor max new))
-            val maxs =
-              if exists (is_ancestor new) maxs then maxs
-              else new :: filter_out (fn max => is_ancestor max new) maxs
-          in find_maxes (name :: seen) maxs names end
-        else
-          find_maxes (name :: seen) maxs
-                     (Graph.Keys.fold (enqueue_new seen)
-                                      (Graph.imm_preds fact_G name) names)
-  in find_maxes [] [] (queue_of (Graph.maximals fact_G)) end
+    val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) facts
+    fun insert_new seen name =
+      not (Symtab.defined seen name) ? insert (op =) name
+    fun find_maxes _ (maxs, []) = map snd maxs
+      | find_maxes seen (maxs, new :: news) =
+        find_maxes
+            (seen |> num_keys (Graph.imm_succs fact_G new) > 1
+                     ? Symtab.default (new, ()))
+            (if Symtab.defined tab new then
+               let
+                 val newp = Graph.all_preds fact_G [new]
+                 fun is_ancestor x yp = member (op =) yp x
+                 val maxs =
+                   maxs |> filter (fn (_, max) => not (is_ancestor max newp))
+               in
+                 if exists (is_ancestor new o fst) maxs then
+                   (maxs, news)
+                 else
+                   ((newp, new)
+                    :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
+                    news)
+               end
+             else
+               (maxs, Graph.Keys.fold (insert_new seen)
+                                      (Graph.imm_preds fact_G new) news))
+  in find_maxes Symtab.empty ([], Graph.maximals fact_G) end
 
 (* Generate more suggestions than requested, because some might be thrown out
    later for various reasons and "meshing" gives better results with some
@@ -602,7 +608,7 @@
   let
     val thy = Proof_Context.theory_of ctxt
     val fact_G = #fact_G (mash_get ctxt)
-    val parents = max_facts_in_graph fact_G facts
+    val parents = maximal_in_graph fact_G facts
     val feats = features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
     val suggs =
       if Graph.is_empty fact_G then []
@@ -618,7 +624,7 @@
           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
     val graph = graph |> Graph.default_node (name, ())
     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
-    val (deps, graph) = ([], graph) |> fold maybe_add_from deps
+    val (deps, _) = ([], graph) |> fold maybe_add_from deps
   in ((name, parents, feats, deps) :: adds, graph) end
 
 val learn_timeout_slack = 2.0
@@ -647,7 +653,7 @@
           val feats = features_of ctxt prover thy (Local, General) [t]
           val deps = used_ths |> map nickname_of
           val {fact_G} = mash_get ctxt
-          val parents = max_facts_in_graph fact_G facts
+          val parents = timeit (fn () => maximal_in_graph fact_G facts)
         in
           mash_ADD ctxt overlord [(name, parents, feats, deps)]; (true, "")
         end)
@@ -743,7 +749,7 @@
               val ancestors =
                 old_facts
                 |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
-              val parents = max_facts_in_graph fact_G ancestors
+              val parents = maximal_in_graph fact_G ancestors
               val (adds, (_, n, _, _)) =
                 ([], (parents, 0, next_commit_time (), false))
                 |> fold learn_new_fact new_facts
@@ -853,10 +859,13 @@
         case fact_filter of
           SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
         | NONE =>
-          if is_smt_prover ctxt prover then mepoN
-          else if mash_can_suggest_facts ctxt then (maybe_learn (); meshN)
-          else if mash_could_suggest_facts () then (maybe_learn (); mepoN)
-          else mepoN
+          if is_smt_prover ctxt prover then
+            mepoN
+          else if mash_could_suggest_facts () then
+            (maybe_learn ();
+             if mash_can_suggest_facts ctxt then meshN else mepoN)
+          else
+            mepoN
       val add_ths = Attrib.eval_thms ctxt add
       fun prepend_facts ths accepts =
         ((facts |> filter (member Thm.eq_thm_prop ths o snd)) @