src/Pure/ML-Systems/multithreading_polyml.ML
author wenzelm
Mon Sep 24 13:52:50 2007 +0200 (2007-09-24 ago)
changeset 24688 a5754ca5c510
parent 24672 f311717d1f03
child 25704 df9c8074ff09
permissions -rw-r--r--
replaced interrupt_timeout by TimeLimit.timeLimit (available on SML/NJ and Poly/ML 5.1);
     1 (*  Title:      Pure/ML-Systems/multithreading_polyml.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Multithreading in Poly/ML 5.1 or later (cf. polyml/basis/Thread.sml).
     6 *)
     7 
     8 open Thread;
     9 
    10 signature MULTITHREADING =
    11 sig
    12   include MULTITHREADING
    13   val ignore_interrupt: ('a -> 'b) -> 'a -> 'b
    14   val raise_interrupt: ('a -> 'b) -> 'a -> 'b
    15   structure TimeLimit: TIME_LIMIT
    16 end;
    17 
    18 structure Multithreading: MULTITHREADING =
    19 struct
    20 
    21 (* options *)
    22 
    23 val trace = ref 0;
    24 fun tracing level msg =
    25   if level <= ! trace
    26   then (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
    27   else ();
    28 
    29 val available = true;
    30 val max_threads = ref 1;
    31 
    32 
    33 (* misc utils *)
    34 
    35 fun cons x xs = x :: xs;
    36 
    37 fun change r f = r := f (! r);
    38 
    39 fun inc i = (i := ! i + 1; ! i);
    40 fun dec i = (i := ! i - 1; ! i);
    41 
    42 fun show "" = "" | show name = " " ^ name;
    43 fun show' "" = "" | show' name = " [" ^ name ^ "]";
    44 
    45 
    46 (* thread attributes *)
    47 
    48 fun with_attributes new_atts f x =
    49   let
    50     val orig_atts = Thread.getAttributes ();
    51     fun restore () = Thread.setAttributes orig_atts;
    52   in
    53     Exn.release
    54     (*RACE for fully asynchronous interrupts!*)
    55     (let
    56         val _ = Thread.setAttributes new_atts;
    57         val result = Exn.capture (f orig_atts) x;
    58         val _ = restore ();
    59       in result end
    60       handle Interrupt => (restore (); Exn.Exn Interrupt))
    61   end;
    62 
    63 
    64 (* interrupt handling *)
    65 
    66 fun uninterruptible f x = with_attributes
    67   [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptDefer] f x;
    68 
    69 fun interruptible f x = with_attributes
    70   [Thread.EnableBroadcastInterrupt true, Thread.InterruptState Thread.InterruptAsynchOnce] f x;
    71 
    72 fun ignore_interrupt f = uninterruptible (fn _ => f);
    73 fun raise_interrupt f = interruptible (fn _ => f);
    74 
    75 
    76 (* execution with time limit *)
    77 
    78 structure TimeLimit =
    79 struct
    80 
    81 exception TimeOut;
    82 
    83 fun timeLimit time f x =
    84   uninterruptible (fn atts => fn () =>
    85     let
    86       val worker = Thread.self ();
    87       val timeout = ref false;
    88       val watchdog = Thread.fork (interruptible (fn _ => fn () =>
    89         (OS.Process.sleep time; timeout := true; Thread.interrupt worker)), []);
    90 
    91       (*RACE! timeout signal vs. external Interrupt*)
    92       val result = Exn.capture (with_attributes atts (fn _ => f)) x;
    93       val was_timeout = (case result of Exn.Exn Interrupt => ! timeout | _ => false);
    94 
    95       val _ = Thread.interrupt watchdog handle Thread _ => ();
    96     in if was_timeout then raise TimeOut else Exn.release result end) ();
    97 
    98 end;
    99 
   100 
   101 (* critical section -- may be nested within the same thread *)
   102 
   103 local
   104 
   105 val critical_lock = Mutex.mutex ();
   106 val critical_thread = ref (NONE: Thread.thread option);
   107 val critical_name = ref "";
   108 
   109 in
   110 
   111 fun self_critical () =
   112   (case ! critical_thread of
   113     NONE => false
   114   | SOME id => Thread.equal (id, Thread.self ()));
   115 
   116 fun NAMED_CRITICAL name e =
   117   if self_critical () then e ()
   118   else
   119     uninterruptible (fn atts => fn () =>
   120       let
   121         val name' = ! critical_name;
   122         val _ =
   123           if Mutex.trylock critical_lock then ()
   124           else
   125             let
   126               val timer = Timer.startRealTimer ();
   127               val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
   128               val _ = Mutex.lock critical_lock;
   129               val time = Timer.checkRealTimer timer;
   130               val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
   131                 "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
   132             in () end;
   133         val _ = critical_thread := SOME (Thread.self ());
   134         val _ = critical_name := name;
   135         val result = Exn.capture (with_attributes atts (fn _ => e)) ();
   136         val _ = critical_name := "";
   137         val _ = critical_thread := NONE;
   138         val _ = Mutex.unlock critical_lock;
   139       in Exn.release result end) ();
   140 
   141 fun CRITICAL e = NAMED_CRITICAL "" e;
   142 
   143 end;
   144 
   145 
   146 (* scheduling -- multiple threads working on a queue of tasks *)
   147 
   148 datatype 'a task =
   149   Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
   150 
   151 fun schedule n next_task = uninterruptible (fn _ => fn tasks =>
   152   let
   153     (*protected execution*)
   154     val lock = Mutex.mutex ();
   155     val protected_name = ref "";
   156     fun PROTECTED name e =
   157       let
   158         val name' = ! protected_name;
   159         val _ =
   160           if Mutex.trylock lock then ()
   161           else
   162             let
   163               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": waiting");
   164               val _ = Mutex.lock lock;
   165               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": passed");
   166             in () end;
   167         val _ = protected_name := name;
   168         val res = Exn.capture e ();
   169         val _ = protected_name := "";
   170         val _ = Mutex.unlock lock;
   171       in Exn.release res end;
   172 
   173     (*wakeup condition*)
   174     val wakeup = ConditionVar.conditionVar ();
   175     fun wakeup_all () = ConditionVar.broadcast wakeup;
   176     fun wait () = ConditionVar.wait (wakeup, lock);
   177     fun wait_timeout () = ConditionVar.waitUntil (wakeup, lock, Time.now () + Time.fromSeconds 1);
   178 
   179     (*queue of tasks*)
   180     val queue = ref tasks;
   181     val active = ref 0;
   182     fun trace_active () = tracing 1 (fn () => "SCHEDULE: " ^ Int.toString (! active) ^ " active");
   183     fun dequeue () =
   184       let
   185         val (next, tasks') = next_task (! queue);
   186         val _ = queue := tasks';
   187       in
   188         (case next of Wait =>
   189           (dec active; trace_active ();
   190             wait ();
   191             inc active; trace_active ();
   192             dequeue ())
   193         | _ => next)
   194       end;
   195 
   196     (*pool of running threads*)
   197     val status = ref ([]: exn list);
   198     val running = ref ([]: Thread.thread list);
   199     fun start f =
   200       (inc active;
   201        change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
   202     fun stop () =
   203       (dec active;
   204        change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
   205 
   206    (*worker thread*)
   207     fun worker () =
   208       (case PROTECTED "dequeue" dequeue of
   209         Task {body, cont, fail} =>
   210           (case Exn.capture (interruptible (fn _ => body)) () of
   211             Exn.Result () =>
   212               (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
   213           | Exn.Exn exn =>
   214               PROTECTED "fail" (fn () =>
   215                 (change status (cons exn); change queue fail; stop (); wakeup_all ())))
   216       | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
   217 
   218     (*main control: fork and wait*)
   219     fun fork 0 = ()
   220       | fork k = (start worker; fork (k - 1));
   221     val _ = PROTECTED "main" (fn () =>
   222      (fork (Int.max (n, 1));
   223       while not (List.null (! running)) do
   224       (trace_active ();
   225        if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else ();
   226        wait_timeout ())));
   227 
   228   in ! status end);
   229 
   230 end;
   231 
   232 val NAMED_CRITICAL = Multithreading.NAMED_CRITICAL;
   233 val CRITICAL = Multithreading.CRITICAL;
   234 
   235 val ignore_interrupt = Multithreading.ignore_interrupt;
   236 val raise_interrupt = Multithreading.raise_interrupt;
   237 
   238 structure TimeLimit = Multithreading.TimeLimit;
   239