use async manager to manage MaSh learners to make sure they get killed cleanly
authorblanchet
Wed, 18 Jul 2012 08:44:04 +0200
changeset 48319 340187063d84
parent 48318 325c8fd0d762
child 48320 891a24a48155
use async manager to manage MaSh learners to make sure they get killed cleanly
src/HOL/Tools/Sledgehammer/async_manager.ML
src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML
src/HOL/Tools/Sledgehammer/sledgehammer_isar.ML
src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML
src/HOL/Tools/Sledgehammer/sledgehammer_run.ML
--- a/src/HOL/Tools/Sledgehammer/async_manager.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/async_manager.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -8,12 +8,12 @@
 
 signature ASYNC_MANAGER =
 sig
-  val implode_desc : string * string -> string
   val break_into_chunks : string -> string list
   val launch :
     string -> Time.time -> Time.time -> string * string
     -> (unit -> bool * string) -> unit
   val kill_threads : string -> string -> unit
+  val has_running_threads : string -> bool
   val running_threads : string -> string -> unit
   val thread_messages : string -> string -> int option -> unit
 end;
@@ -23,29 +23,27 @@
 
 (** preferences **)
 
-val message_store_limit = 20;
-val message_display_limit = 10;
+val message_store_limit = 20
+val message_display_limit = 10
 
 
 (** thread management **)
 
-val implode_desc = op ^ o apfst quote
-
 fun implode_message (workers, work) =
-  space_implode " " (Try.serial_commas "and" (map quote workers)) ^ work
+  space_implode " " (Try.serial_commas "and" workers) ^ work
 
 
 (* data structures over threads *)
 
 structure Thread_Heap = Heap
 (
-  type elem = Time.time * Thread.thread;
-  fun ord ((a, _), (b, _)) = Time.compare (a, b);
-);
+  type elem = Time.time * Thread.thread
+  fun ord ((a, _), (b, _)) = Time.compare (a, b)
+)
 
-fun lookup_thread xs = AList.lookup Thread.equal xs;
-fun delete_thread xs = AList.delete Thread.equal xs;
-fun update_thread xs = AList.update Thread.equal xs;
+fun lookup_thread xs = AList.lookup Thread.equal xs
+fun delete_thread xs = AList.delete Thread.equal xs
+fun update_thread xs = AList.update Thread.equal xs
 
 
 (* state of thread manager *)
@@ -65,7 +63,7 @@
    canceling = canceling, messages = messages, store = store}
 
 val global_state = Synchronized.var "async_manager"
-  (make_state NONE Thread_Heap.empty [] [] [] []);
+  (make_state NONE Thread_Heap.empty [] [] [] [])
 
 
 (* unregister thread *)
@@ -76,22 +74,23 @@
     (case lookup_thread active thread of
       SOME (tool, _, _, desc as (worker, its_desc)) =>
         let
-          val active' = delete_thread thread active;
+          val active' = delete_thread thread active
           val now = Time.now ()
           val canceling' = (thread, (tool, now, desc)) :: canceling
-          val message' = (worker, its_desc ^ "\n" ^ message)
+          val message' =
+            (worker, its_desc ^ (if message = "" then "" else "\n" ^ message))
           val messages' = (urgent, (tool, message')) :: messages
           val store' = (tool, message') ::
             (if length store <= message_store_limit then store
-             else #1 (chop message_store_limit store));
+             else #1 (chop message_store_limit store))
         in make_state manager timeout_heap active' canceling' messages' store' end
-    | NONE => state));
+    | NONE => state))
 
 
 (* main manager thread -- only one may exist *)
 
-val min_wait_time = seconds 0.3;
-val max_wait_time = seconds 10.0;
+val min_wait_time = seconds 0.3
+val max_wait_time = seconds 10.0
 
 fun replace_all bef aft =
   let
@@ -119,7 +118,8 @@
                                      postponed_messages store))
   |> map (fn (_, (tool, (worker, work))) => ((tool, work), worker))
   |> AList.group (op =)
-  |> List.app (fn ((tool, work), workers) =>
+  |> List.app (fn ((_, ""), _) => ()
+                | ((tool, work), workers) =>
                   tool ^ ": " ^
                   implode_message (workers |> sort_distinct string_ord, work)
                   |> break_into_chunks
@@ -133,12 +133,12 @@
         fun time_limit timeout_heap =
           (case try Thread_Heap.min timeout_heap of
             NONE => Time.+ (Time.now (), max_wait_time)
-          | SOME (time, _) => time);
+          | SOME (time, _) => time)
 
         (*action: find threads whose timeout is reached, and interrupt canceling threads*)
         fun action {manager, timeout_heap, active, canceling, messages, store} =
           let val (timeout_threads, timeout_heap') =
-            Thread_Heap.upto (Time.now (), Thread.self ()) timeout_heap;
+            Thread_Heap.upto (Time.now (), Thread.self ()) timeout_heap
           in
             if null timeout_threads andalso null canceling then
               NONE
@@ -146,9 +146,9 @@
               let
                 val _ = List.app (Simple_Thread.interrupt_unsynchronized o #1) canceling
                 val canceling' = filter (Thread.isActive o #1) canceling
-                val state' = make_state manager timeout_heap' active canceling' messages store;
+                val state' = make_state manager timeout_heap' active canceling' messages store
               in SOME (map #2 timeout_threads, state') end
-          end;
+          end
       in
         while Synchronized.change_result global_state
           (fn state as {timeout_heap, active, canceling, messages, store, ...} =>
@@ -156,12 +156,13 @@
             then (false, make_state NONE timeout_heap active canceling messages store)
             else (true, state))
         do
-          (Synchronized.timed_access global_state (SOME o time_limit o #timeout_heap) action
-            |> these
-            |> List.app (unregister (false, "Timed out."));
-            print_new_messages ();
-            (*give threads some time to respond to interrupt*)
-            OS.Process.sleep min_wait_time)
+          (Synchronized.timed_access global_state
+               (SOME o time_limit o #timeout_heap) action
+           |> these
+           |> List.app (unregister (false, "Timed out."));
+           print_new_messages ();
+           (* give threads some time to respond to interrupt *)
+           OS.Process.sleep min_wait_time)
       end))
     in make_state manager timeout_heap active canceling messages store end)
 
@@ -172,9 +173,9 @@
  (Synchronized.change global_state
     (fn {manager, timeout_heap, active, canceling, messages, store} =>
       let
-        val timeout_heap' = Thread_Heap.insert (death_time, thread) timeout_heap;
-        val active' = update_thread (thread, (tool, birth_time, death_time, desc)) active;
-        val state' = make_state manager timeout_heap' active' canceling messages store;
+        val timeout_heap' = Thread_Heap.insert (death_time, thread) timeout_heap
+        val active' = update_thread (thread, (tool, birth_time, death_time, desc)) active
+        val state' = make_state manager timeout_heap' active' canceling messages store
       in state' end);
   check_thread_manager ())
 
@@ -200,33 +201,36 @@
         map_filter (fn (th, (tool', _, _, desc)) =>
                        if tool' = tool then SOME (th, (tool', Time.now (), desc))
                        else NONE) active
-      val state' = make_state manager timeout_heap [] (killing @ canceling) messages store;
+      val state' = make_state manager timeout_heap [] (killing @ canceling) messages store
       val _ =
         if null killing then ()
         else Output.urgent_message ("Interrupted active " ^ das_wort_worker ^ "s.")
-    in state' end);
+    in state' end)
 
 
 (* running threads *)
 
 fun seconds time = string_of_int (Time.toSeconds time) ^ " s"
 
+fun has_running_threads tool =
+  exists (fn (_, (tool', _, _, _)) => tool' = tool)
+         (#active (Synchronized.value global_state))
+
 fun running_threads tool das_wort_worker =
   let
-    val {active, canceling, ...} = Synchronized.value global_state;
-
-    val now = Time.now ();
+    val {active, canceling, ...} = Synchronized.value global_state
+    val now = Time.now ()
     fun running_info (_, (tool', birth_time, death_time, desc)) =
       if tool' = tool then
         SOME ("Running: " ^ seconds (Time.- (now, birth_time)) ^ " -- " ^
               seconds (Time.- (death_time, now)) ^ " to live:\n" ^
-              implode_desc desc)
+              op ^ desc)
       else
         NONE
     fun canceling_info (_, (tool', death_time, desc)) =
       if tool' = tool then
         SOME ("Trying to interrupt " ^ das_wort_worker ^ " since " ^
-              seconds (Time.- (now, death_time)) ^ ":\n" ^ implode_desc desc)
+              seconds (Time.- (now, death_time)) ^ ":\n" ^ op ^ desc)
       else
         NONE
     val running =
@@ -241,14 +245,14 @@
 
 fun thread_messages tool das_wort_worker opt_limit =
   let
-    val limit = the_default message_display_limit opt_limit;
+    val limit = the_default message_display_limit opt_limit
     val tool_store = Synchronized.value global_state
                      |> #store |> filter (curry (op =) tool o fst)
     val header =
       "Recent " ^ das_wort_worker ^ " messages" ^
         (if length tool_store <= limit then ":"
-         else " (" ^ string_of_int limit ^ " displayed):");
-    val ss = tool_store |> chop limit |> #1 |> map (implode_desc o snd)
+         else " (" ^ string_of_int limit ^ " displayed):")
+    val ss = tool_store |> chop limit |> #1 |> map (op ^ o snd)
   in List.app Output.urgent_message (header :: maps break_into_chunks ss) end
 
 end;
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -15,6 +15,7 @@
   type prover_result = Sledgehammer_Provers.prover_result
 
   val trace : bool Config.T
+  val MaShN : string
   val meshN : string
   val iterN : string
   val mashN : string
@@ -51,12 +52,15 @@
   val mash_suggest_facts :
     Proof.context -> params -> string -> int -> term list -> term -> fact list
     -> fact list
-  val mash_learn_thy : Proof.context -> params -> theory -> Time.time -> unit
+  val mash_learn_thy :
+    Proof.context -> params -> theory -> Time.time -> fact list -> string
   val mash_learn_proof :
     Proof.context -> params -> term -> thm list -> fact list -> unit
   val relevant_facts :
     Proof.context -> params -> string -> int -> fact_override -> term list
     -> term -> fact list -> fact list
+  val kill_learners : unit -> unit
+  val running_learners : unit -> unit
 end;
 
 structure Sledgehammer_Filter_MaSh : SLEDGEHAMMER_FILTER_MASH =
@@ -74,6 +78,8 @@
   Attrib.setup_config_bool @{binding sledgehammer_filter_mash_trace} (K false)
 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
 
+val MaShN = "MaSh"
+
 val meshN = "mesh"
 val iterN = "iter"
 val mashN = "mash"
@@ -481,8 +487,6 @@
   let
     val thy = Proof_Context.theory_of ctxt
     val fact_graph = #fact_graph (mash_get ())
-val _ = warning (PolyML.makestring (length (fact_graph |> Graph.keys), length (fact_graph |> Graph.maximals),
-length (fact_graph |> Graph.minimals))) (*###*)
     val parents = parents_wrt_facts facts fact_graph
     val feats = features_of ctxt prover thy General (concl_t :: hyp_ts)
     val suggs =
@@ -509,34 +513,23 @@
   in ((name, parents, feats, deps) :: upds, graph) end
 
 val pass1_learn_timeout_factor = 0.5
-val pass2_learn_timeout_factor = 10.0
 
 (* The timeout is understood in a very slack fashion. *)
-fun mash_learn_thy ctxt ({provers, verbose, overlord, ...} : params) thy
-                   timeout =
+fun mash_learn_thy ctxt ({provers, verbose, overlord, ...} : params) thy timeout
+                   facts =
   let
     val timer = Timer.startRealTimer ()
     val prover = hd provers
     fun timed_out frac =
       Time.> (Timer.checkRealTimer timer, time_mult frac timeout)
-    val css_table = clasimpset_rule_table_of ctxt
-    val facts = all_facts_of thy css_table
     val {fact_graph, ...} = mash_get ()
     fun is_old (_, th) = can (Graph.get_node fact_graph) (Thm.get_name_hint th)
     val new_facts = facts |> filter_out is_old |> sort (thm_ord o pairself snd)
   in
     if null new_facts then
-      ()
+      ""
     else
       let
-        val n = length new_facts
-        val _ =
-          if verbose then
-            "MaShing " ^ string_of_int n ^ " fact" ^ plural_s n ^
-            " (advisory timeout: " ^ string_from_time timeout ^ ")..."
-            |> Output.urgent_message
-          else
-            ()
         val ths = facts |> map snd
         val all_names =
           ths |> filter_out (is_likely_tautology ctxt prover orf is_too_meta)
@@ -566,22 +559,15 @@
               fact_graph = fact_graph})
           end
       in
-        TimeLimit.timeLimit (time_mult pass2_learn_timeout_factor timeout)
-                            mash_map trans
-        handle TimeLimit.TimeOut =>
-               (if verbose then
-                  "MaSh timed out trying to learn " ^ string_of_int n ^
-                  " fact" ^ plural_s n ^ " in " ^
-                  string_from_time (Timer.checkRealTimer timer) ^ "."
-                  |> Output.urgent_message
-                else
-                  ());
-        (if verbose then
-           "MaSh learned " ^ string_of_int n ^ " fact" ^ plural_s n ^ " in " ^
-           string_from_time (Timer.checkRealTimer timer) ^ "."
-           |> Output.urgent_message
-         else
-           ())
+        mash_map trans;
+        if verbose then
+          "Processed " ^ string_of_int n ^ " proof" ^ plural_s n ^
+          (if verbose then
+             " in " ^ string_from_time (Timer.checkRealTimer timer)
+           else
+             "") ^ "."
+        else
+          ""
       end
   end
 
@@ -623,19 +609,23 @@
     let
       val thy = Proof_Context.theory_of ctxt
       fun maybe_learn can_suggest =
-        if Time.toSeconds timeout >= min_secs_for_learning then
-          if Multithreading.enabled () then
-            let
-              val factor =
-                if can_suggest then short_learn_timeout_factor
-                else long_learn_timeout_factor
-            in
-              Future.fork (fn () => mash_learn_thy ctxt params thy
-                                        (time_mult factor timeout)); ()
-            end
-          else
-            mash_learn_thy ctxt params thy
-                           (time_mult short_learn_timeout_factor timeout)
+        if Async_Manager.has_running_threads MaShN orelse null facts then
+          ()
+        else if Time.toSeconds timeout >= min_secs_for_learning then
+          let
+            val factor =
+              if can_suggest then short_learn_timeout_factor
+              else long_learn_timeout_factor
+            val soft_timeout = time_mult factor timeout
+            val hard_timeout = time_mult 2.0 soft_timeout
+            val birth_time = Time.now ()
+            val death_time = Time.+ (birth_time, hard_timeout)
+            val desc = ("machine learner for Sledgehammer", "")
+          in
+            Async_Manager.launch MaShN birth_time death_time desc
+                (fn () =>
+                    (true, mash_learn_thy ctxt params thy soft_timeout facts))
+          end
         else
           ()
       val fact_filter =
@@ -667,4 +657,7 @@
       |> not (null add_ths) ? prepend_facts add_ths
     end
 
+fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
+fun running_learners () = Async_Manager.running_threads MaShN "learner"
+
 end;
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_isar.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_isar.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -33,10 +33,12 @@
 val minN = "min"
 val messagesN = "messages"
 val supported_proversN = "supported_provers"
+val kill_proversN = "kill_provers"
 val running_proversN = "running_provers"
-val kill_proversN = "kill_provers"
+val kill_learnersN = "kill_learners"
+val running_learnersN = "running_learners"
+val unlearnN = "unlearn"
 val refresh_tptpN = "refresh_tptp"
-val reset_mashN = "reset_mash"
 
 val auto = Unsynchronized.ref false
 
@@ -374,14 +376,18 @@
       messages opt_i
     else if subcommand = supported_proversN then
       supported_provers ctxt
-    else if subcommand = running_proversN then
-      running_provers ()
     else if subcommand = kill_proversN then
       kill_provers ()
+    else if subcommand = running_proversN then
+      running_provers ()
+    else if subcommand = kill_learnersN then
+      kill_learners ()
+    else if subcommand = running_learnersN then
+      running_learners ()
+    else if subcommand = unlearnN then
+      mash_reset ctxt
     else if subcommand = refresh_tptpN then
       refresh_systems_on_tptp ()
-    else if subcommand = reset_mashN then
-      mash_reset ctxt
     else
       error ("Unknown subcommand: " ^ quote subcommand ^ ".")
   end
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -99,7 +99,7 @@
   val smt_slice_fact_frac : real Config.T
   val smt_slice_time_frac : real Config.T
   val smt_slice_min_secs : int Config.T
-  val das_tool : string
+  val SledgehammerN : string
   val plain_metis : reconstructor
   val select_smt_solver : string -> Proof.context -> Proof.context
   val extract_reconstructor :
@@ -153,7 +153,7 @@
 
 (* Identifier that distinguishes Sledgehammer from other tools that could use
    "Async_Manager". *)
-val das_tool = "Sledgehammer"
+val SledgehammerN = "Sledgehammer"
 
 val reconstructor_names = [metisN, smtN]
 val plain_metis = Metis (hd partial_type_encs, combsN)
@@ -298,9 +298,9 @@
                            commas (local_provers @ remote_provers) ^ ".")
   end
 
-fun kill_provers () = Async_Manager.kill_threads das_tool "prover"
-fun running_provers () = Async_Manager.running_threads das_tool "prover"
-val messages = Async_Manager.thread_messages das_tool "prover"
+fun kill_provers () = Async_Manager.kill_threads SledgehammerN "prover"
+fun running_provers () = Async_Manager.running_threads SledgehammerN "prover"
+val messages = Async_Manager.thread_messages SledgehammerN "prover"
 
 
 (** problems, results, ATPs, etc. **)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -53,7 +53,7 @@
 
 fun prover_description ctxt ({verbose, blocking, ...} : params) name num_facts i
                        n goal =
-  (name,
+  (quote name,
    (if verbose then
       " with " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts
     else
@@ -141,7 +141,7 @@
         (outcome_code, state)
       end
     else
-      (Async_Manager.launch das_tool birth_time death_time (desc ())
+      (Async_Manager.launch SledgehammerN birth_time death_time (desc ())
                             ((fn (outcome_code, message) =>
                                  (verbose orelse outcome_code = someN,
                                   message ())) o go);