schedule: more precise task model;
authorwenzelm
Thu, 09 Aug 2007 23:53:50 +0200
changeset 24208 f4cafbaa05e4
parent 24207 402d629925ed
child 24209 8a2c8d623e43
schedule: more precise task model; improved error handling: first failure causes interrupt of all threads; misc cleanup;
src/Pure/ML-Systems/multithreading_polyml.ML
--- a/src/Pure/ML-Systems/multithreading_polyml.ML	Thu Aug 09 23:53:49 2007 +0200
+++ b/src/Pure/ML-Systems/multithreading_polyml.ML	Thu Aug 09 23:53:50 2007 +0200
@@ -7,6 +7,13 @@
 
 open Thread;
 
+signature MULTITHREADING =
+sig
+  include MULTITHREADING
+  val uninterruptible: ('a -> 'b) -> 'a -> 'b
+  val interruptible: ('a -> 'b) -> 'a -> 'b
+end;
+
 structure Multithreading: MULTITHREADING =
 struct
 
@@ -24,15 +31,44 @@
 
 (* misc utils *)
 
-fun show "" = ""
-  | show name = " " ^ name;
+fun cons x xs = x :: xs;
 
-fun show' "" = ""
-  | show' name = " [" ^ name ^ "]";
+fun change r f = r := f (! r);
 
 fun inc i = (i := ! i + 1; ! i);
 fun dec i = (i := ! i - 1; ! i);
 
+fun show "" = "" | show name = " " ^ name;
+fun show' "" = "" | show' name = " [" ^ name ^ "]";
+
+
+(* thread attributes *)
+
+local
+
+fun with_attributes new_atts f x =
+  let
+    val orig_atts = Thread.getAttributes ();
+    fun restore () = Thread.setAttributes orig_atts;
+  in
+    Exn.release
+     (let
+        val _ = Thread.setAttributes new_atts;
+        val result = Exn.capture f x;
+        val _ = restore ();
+      in result end
+      handle Interrupt => (restore (); Exn.Exn Interrupt))
+  end;
+
+fun with_interrupt_state state = with_attributes [Thread.InterruptState state];
+
+in
+
+fun uninterruptible f x = with_interrupt_state Thread.InterruptDefer f x;
+fun interruptible f x = with_interrupt_state Thread.InterruptAsynchOnce f x;
+
+end;
+
 
 (* critical section -- may be nested within the same thread *)
 
@@ -52,33 +88,37 @@
 fun NAMED_CRITICAL name e =
   if self_critical () then e ()
   else
-    let
-      val name' = ! critical_name;
-      val _ =
-        if Mutex.trylock critical_lock then ()
-        else
-          let
-            val timer = Timer.startRealTimer ();
-            val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
-            val _ = Mutex.lock critical_lock;
-            val time = Timer.checkRealTimer timer;
-            val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
-              "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
-          in () end;
-      val _ = critical_thread := SOME (Thread.self ());
-      val _ = critical_name := name;
-      val result = Exn.capture e ();
-      val _ = critical_name := "";
-      val _ = critical_thread := NONE;
-      val _ = Mutex.unlock critical_lock;
-    in Exn.release result end;
+    uninterruptible (fn () =>
+      let
+        val name' = ! critical_name;
+        val _ =
+          if Mutex.trylock critical_lock then ()
+          else
+            let
+              val timer = Timer.startRealTimer ();
+              val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
+              val _ = Mutex.lock critical_lock;
+              val time = Timer.checkRealTimer timer;
+              val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
+                "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
+            in () end;
+        val _ = critical_thread := SOME (Thread.self ());
+        val _ = critical_name := name;
+        val result = Exn.capture e ();
+        val _ = critical_name := "";
+        val _ = critical_thread := NONE;
+        val _ = Mutex.unlock critical_lock;
+      in Exn.release result end) ();
 
 fun CRITICAL e = NAMED_CRITICAL "" e;
 
 end;
 
 
-(* scheduling -- non-interruptible threads working on a queue of tasks *)
+(* scheduling -- multiple threads working on a queue of tasks *)
+
+datatype 'a task =
+  Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
 
 local
 
@@ -86,7 +126,7 @@
 
 in
 
-fun schedule n next_task tasks =
+fun schedule n next_task = uninterruptible (fn tasks =>
   let
     (*protected execution*)
     val lock = Mutex.mutex ();
@@ -121,40 +161,47 @@
         val (next, tasks') = next_task (! queue);
         val _ = queue := tasks';
       in
-        if Task.is_running (#1 next) then
-         (dec active; trace_active ();
-          wait ();
-          inc active; trace_active ();
-          dequeue ())
-        else next
+        (case next of Wait =>
+          (dec active; trace_active ();
+            wait ();
+            inc active; trace_active ();
+            dequeue ())
+        | _ => next)
       end;
 
-    (*worker threads*)
-    val running = ref 0;
+    (*pool of running threads*)
     val status = ref ([]: exn list);
-    fun work () =
+    val running = ref ([]: Thread.thread list);
+    fun start f =
+      (inc active;
+       change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
+    fun stop () =
+      (dec active;
+       change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
+
+   (*worker thread*)
+    fun worker () =
       (case PROTECTED "dequeue" dequeue of
-        (Task.Task f, cont) =>
-          (case Exn.capture f () of
-            Exn.Result () => continue cont
+        Task {body, cont, fail} =>
+          (case Exn.capture (interruptible body) () of
+            Exn.Result () =>
+              (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
           | Exn.Exn exn =>
-              (PROTECTED "status" (fn () => status := exn :: ! status); continue cont))
-      | (Task.Finished, _) =>
-         (PROTECTED "running" (fn () => (dec active; dec running; wakeup_all ()))))
-    and continue cont =
-      (PROTECTED "cont" (fn () => (queue := cont (! queue); wakeup_all ())); work ());
+              PROTECTED "fail" (fn () =>
+                (change status (cons exn); change queue fail; stop (); wakeup_all ())))
+      | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
 
     (*main control: fork and wait*)
     fun fork 0 = ()
-      | fork k =
-         (inc running; inc active;
-          Thread.fork (work, [Thread.InterruptState Thread.InterruptDefer]);
-          fork (k - 1));
+      | fork k = (start worker; fork (k - 1));
     val _ = PROTECTED "main" (fn () =>
      (fork (Int.max (n, 1));
-      while ! running <> 0 do (trace_active (); wait ())));
+      while not (List.null (! running)) do
+      (trace_active ();
+       if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else ();
+       wait ())));
 
-  in ! status end;
+  in ! status end);
 
 end;
 
@@ -162,3 +209,4 @@
 
 val NAMED_CRITICAL = Multithreading.NAMED_CRITICAL;
 val CRITICAL = Multithreading.CRITICAL;
+