src/Pure/ML-Systems/multithreading_polyml.ML
author wenzelm
Fri Aug 10 14:49:01 2007 +0200 (2007-08-10)
changeset 24214 0482ecc4ef11
parent 24208 f4cafbaa05e4
child 24291 fa72aab966dc
permissions -rw-r--r--
(un)interruptible: pass-through original thread attributes;
     1 (*  Title:      Pure/ML-Systems/multithreading_polyml.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Multithreading in Poly/ML 5.1 or later (cf. polyml/basis/Thread.sml).
     6 *)
     7 
     8 open Thread;
     9 
    10 signature MULTITHREADING =
    11 sig
    12   include MULTITHREADING
    13   val uninterruptible: (Thread.threadAttribute list -> 'a -> 'b) -> 'a -> 'b
    14   val interruptible: (Thread.threadAttribute list -> 'a -> 'b) -> 'a -> 'b
    15 end;
    16 
    17 structure Multithreading: MULTITHREADING =
    18 struct
    19 
    20 (* options *)
    21 
    22 val trace = ref 0;
    23 fun tracing level msg =
    24   if level <= ! trace
    25   then (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
    26   else ();
    27 
    28 val available = true;
    29 val max_threads = ref 1;
    30 
    31 
    32 (* misc utils *)
    33 
    34 fun cons x xs = x :: xs;
    35 
    36 fun change r f = r := f (! r);
    37 
    38 fun inc i = (i := ! i + 1; ! i);
    39 fun dec i = (i := ! i - 1; ! i);
    40 
    41 fun show "" = "" | show name = " " ^ name;
    42 fun show' "" = "" | show' name = " [" ^ name ^ "]";
    43 
    44 
    45 (* thread attributes *)
    46 
    47 fun with_attributes new_atts f x =
    48   let
    49     val orig_atts = Thread.getAttributes ();
    50     fun restore () = Thread.setAttributes orig_atts;
    51   in
    52     Exn.release
    53     (*RACE for fully asynchronous interrupts!*)
    54     (let
    55         val _ = Thread.setAttributes new_atts;
    56         val result = Exn.capture (f orig_atts) x;
    57         val _ = restore ();
    58       in result end
    59       handle Interrupt => (restore (); Exn.Exn Interrupt))
    60   end;
    61 
    62 fun with_interrupt_state state = with_attributes [Thread.InterruptState state];
    63 
    64 fun uninterruptible f x = with_interrupt_state Thread.InterruptDefer f x;
    65 fun interruptible f x = with_interrupt_state Thread.InterruptAsynchOnce f x;
    66 
    67 
    68 (* critical section -- may be nested within the same thread *)
    69 
    70 local
    71 
    72 val critical_lock = Mutex.mutex ();
    73 val critical_thread = ref (NONE: Thread.thread option);
    74 val critical_name = ref "";
    75 
    76 in
    77 
    78 fun self_critical () =
    79   (case ! critical_thread of
    80     NONE => false
    81   | SOME id => Thread.equal (id, Thread.self ()));
    82 
    83 fun NAMED_CRITICAL name e =
    84   if self_critical () then e ()
    85   else
    86     uninterruptible (fn atts => fn () =>
    87       let
    88         val name' = ! critical_name;
    89         val _ =
    90           if Mutex.trylock critical_lock then ()
    91           else
    92             let
    93               val timer = Timer.startRealTimer ();
    94               val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
    95               val _ = Mutex.lock critical_lock;
    96               val time = Timer.checkRealTimer timer;
    97               val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
    98                 "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
    99             in () end;
   100         val _ = critical_thread := SOME (Thread.self ());
   101         val _ = critical_name := name;
   102         val result = Exn.capture (with_attributes atts (fn _ => e)) ();
   103         val _ = critical_name := "";
   104         val _ = critical_thread := NONE;
   105         val _ = Mutex.unlock critical_lock;
   106       in Exn.release result end) ();
   107 
   108 fun CRITICAL e = NAMED_CRITICAL "" e;
   109 
   110 end;
   111 
   112 
   113 (* scheduling -- multiple threads working on a queue of tasks *)
   114 
   115 datatype 'a task =
   116   Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
   117 
   118 local
   119 
   120 val protected_name = ref "";
   121 
   122 in
   123 
   124 fun schedule n next_task = uninterruptible (fn _ => fn tasks =>
   125   let
   126     (*protected execution*)
   127     val lock = Mutex.mutex ();
   128     fun PROTECTED name e =
   129       let
   130         val name' = ! protected_name;
   131         val _ =
   132           if Mutex.trylock lock then ()
   133           else
   134             let
   135               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": waiting");
   136               val _ = Mutex.lock lock;
   137               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": passed");
   138             in () end;
   139         val _ = protected_name := name;
   140         val res = Exn.capture e ();
   141         val _ = protected_name := "";
   142         val _ = Mutex.unlock lock;
   143       in Exn.release res end;
   144 
   145     (*wakeup condition*)
   146     val wakeup = ConditionVar.conditionVar ();
   147     fun wakeup_all () = ConditionVar.broadcast wakeup;
   148     fun wait () = ConditionVar.wait (wakeup, lock);
   149 
   150     (*queue of tasks*)
   151     val queue = ref tasks;
   152     val active = ref 0;
   153     fun trace_active () = tracing 1 (fn () => "SCHEDULE: " ^ Int.toString (! active) ^ " active");
   154     fun dequeue () =
   155       let
   156         val (next, tasks') = next_task (! queue);
   157         val _ = queue := tasks';
   158       in
   159         (case next of Wait =>
   160           (dec active; trace_active ();
   161             wait ();
   162             inc active; trace_active ();
   163             dequeue ())
   164         | _ => next)
   165       end;
   166 
   167     (*pool of running threads*)
   168     val status = ref ([]: exn list);
   169     val running = ref ([]: Thread.thread list);
   170     fun start f =
   171       (inc active;
   172        change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
   173     fun stop () =
   174       (dec active;
   175        change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
   176 
   177    (*worker thread*)
   178     fun worker () =
   179       (case PROTECTED "dequeue" dequeue of
   180         Task {body, cont, fail} =>
   181           (case Exn.capture (interruptible (fn _ => body)) () of
   182             Exn.Result () =>
   183               (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
   184           | Exn.Exn exn =>
   185               PROTECTED "fail" (fn () =>
   186                 (change status (cons exn); change queue fail; stop (); wakeup_all ())))
   187       | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
   188 
   189     (*main control: fork and wait*)
   190     fun fork 0 = ()
   191       | fork k = (start worker; fork (k - 1));
   192     val _ = PROTECTED "main" (fn () =>
   193      (fork (Int.max (n, 1));
   194       while not (List.null (! running)) do
   195       (trace_active ();
   196        if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else ();
   197        wait ())));
   198 
   199   in ! status end);
   200 
   201 end;
   202 
   203 end;
   204 
   205 val NAMED_CRITICAL = Multithreading.NAMED_CRITICAL;
   206 val CRITICAL = Multithreading.CRITICAL;
   207