src/Pure/ML-Systems/multithreading_polyml.ML
author wenzelm
Tue Dec 18 22:21:42 2007 +0100 (2007-12-18)
changeset 25704 df9c8074ff09
parent 24688 a5754ca5c510
child 25735 4d147263f71f
permissions -rw-r--r--
signature BASIC_MULTITHREADING;
added specific serial number generator, which avoid the global CRITICAL lock;
     1 (*  Title:      Pure/ML-Systems/multithreading_polyml.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Multithreading in Poly/ML 5.1 (cf. polyml/basis/Thread.sml).
     6 *)
     7 
     8 open Thread;
     9 
    10 signature MULTITHREADING_POLYML =
    11 sig
    12   val ignore_interrupt: ('a -> 'b) -> 'a -> 'b
    13   val raise_interrupt: ('a -> 'b) -> 'a -> 'b
    14   structure TimeLimit: TIME_LIMIT
    15 end;
    16 
    17 signature BASIC_MULTITHREADING =
    18 sig
    19   include BASIC_MULTITHREADING
    20   include MULTITHREADING_POLYML
    21 end;
    22 
    23 signature MULTITHREADING =
    24 sig
    25   include MULTITHREADING
    26   include MULTITHREADING_POLYML
    27 end;
    28 
    29 structure Multithreading: MULTITHREADING =
    30 struct
    31 
    32 (* options *)
    33 
    34 val trace = ref 0;
    35 fun tracing level msg =
    36   if level <= ! trace
    37   then (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
    38   else ();
    39 
    40 val available = true;
    41 val max_threads = ref 1;
    42 
    43 
    44 (* misc utils *)
    45 
    46 fun cons x xs = x :: xs;
    47 
    48 fun change r f = r := f (! r);
    49 
    50 fun inc i = (i := ! i + 1; ! i);
    51 fun dec i = (i := ! i - 1; ! i);
    52 
    53 fun show "" = "" | show name = " " ^ name;
    54 fun show' "" = "" | show' name = " [" ^ name ^ "]";
    55 
    56 
    57 (* thread attributes *)
    58 
    59 fun with_attributes new_atts f x =
    60   let
    61     val orig_atts = Thread.getAttributes ();
    62     fun restore () = Thread.setAttributes orig_atts;
    63   in
    64     Exn.release
    65     (*RACE for fully asynchronous interrupts!*)
    66     (let
    67         val _ = Thread.setAttributes new_atts;
    68         val result = Exn.capture (f orig_atts) x;
    69         val _ = restore ();
    70       in result end
    71       handle Interrupt => (restore (); Exn.Exn Interrupt))
    72   end;
    73 
    74 
    75 (* interrupt handling *)
    76 
    77 fun uninterruptible f x = with_attributes
    78   [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptDefer] f x;
    79 
    80 fun interruptible f x = with_attributes
    81   [Thread.EnableBroadcastInterrupt true, Thread.InterruptState Thread.InterruptAsynchOnce] f x;
    82 
    83 fun ignore_interrupt f = uninterruptible (fn _ => f);
    84 fun raise_interrupt f = interruptible (fn _ => f);
    85 
    86 
    87 (* execution with time limit *)
    88 
    89 structure TimeLimit =
    90 struct
    91 
    92 exception TimeOut;
    93 
    94 fun timeLimit time f x =
    95   uninterruptible (fn atts => fn () =>
    96     let
    97       val worker = Thread.self ();
    98       val timeout = ref false;
    99       val watchdog = Thread.fork (interruptible (fn _ => fn () =>
   100         (OS.Process.sleep time; timeout := true; Thread.interrupt worker)), []);
   101 
   102       (*RACE! timeout signal vs. external Interrupt*)
   103       val result = Exn.capture (with_attributes atts (fn _ => f)) x;
   104       val was_timeout = (case result of Exn.Exn Interrupt => ! timeout | _ => false);
   105 
   106       val _ = Thread.interrupt watchdog handle Thread _ => ();
   107     in if was_timeout then raise TimeOut else Exn.release result end) ();
   108 
   109 end;
   110 
   111 
   112 (* critical section -- may be nested within the same thread *)
   113 
   114 local
   115 
   116 val critical_lock = Mutex.mutex ();
   117 val critical_thread = ref (NONE: Thread.thread option);
   118 val critical_name = ref "";
   119 
   120 in
   121 
   122 fun self_critical () =
   123   (case ! critical_thread of
   124     NONE => false
   125   | SOME id => Thread.equal (id, Thread.self ()));
   126 
   127 fun NAMED_CRITICAL name e =
   128   if self_critical () then e ()
   129   else
   130     uninterruptible (fn atts => fn () =>
   131       let
   132         val name' = ! critical_name;
   133         val _ =
   134           if Mutex.trylock critical_lock then ()
   135           else
   136             let
   137               val timer = Timer.startRealTimer ();
   138               val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
   139               val _ = Mutex.lock critical_lock;
   140               val time = Timer.checkRealTimer timer;
   141               val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
   142                 "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
   143             in () end;
   144         val _ = critical_thread := SOME (Thread.self ());
   145         val _ = critical_name := name;
   146         val result = Exn.capture (with_attributes atts (fn _ => e)) ();
   147         val _ = critical_name := "";
   148         val _ = critical_thread := NONE;
   149         val _ = Mutex.unlock critical_lock;
   150       in Exn.release result end) ();
   151 
   152 fun CRITICAL e = NAMED_CRITICAL "" e;
   153 
   154 end;
   155 
   156 
   157 (* scheduling -- multiple threads working on a queue of tasks *)
   158 
   159 datatype 'a task =
   160   Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
   161 
   162 fun schedule n next_task = uninterruptible (fn _ => fn tasks =>
   163   let
   164     (*protected execution*)
   165     val lock = Mutex.mutex ();
   166     val protected_name = ref "";
   167     fun PROTECTED name e =
   168       let
   169         val name' = ! protected_name;
   170         val _ =
   171           if Mutex.trylock lock then ()
   172           else
   173             let
   174               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": waiting");
   175               val _ = Mutex.lock lock;
   176               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": passed");
   177             in () end;
   178         val _ = protected_name := name;
   179         val res = Exn.capture e ();
   180         val _ = protected_name := "";
   181         val _ = Mutex.unlock lock;
   182       in Exn.release res end;
   183 
   184     (*wakeup condition*)
   185     val wakeup = ConditionVar.conditionVar ();
   186     fun wakeup_all () = ConditionVar.broadcast wakeup;
   187     fun wait () = ConditionVar.wait (wakeup, lock);
   188     fun wait_timeout () = ConditionVar.waitUntil (wakeup, lock, Time.now () + Time.fromSeconds 1);
   189 
   190     (*queue of tasks*)
   191     val queue = ref tasks;
   192     val active = ref 0;
   193     fun trace_active () = tracing 1 (fn () => "SCHEDULE: " ^ Int.toString (! active) ^ " active");
   194     fun dequeue () =
   195       let
   196         val (next, tasks') = next_task (! queue);
   197         val _ = queue := tasks';
   198       in
   199         (case next of Wait =>
   200           (dec active; trace_active ();
   201             wait ();
   202             inc active; trace_active ();
   203             dequeue ())
   204         | _ => next)
   205       end;
   206 
   207     (*pool of running threads*)
   208     val status = ref ([]: exn list);
   209     val running = ref ([]: Thread.thread list);
   210     fun start f =
   211       (inc active;
   212        change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
   213     fun stop () =
   214       (dec active;
   215        change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
   216 
   217    (*worker thread*)
   218     fun worker () =
   219       (case PROTECTED "dequeue" dequeue of
   220         Task {body, cont, fail} =>
   221           (case Exn.capture (interruptible (fn _ => body)) () of
   222             Exn.Result () =>
   223               (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
   224           | Exn.Exn exn =>
   225               PROTECTED "fail" (fn () =>
   226                 (change status (cons exn); change queue fail; stop (); wakeup_all ())))
   227       | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
   228 
   229     (*main control: fork and wait*)
   230     fun fork 0 = ()
   231       | fork k = (start worker; fork (k - 1));
   232     val _ = PROTECTED "main" (fn () =>
   233      (fork (Int.max (n, 1));
   234       while not (List.null (! running)) do
   235       (trace_active ();
   236        if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else ();
   237        wait_timeout ())));
   238 
   239   in ! status end);
   240 
   241 
   242 (* serial numbers *)
   243 
   244 local
   245 
   246 val serial_lock = Mutex.mutex ();
   247 val serial_count = ref 0;
   248 
   249 in
   250 
   251 val serial = uninterruptible (fn _ => fn () =>
   252   let
   253     val _ = Mutex.lock serial_lock;
   254     val res = inc serial_count;
   255     val _ = Mutex.unlock serial_lock;
   256   in res end);
   257 
   258 end;
   259 
   260 end;
   261 
   262 structure BasicMultithreading: BASIC_MULTITHREADING = Multithreading;
   263 open BasicMultithreading;