src/Pure/ML-Systems/multithreading_polyml.ML
author wenzelm
Fri Sep 21 22:51:13 2007 +0200 (2007-09-21 ago)
changeset 24672 f311717d1f03
parent 24668 4058b7b0925c
child 24688 a5754ca5c510
permissions -rw-r--r--
tuned;
     1 (*  Title:      Pure/ML-Systems/multithreading_polyml.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Multithreading in Poly/ML 5.1 or later (cf. polyml/basis/Thread.sml).
     6 *)
     7 
     8 open Thread;
     9 
    10 signature MULTITHREADING =
    11 sig
    12   include MULTITHREADING
    13   val ignore_interrupt: ('a -> 'b) -> 'a -> 'b
    14   val raise_interrupt: ('a -> 'b) -> 'a -> 'b
    15   val interrupt_timeout: Time.time -> ('a -> 'b) -> 'a -> 'b
    16 end;
    17 
    18 structure Multithreading: MULTITHREADING =
    19 struct
    20 
    21 (* options *)
    22 
    23 val trace = ref 0;
    24 fun tracing level msg =
    25   if level <= ! trace
    26   then (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
    27   else ();
    28 
    29 val available = true;
    30 val max_threads = ref 1;
    31 
    32 
    33 (* misc utils *)
    34 
    35 fun cons x xs = x :: xs;
    36 
    37 fun change r f = r := f (! r);
    38 
    39 fun inc i = (i := ! i + 1; ! i);
    40 fun dec i = (i := ! i - 1; ! i);
    41 
    42 fun show "" = "" | show name = " " ^ name;
    43 fun show' "" = "" | show' name = " [" ^ name ^ "]";
    44 
    45 
    46 (* thread attributes *)
    47 
    48 fun with_attributes new_atts f x =
    49   let
    50     val orig_atts = Thread.getAttributes ();
    51     fun restore () = Thread.setAttributes orig_atts;
    52   in
    53     Exn.release
    54     (*RACE for fully asynchronous interrupts!*)
    55     (let
    56         val _ = Thread.setAttributes new_atts;
    57         val result = Exn.capture (f orig_atts) x;
    58         val _ = restore ();
    59       in result end
    60       handle Interrupt => (restore (); Exn.Exn Interrupt))
    61   end;
    62 
    63 
    64 (* interrupt handling *)
    65 
    66 fun uninterruptible f x = with_attributes
    67   [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptDefer] f x;
    68 
    69 fun interruptible f x = with_attributes
    70   [Thread.EnableBroadcastInterrupt true, Thread.InterruptState Thread.InterruptAsynchOnce] f x;
    71 
    72 fun ignore_interrupt f = uninterruptible (fn _ => f);
    73 fun raise_interrupt f = interruptible (fn _ => f);
    74 
    75 fun interrupt_timeout time f x =
    76   uninterruptible (fn atts => fn () =>
    77     let
    78       val worker = Thread.self ();
    79       val watchdog = Thread.fork (interruptible (fn _ => fn () =>
    80         (OS.Process.sleep time; Thread.interrupt worker)), []);
    81       val result = Exn.capture (with_attributes atts (fn _ => f)) x;
    82       val _ = Thread.interrupt watchdog handle Thread _ => ();
    83     in Exn.release result end) ();
    84 
    85 
    86 (* critical section -- may be nested within the same thread *)
    87 
    88 local
    89 
    90 val critical_lock = Mutex.mutex ();
    91 val critical_thread = ref (NONE: Thread.thread option);
    92 val critical_name = ref "";
    93 
    94 in
    95 
    96 fun self_critical () =
    97   (case ! critical_thread of
    98     NONE => false
    99   | SOME id => Thread.equal (id, Thread.self ()));
   100 
   101 fun NAMED_CRITICAL name e =
   102   if self_critical () then e ()
   103   else
   104     uninterruptible (fn atts => fn () =>
   105       let
   106         val name' = ! critical_name;
   107         val _ =
   108           if Mutex.trylock critical_lock then ()
   109           else
   110             let
   111               val timer = Timer.startRealTimer ();
   112               val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
   113               val _ = Mutex.lock critical_lock;
   114               val time = Timer.checkRealTimer timer;
   115               val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
   116                 "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
   117             in () end;
   118         val _ = critical_thread := SOME (Thread.self ());
   119         val _ = critical_name := name;
   120         val result = Exn.capture (with_attributes atts (fn _ => e)) ();
   121         val _ = critical_name := "";
   122         val _ = critical_thread := NONE;
   123         val _ = Mutex.unlock critical_lock;
   124       in Exn.release result end) ();
   125 
   126 fun CRITICAL e = NAMED_CRITICAL "" e;
   127 
   128 end;
   129 
   130 
   131 (* scheduling -- multiple threads working on a queue of tasks *)
   132 
   133 datatype 'a task =
   134   Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
   135 
   136 fun schedule n next_task = uninterruptible (fn _ => fn tasks =>
   137   let
   138     (*protected execution*)
   139     val lock = Mutex.mutex ();
   140     val protected_name = ref "";
   141     fun PROTECTED name e =
   142       let
   143         val name' = ! protected_name;
   144         val _ =
   145           if Mutex.trylock lock then ()
   146           else
   147             let
   148               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": waiting");
   149               val _ = Mutex.lock lock;
   150               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": passed");
   151             in () end;
   152         val _ = protected_name := name;
   153         val res = Exn.capture e ();
   154         val _ = protected_name := "";
   155         val _ = Mutex.unlock lock;
   156       in Exn.release res end;
   157 
   158     (*wakeup condition*)
   159     val wakeup = ConditionVar.conditionVar ();
   160     fun wakeup_all () = ConditionVar.broadcast wakeup;
   161     fun wait () = ConditionVar.wait (wakeup, lock);
   162     fun wait_timeout () = ConditionVar.waitUntil (wakeup, lock, Time.now () + Time.fromSeconds 1);
   163 
   164     (*queue of tasks*)
   165     val queue = ref tasks;
   166     val active = ref 0;
   167     fun trace_active () = tracing 1 (fn () => "SCHEDULE: " ^ Int.toString (! active) ^ " active");
   168     fun dequeue () =
   169       let
   170         val (next, tasks') = next_task (! queue);
   171         val _ = queue := tasks';
   172       in
   173         (case next of Wait =>
   174           (dec active; trace_active ();
   175             wait ();
   176             inc active; trace_active ();
   177             dequeue ())
   178         | _ => next)
   179       end;
   180 
   181     (*pool of running threads*)
   182     val status = ref ([]: exn list);
   183     val running = ref ([]: Thread.thread list);
   184     fun start f =
   185       (inc active;
   186        change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
   187     fun stop () =
   188       (dec active;
   189        change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
   190 
   191    (*worker thread*)
   192     fun worker () =
   193       (case PROTECTED "dequeue" dequeue of
   194         Task {body, cont, fail} =>
   195           (case Exn.capture (interruptible (fn _ => body)) () of
   196             Exn.Result () =>
   197               (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
   198           | Exn.Exn exn =>
   199               PROTECTED "fail" (fn () =>
   200                 (change status (cons exn); change queue fail; stop (); wakeup_all ())))
   201       | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
   202 
   203     (*main control: fork and wait*)
   204     fun fork 0 = ()
   205       | fork k = (start worker; fork (k - 1));
   206     val _ = PROTECTED "main" (fn () =>
   207      (fork (Int.max (n, 1));
   208       while not (List.null (! running)) do
   209       (trace_active ();
   210        if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else ();
   211        wait_timeout ())));
   212 
   213   in ! status end);
   214 
   215 end;
   216 
   217 val NAMED_CRITICAL = Multithreading.NAMED_CRITICAL;
   218 val CRITICAL = Multithreading.CRITICAL;
   219 val ignore_interrupt = Multithreading.ignore_interrupt;
   220 val raise_interrupt = Multithreading.raise_interrupt;
   221 val interrupt_timeout = Multithreading.interrupt_timeout;
   222