# HG changeset patch # User wenzelm # Date 1186696430 -7200 # Node ID f4cafbaa05e49152b83581c3352fcb165e82072d # Parent 402d629925ed9cfcdf2edd9686b22cb52003dee7 schedule: more precise task model; improved error handling: first failure causes interrupt of all threads; misc cleanup; diff -r 402d629925ed -r f4cafbaa05e4 src/Pure/ML-Systems/multithreading_polyml.ML --- a/src/Pure/ML-Systems/multithreading_polyml.ML Thu Aug 09 23:53:49 2007 +0200 +++ b/src/Pure/ML-Systems/multithreading_polyml.ML Thu Aug 09 23:53:50 2007 +0200 @@ -7,6 +7,13 @@ open Thread; +signature MULTITHREADING = +sig + include MULTITHREADING + val uninterruptible: ('a -> 'b) -> 'a -> 'b + val interruptible: ('a -> 'b) -> 'a -> 'b +end; + structure Multithreading: MULTITHREADING = struct @@ -24,15 +31,44 @@ (* misc utils *) -fun show "" = "" - | show name = " " ^ name; +fun cons x xs = x :: xs; -fun show' "" = "" - | show' name = " [" ^ name ^ "]"; +fun change r f = r := f (! r); fun inc i = (i := ! i + 1; ! i); fun dec i = (i := ! i - 1; ! i); +fun show "" = "" | show name = " " ^ name; +fun show' "" = "" | show' name = " [" ^ name ^ "]"; + + +(* thread attributes *) + +local + +fun with_attributes new_atts f x = + let + val orig_atts = Thread.getAttributes (); + fun restore () = Thread.setAttributes orig_atts; + in + Exn.release + (let + val _ = Thread.setAttributes new_atts; + val result = Exn.capture f x; + val _ = restore (); + in result end + handle Interrupt => (restore (); Exn.Exn Interrupt)) + end; + +fun with_interrupt_state state = with_attributes [Thread.InterruptState state]; + +in + +fun uninterruptible f x = with_interrupt_state Thread.InterruptDefer f x; +fun interruptible f x = with_interrupt_state Thread.InterruptAsynchOnce f x; + +end; + (* critical section -- may be nested within the same thread *) @@ -52,33 +88,37 @@ fun NAMED_CRITICAL name e = if self_critical () then e () else - let - val name' = ! critical_name; - val _ = - if Mutex.trylock critical_lock then () - else - let - val timer = Timer.startRealTimer (); - val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting"); - val _ = Mutex.lock critical_lock; - val time = Timer.checkRealTimer timer; - val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () => - "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time); - in () end; - val _ = critical_thread := SOME (Thread.self ()); - val _ = critical_name := name; - val result = Exn.capture e (); - val _ = critical_name := ""; - val _ = critical_thread := NONE; - val _ = Mutex.unlock critical_lock; - in Exn.release result end; + uninterruptible (fn () => + let + val name' = ! critical_name; + val _ = + if Mutex.trylock critical_lock then () + else + let + val timer = Timer.startRealTimer (); + val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting"); + val _ = Mutex.lock critical_lock; + val time = Timer.checkRealTimer timer; + val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () => + "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time); + in () end; + val _ = critical_thread := SOME (Thread.self ()); + val _ = critical_name := name; + val result = Exn.capture e (); + val _ = critical_name := ""; + val _ = critical_thread := NONE; + val _ = Mutex.unlock critical_lock; + in Exn.release result end) (); fun CRITICAL e = NAMED_CRITICAL "" e; end; -(* scheduling -- non-interruptible threads working on a queue of tasks *) +(* scheduling -- multiple threads working on a queue of tasks *) + +datatype 'a task = + Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate; local @@ -86,7 +126,7 @@ in -fun schedule n next_task tasks = +fun schedule n next_task = uninterruptible (fn tasks => let (*protected execution*) val lock = Mutex.mutex (); @@ -121,40 +161,47 @@ val (next, tasks') = next_task (! queue); val _ = queue := tasks'; in - if Task.is_running (#1 next) then - (dec active; trace_active (); - wait (); - inc active; trace_active (); - dequeue ()) - else next + (case next of Wait => + (dec active; trace_active (); + wait (); + inc active; trace_active (); + dequeue ()) + | _ => next) end; - (*worker threads*) - val running = ref 0; + (*pool of running threads*) val status = ref ([]: exn list); - fun work () = + val running = ref ([]: Thread.thread list); + fun start f = + (inc active; + change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer])))); + fun stop () = + (dec active; + change running (List.filter (fn t => not (Thread.equal (t, Thread.self ()))))); + + (*worker thread*) + fun worker () = (case PROTECTED "dequeue" dequeue of - (Task.Task f, cont) => - (case Exn.capture f () of - Exn.Result () => continue cont + Task {body, cont, fail} => + (case Exn.capture (interruptible body) () of + Exn.Result () => + (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ()) | Exn.Exn exn => - (PROTECTED "status" (fn () => status := exn :: ! status); continue cont)) - | (Task.Finished, _) => - (PROTECTED "running" (fn () => (dec active; dec running; wakeup_all ())))) - and continue cont = - (PROTECTED "cont" (fn () => (queue := cont (! queue); wakeup_all ())); work ()); + PROTECTED "fail" (fn () => + (change status (cons exn); change queue fail; stop (); wakeup_all ()))) + | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ()))); (*main control: fork and wait*) fun fork 0 = () - | fork k = - (inc running; inc active; - Thread.fork (work, [Thread.InterruptState Thread.InterruptDefer]); - fork (k - 1)); + | fork k = (start worker; fork (k - 1)); val _ = PROTECTED "main" (fn () => (fork (Int.max (n, 1)); - while ! running <> 0 do (trace_active (); wait ()))); + while not (List.null (! running)) do + (trace_active (); + if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else (); + wait ()))); - in ! status end; + in ! status end); end; @@ -162,3 +209,4 @@ val NAMED_CRITICAL = Multithreading.NAMED_CRITICAL; val CRITICAL = Multithreading.CRITICAL; +