add_prover: plain prover function, without thread;
authorwenzelm
Tue, 14 Oct 2008 20:10:44 +0200
changeset 28595 67e3945b53f1
parent 28594 ed3351ff3f1b
child 28596 fcd463a6b6de
add_prover: plain prover function, without thread; removed obsolete atp_thread interface; moved kill_excessive into main thread manager loop -- avoids race condition wrt. register/unregister; start_prover: register/unregister self -- avoids race condition;
src/HOL/Tools/atp_manager.ML
--- a/src/HOL/Tools/atp_manager.ML	Tue Oct 14 20:10:43 2008 +0200
+++ b/src/HOL/Tools/atp_manager.ML	Tue Oct 14 20:10:44 2008 +0200
@@ -19,8 +19,8 @@
   val set_timeout: int -> unit
   val kill: unit -> unit
   val info: unit -> unit
-  val atp_thread: (unit -> 'a option) -> ('a -> string) -> Thread.thread
-  val add_prover: string -> (int -> Proof.state -> Thread.thread) -> theory -> theory
+  type prover = int -> Proof.state -> bool * string
+  val add_prover: string -> prover -> theory -> theory
   val print_provers: theory -> unit
   val sledgehammer: string list -> Proof.state -> unit
 end;
@@ -72,7 +72,7 @@
 (
   type elem = Time.time * Thread.thread;
   fun ord ((a, _), (b, _)) = Time.compare (a, b);
-)
+);
 
 val lookup_thread = AList.lookup Thread.equal;
 val delete_thread = AList.delete Thread.equal;
@@ -102,7 +102,7 @@
 
 (* unregister thread from thread manager -- move to cancelling *)
 
-fun unregister success message thread = Synchronized.change_result state
+fun unregister (success, message) thread = Synchronized.change_result state
   (fn State {timeout_heap, oldest_heap, active, cancelling} =>
     let
       val info = lookup_thread active thread
@@ -130,6 +130,35 @@
     in (message', make_state timeout_heap oldest_heap active'' cancelling'') end);
 
 
+(* kill excessive atp threads *)
+
+fun excessive_atps active =
+  let val max = get_max_atps ()
+  in length active > max andalso max > ~1 end;
+
+local
+
+fun kill_oldest () =
+  let exception Unchanged in
+    Synchronized.change_result state (fn State {timeout_heap, oldest_heap, active, cancelling} =>
+        if ThreadHeap.is_empty oldest_heap orelse not (excessive_atps active)
+        then raise Unchanged
+        else
+          let val ((_, oldest_thread), oldest_heap') = ThreadHeap.min_elem oldest_heap
+          in (oldest_thread, make_state timeout_heap oldest_heap' active cancelling) end)
+      |> (priority o unregister (false, "Interrupted (maximum number of ATPs exceeded)"))
+    handle Unchanged => ()
+  end;
+
+in
+
+fun kill_excessive () =
+  let val State {active, ...} = Synchronized.value state
+  in if excessive_atps active then (kill_oldest (); kill_excessive ()) else () end;
+
+end;
+
+
 (* start a watching thread which runs forever -- only one may exist *)
 
 fun check_thread_manager () = CRITICAL (fn () =>
@@ -150,7 +179,8 @@
         let val (timeout_threads, timeout_heap') =
           ThreadHeap.upto (Time.now (), Thread.self ()) timeout_heap
         in
-          if null timeout_threads andalso null cancelling then NONE
+          if null timeout_threads andalso null cancelling andalso not (excessive_atps active)
+          then NONE
           else
             let
               val _ = List.app (SimpleThread.interrupt o #1) cancelling
@@ -160,11 +190,11 @@
         end
     in
       while true do
-       ((* cancel threads found by action *)
-        Synchronized.timed_access state time_limit action
+       (Synchronized.timed_access state time_limit action
         |> these
-        |> List.app (priority o unregister false "Interrupted (reached timeout)");
-        (* give threads time to respond to interrupt *)
+        |> List.app (priority o unregister (false, "Interrupted (reached timeout)"));
+        kill_excessive ();
+        (*give threads time to respond to interrupt*)
         OS.Process.sleep min_wait_time)
     end)));
 
@@ -220,11 +250,13 @@
 
 (* named provers *)
 
+type prover = int -> Proof.state -> bool * string;
+
 fun err_dup_prover name = error ("Duplicate prover: " ^ quote name);
 
 structure Provers = TheoryDataFun
 (
-  type T = ((int -> Proof.state -> Thread.thread) * stamp) Symtab.table
+  type T = (prover * stamp) Symtab.table
   val empty = Symtab.empty
   val copy = I
   val extend = I
@@ -232,79 +264,44 @@
     handle Symtab.DUP dup => err_dup_prover dup
 );
 
-fun add_prover name prover_fn thy =
-  Provers.map (Symtab.update_new (name, (prover_fn, stamp ()))) thy
+fun add_prover name prover thy =
+  Provers.map (Symtab.update_new (name, (prover, stamp ()))) thy
     handle Symtab.DUP dup => err_dup_prover dup;
 
 fun print_provers thy = Pretty.writeln
   (Pretty.strs ("external provers:" :: sort_strings (Symtab.keys (Provers.get thy))));
 
-fun prover_desc state subgoal name =
-  let val (ctxt, (_, goal)) = Proof.get_goal state in
-    "external prover " ^ quote name ^ " for subgoal " ^ string_of_int subgoal ^ ":\n" ^
-      Syntax.string_of_term ctxt (Thm.term_of (Thm.cprem_of goal subgoal))
-  end;
 
-
-(* thread wrapping an atp-call *)
-
-fun atp_thread call_prover produce_answer =
-  SimpleThread.fork true (fn () =>
-    let
-      val result = call_prover ()
-      val message = case result of NONE => "Failed."
-          | SOME result => "Try this command: " ^ produce_answer result
-    in priority (unregister (is_some result) message (Thread.self ()))
-    end handle Interrupt => ());
-
-fun run_prover state subgoal name =
-  (case Symtab.lookup (Provers.get (Proof.theory_of state)) name of
-    NONE => (warning ("Unknown external prover: " ^ quote name); NONE)
-  | SOME (prover_fn, _) => SOME (prover_fn subgoal state, prover_desc state subgoal name));
-
-
-(* kill excessive atp threads *)
+(* start prover thread *)
 
-local
-
-fun excessive_atps active =
-  let val max = get_max_atps ()
-  in length active > max andalso max > ~1 end;
-
-fun kill_oldest () =
-  let exception Unchanged in
-    Synchronized.change_result state (fn State {timeout_heap, oldest_heap, active, cancelling} =>
-        if ThreadHeap.is_empty oldest_heap orelse not (excessive_atps active)
-        then raise Unchanged
-        else
-          let val ((_, oldest_thread), oldest_heap') = ThreadHeap.min_elem oldest_heap
-          in (oldest_thread, make_state timeout_heap oldest_heap' active cancelling) end)
-      |> (priority o unregister false "Interrupted (maximum number of ATPs exceeded)")
-    handle Unchanged => ()
-  end;
-
-in
-
-fun kill_excessive () =
-  let val State {active, ...} = Synchronized.value state
-  in if excessive_atps active then (kill_oldest (); kill_excessive ()) else () end;
-
-end;
+fun start_prover name birthtime deadtime i proof_state =
+  (case Symtab.lookup (Provers.get (Proof.theory_of proof_state)) name of
+    NONE => warning ("Unknown external prover: " ^ quote name)
+  | SOME (prover, _) =>
+      let
+        val (ctxt, (_, goal)) = Proof.get_goal proof_state
+        val desc =
+          "external prover " ^ quote name ^ " for subgoal " ^ string_of_int i ^ ":\n" ^
+            Syntax.string_of_term ctxt (Thm.term_of (Thm.cprem_of goal i))
+        val _ = SimpleThread.fork true (fn () =>
+          let
+            val _ = register birthtime deadtime (Thread.self (), desc)
+            val result = prover i proof_state
+            val _ = priority (unregister result (Thread.self ()))
+          in () end handle Interrupt => ())
+      in () end);
 
 
 (* sledghammer for first subgoal *)
 
 fun sledgehammer names proof_state =
   let
-    val proverids =
+    val provers =
       if null names then String.tokens (Symbol.is_ascii_blank o String.str) (get_atps ())
       else names
-    val threads_names = map_filter (run_prover proof_state 1) proverids
     val birthtime = Time.now ()
-    val deadtime = Time.+ (Time.now (), Time.fromSeconds (get_timeout ()))
-    val _ = List.app (register birthtime deadtime) threads_names
-    val _ = kill_excessive ()
-  in () end;
+    val deadtime = Time.+ (birthtime, Time.fromSeconds (get_timeout ()))
+  in List.app (fn name => start_prover name birthtime deadtime 1 proof_state) provers end;