src/Pure/ML-Systems/multithreading_polyml.ML
author wenzelm
Sat Feb 16 16:44:02 2008 +0100 (2008-02-16 ago)
changeset 26083 abb3f8dd66dc
parent 26074 44c5419cd9f1
child 26098 b59d33f73aed
permissions -rw-r--r--
removed managed_process (cf. General/shell_process.ML);
replaced ignore/raise_interrupt by more flexible (un)interruptible combinators;
tuned timeLimit: sleep already interruptible by default;
schedule: restore attributes of body, instead of forcing interruptible execution;
     1 (*  Title:      Pure/ML-Systems/multithreading_polyml.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Multithreading in Poly/ML 5.1 (cf. polyml/basis/Thread.sml).
     6 *)
     7 
     8 open Thread;
     9 
    10 signature MULTITHREADING_POLYML =
    11 sig
    12   val interruptible: ('a -> 'b) -> 'a -> 'b
    13   val uninterruptible: ((('c -> 'd) -> 'c -> 'd) -> 'a -> 'b) -> 'a -> 'b
    14   structure TimeLimit: TIME_LIMIT
    15 end;
    16 
    17 signature BASIC_MULTITHREADING =
    18 sig
    19   include BASIC_MULTITHREADING
    20   include MULTITHREADING_POLYML
    21 end;
    22 
    23 signature MULTITHREADING =
    24 sig
    25   include MULTITHREADING
    26   include MULTITHREADING_POLYML
    27 end;
    28 
    29 structure Multithreading: MULTITHREADING =
    30 struct
    31 
    32 (* options *)
    33 
    34 val trace = ref 0;
    35 fun tracing level msg =
    36   if level <= ! trace
    37   then (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
    38   else ();
    39 
    40 val available = true;
    41 
    42 val max_threads = ref 1;
    43 
    44 fun max_threads_value () =
    45   let val m = ! max_threads
    46   in if m <= 0 then Thread.numProcessors () else m end;
    47 
    48 
    49 (* misc utils *)
    50 
    51 fun cons x xs = x :: xs;
    52 
    53 fun change r f = r := f (! r);
    54 
    55 fun inc i = (i := ! i + 1; ! i);
    56 fun dec i = (i := ! i - 1; ! i);
    57 
    58 fun show "" = "" | show name = " " ^ name;
    59 fun show' "" = "" | show' name = " [" ^ name ^ "]";
    60 
    61 
    62 (* thread attributes *)
    63 
    64 fun with_attributes new_atts f x =
    65   let
    66     val orig_atts = Thread.getAttributes ();
    67     fun restore () = Thread.setAttributes orig_atts;
    68   in
    69     Exn.release
    70     (*RACE for fully asynchronous interrupts!*)
    71     (let
    72         val _ = Thread.setAttributes new_atts;
    73         val result = Exn.capture (f orig_atts) x;
    74         val _ = restore ();
    75       in result end
    76       handle Interrupt => (restore (); Exn.Exn Interrupt))
    77   end;
    78 
    79 fun interruptible f =
    80   with_attributes
    81     [Thread.EnableBroadcastInterrupt true, Thread.InterruptState Thread.InterruptAsynchOnce]
    82     (fn _ => f);
    83 
    84 fun uninterruptible f =
    85   with_attributes
    86     [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptDefer]
    87     (fn atts => f (fn g => with_attributes atts (fn _ => g)));
    88 
    89 
    90 (* execution with time limit *)
    91 
    92 structure TimeLimit =
    93 struct
    94 
    95 exception TimeOut;
    96 
    97 fun timeLimit time f x = uninterruptible (fn restore_attributes => fn () =>
    98   let
    99     val worker = Thread.self ();
   100     val timeout = ref false;
   101     val watchdog = Thread.fork (fn () =>
   102       (OS.Process.sleep time; timeout := true; Thread.interrupt worker), []);
   103 
   104     (*RACE! timeout signal vs. external Interrupt*)
   105     val result = Exn.capture (restore_attributes f) x;
   106     val was_timeout = (case result of Exn.Exn Interrupt => ! timeout | _ => false);
   107 
   108     val _ = Thread.interrupt watchdog handle Thread _ => ();
   109   in if was_timeout then raise TimeOut else Exn.release result end) ();
   110 
   111 end;
   112 
   113 
   114 (* critical section -- may be nested within the same thread *)
   115 
   116 local
   117 
   118 val critical_lock = Mutex.mutex ();
   119 val critical_thread = ref (NONE: Thread.thread option);
   120 val critical_name = ref "";
   121 
   122 in
   123 
   124 fun self_critical () =
   125   (case ! critical_thread of
   126     NONE => false
   127   | SOME id => Thread.equal (id, Thread.self ()));
   128 
   129 fun NAMED_CRITICAL name e =
   130   if self_critical () then e ()
   131   else
   132     uninterruptible (fn restore_attributes => fn () =>
   133       let
   134         val name' = ! critical_name;
   135         val _ =
   136           if Mutex.trylock critical_lock then ()
   137           else
   138             let
   139               val timer = Timer.startRealTimer ();
   140               val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
   141               val _ = Mutex.lock critical_lock;
   142               val time = Timer.checkRealTimer timer;
   143               val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
   144                 "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
   145             in () end;
   146         val _ = critical_thread := SOME (Thread.self ());
   147         val _ = critical_name := name;
   148         val result = Exn.capture (restore_attributes e) ();
   149         val _ = critical_name := "";
   150         val _ = critical_thread := NONE;
   151         val _ = Mutex.unlock critical_lock;
   152       in Exn.release result end) ();
   153 
   154 fun CRITICAL e = NAMED_CRITICAL "" e;
   155 
   156 end;
   157 
   158 
   159 (* scheduling -- multiple threads working on a queue of tasks *)
   160 
   161 datatype 'a task =
   162   Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
   163 
   164 fun schedule n next_task = uninterruptible (fn restore_attributes => fn tasks =>
   165   let
   166     (*protected execution*)
   167     val lock = Mutex.mutex ();
   168     val protected_name = ref "";
   169     fun PROTECTED name e =
   170       let
   171         val name' = ! protected_name;
   172         val _ =
   173           if Mutex.trylock lock then ()
   174           else
   175             let
   176               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": waiting");
   177               val _ = Mutex.lock lock;
   178               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": passed");
   179             in () end;
   180         val _ = protected_name := name;
   181         val res = Exn.capture e ();
   182         val _ = protected_name := "";
   183         val _ = Mutex.unlock lock;
   184       in Exn.release res end;
   185 
   186     (*wakeup condition*)
   187     val wakeup = ConditionVar.conditionVar ();
   188     fun wakeup_all () = ConditionVar.broadcast wakeup;
   189     fun wait () = ConditionVar.wait (wakeup, lock);
   190     fun wait_timeout () = ConditionVar.waitUntil (wakeup, lock, Time.now () + Time.fromSeconds 1);
   191 
   192     (*queue of tasks*)
   193     val queue = ref tasks;
   194     val active = ref 0;
   195     fun trace_active () = tracing 1 (fn () => "SCHEDULE: " ^ Int.toString (! active) ^ " active");
   196     fun dequeue () =
   197       let
   198         val (next, tasks') = next_task (! queue);
   199         val _ = queue := tasks';
   200       in
   201         (case next of Wait =>
   202           (dec active; trace_active ();
   203             wait ();
   204             inc active; trace_active ();
   205             dequeue ())
   206         | _ => next)
   207       end;
   208 
   209     (*pool of running threads*)
   210     val status = ref ([]: exn list);
   211     val running = ref ([]: Thread.thread list);
   212     fun start f =
   213       (inc active;
   214        change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
   215     fun stop () =
   216       (dec active;
   217        change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
   218 
   219    (*worker thread*)
   220     fun worker () =
   221       (case PROTECTED "dequeue" dequeue of
   222         Task {body, cont, fail} =>
   223           (case Exn.capture (restore_attributes body) () of
   224             Exn.Result () =>
   225               (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
   226           | Exn.Exn exn =>
   227               PROTECTED "fail" (fn () =>
   228                 (change status (cons exn); change queue fail; stop (); wakeup_all ())))
   229       | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
   230 
   231     (*main control: fork and wait*)
   232     fun fork 0 = ()
   233       | fork k = (start worker; fork (k - 1));
   234     val _ = PROTECTED "main" (fn () =>
   235      (fork (Int.max (n, 1));
   236       while not (List.null (! running)) do
   237       (trace_active ();
   238        if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else ();
   239        wait_timeout ())));
   240 
   241   in ! status end);
   242 
   243 
   244 (* serial numbers *)
   245 
   246 local
   247 
   248 val serial_lock = Mutex.mutex ();
   249 val serial_count = ref 0;
   250 
   251 in
   252 
   253 val serial = uninterruptible (fn _ => fn () =>
   254   let
   255     val _ = Mutex.lock serial_lock;
   256     val res = inc serial_count;
   257     val _ = Mutex.unlock serial_lock;
   258   in res end);
   259 
   260 end;
   261 
   262 
   263 (* thread data *)
   264 
   265 val get_data = Thread.getLocal;
   266 val put_data = Thread.setLocal;
   267 
   268 end;
   269 
   270 structure BasicMultithreading: BASIC_MULTITHREADING = Multithreading;
   271 open BasicMultithreading;