schedule: more precise task model;
improved error handling: first failure causes interrupt of all threads;
misc cleanup;
--- a/src/Pure/ML-Systems/multithreading_polyml.ML Thu Aug 09 23:53:49 2007 +0200
+++ b/src/Pure/ML-Systems/multithreading_polyml.ML Thu Aug 09 23:53:50 2007 +0200
@@ -7,6 +7,13 @@
open Thread;
+signature MULTITHREADING =
+sig
+ include MULTITHREADING
+ val uninterruptible: ('a -> 'b) -> 'a -> 'b
+ val interruptible: ('a -> 'b) -> 'a -> 'b
+end;
+
structure Multithreading: MULTITHREADING =
struct
@@ -24,15 +31,44 @@
(* misc utils *)
-fun show "" = ""
- | show name = " " ^ name;
+fun cons x xs = x :: xs;
-fun show' "" = ""
- | show' name = " [" ^ name ^ "]";
+fun change r f = r := f (! r);
fun inc i = (i := ! i + 1; ! i);
fun dec i = (i := ! i - 1; ! i);
+fun show "" = "" | show name = " " ^ name;
+fun show' "" = "" | show' name = " [" ^ name ^ "]";
+
+
+(* thread attributes *)
+
+local
+
+fun with_attributes new_atts f x =
+ let
+ val orig_atts = Thread.getAttributes ();
+ fun restore () = Thread.setAttributes orig_atts;
+ in
+ Exn.release
+ (let
+ val _ = Thread.setAttributes new_atts;
+ val result = Exn.capture f x;
+ val _ = restore ();
+ in result end
+ handle Interrupt => (restore (); Exn.Exn Interrupt))
+ end;
+
+fun with_interrupt_state state = with_attributes [Thread.InterruptState state];
+
+in
+
+fun uninterruptible f x = with_interrupt_state Thread.InterruptDefer f x;
+fun interruptible f x = with_interrupt_state Thread.InterruptAsynchOnce f x;
+
+end;
+
(* critical section -- may be nested within the same thread *)
@@ -52,33 +88,37 @@
fun NAMED_CRITICAL name e =
if self_critical () then e ()
else
- let
- val name' = ! critical_name;
- val _ =
- if Mutex.trylock critical_lock then ()
- else
- let
- val timer = Timer.startRealTimer ();
- val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
- val _ = Mutex.lock critical_lock;
- val time = Timer.checkRealTimer timer;
- val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
- "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
- in () end;
- val _ = critical_thread := SOME (Thread.self ());
- val _ = critical_name := name;
- val result = Exn.capture e ();
- val _ = critical_name := "";
- val _ = critical_thread := NONE;
- val _ = Mutex.unlock critical_lock;
- in Exn.release result end;
+ uninterruptible (fn () =>
+ let
+ val name' = ! critical_name;
+ val _ =
+ if Mutex.trylock critical_lock then ()
+ else
+ let
+ val timer = Timer.startRealTimer ();
+ val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
+ val _ = Mutex.lock critical_lock;
+ val time = Timer.checkRealTimer timer;
+ val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
+ "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
+ in () end;
+ val _ = critical_thread := SOME (Thread.self ());
+ val _ = critical_name := name;
+ val result = Exn.capture e ();
+ val _ = critical_name := "";
+ val _ = critical_thread := NONE;
+ val _ = Mutex.unlock critical_lock;
+ in Exn.release result end) ();
fun CRITICAL e = NAMED_CRITICAL "" e;
end;
-(* scheduling -- non-interruptible threads working on a queue of tasks *)
+(* scheduling -- multiple threads working on a queue of tasks *)
+
+datatype 'a task =
+ Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
local
@@ -86,7 +126,7 @@
in
-fun schedule n next_task tasks =
+fun schedule n next_task = uninterruptible (fn tasks =>
let
(*protected execution*)
val lock = Mutex.mutex ();
@@ -121,40 +161,47 @@
val (next, tasks') = next_task (! queue);
val _ = queue := tasks';
in
- if Task.is_running (#1 next) then
- (dec active; trace_active ();
- wait ();
- inc active; trace_active ();
- dequeue ())
- else next
+ (case next of Wait =>
+ (dec active; trace_active ();
+ wait ();
+ inc active; trace_active ();
+ dequeue ())
+ | _ => next)
end;
- (*worker threads*)
- val running = ref 0;
+ (*pool of running threads*)
val status = ref ([]: exn list);
- fun work () =
+ val running = ref ([]: Thread.thread list);
+ fun start f =
+ (inc active;
+ change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
+ fun stop () =
+ (dec active;
+ change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
+
+ (*worker thread*)
+ fun worker () =
(case PROTECTED "dequeue" dequeue of
- (Task.Task f, cont) =>
- (case Exn.capture f () of
- Exn.Result () => continue cont
+ Task {body, cont, fail} =>
+ (case Exn.capture (interruptible body) () of
+ Exn.Result () =>
+ (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
| Exn.Exn exn =>
- (PROTECTED "status" (fn () => status := exn :: ! status); continue cont))
- | (Task.Finished, _) =>
- (PROTECTED "running" (fn () => (dec active; dec running; wakeup_all ()))))
- and continue cont =
- (PROTECTED "cont" (fn () => (queue := cont (! queue); wakeup_all ())); work ());
+ PROTECTED "fail" (fn () =>
+ (change status (cons exn); change queue fail; stop (); wakeup_all ())))
+ | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
(*main control: fork and wait*)
fun fork 0 = ()
- | fork k =
- (inc running; inc active;
- Thread.fork (work, [Thread.InterruptState Thread.InterruptDefer]);
- fork (k - 1));
+ | fork k = (start worker; fork (k - 1));
val _ = PROTECTED "main" (fn () =>
(fork (Int.max (n, 1));
- while ! running <> 0 do (trace_active (); wait ())));
+ while not (List.null (! running)) do
+ (trace_active ();
+ if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else ();
+ wait ())));
- in ! status end;
+ in ! status end);
end;
@@ -162,3 +209,4 @@
val NAMED_CRITICAL = Multithreading.NAMED_CRITICAL;
val CRITICAL = Multithreading.CRITICAL;
+