src/Pure/ML-Systems/multithreading_polyml.ML
author wenzelm
Wed Jan 02 16:32:53 2008 +0100 (2008-01-02 ago)
changeset 25775 90525e67ede7
parent 25735 4d147263f71f
child 26074 44c5419cd9f1
permissions -rw-r--r--
added Multithreading.max_threads_value, which maps a value of 0 to number of CPUs;
     1 (*  Title:      Pure/ML-Systems/multithreading_polyml.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Multithreading in Poly/ML 5.1 (cf. polyml/basis/Thread.sml).
     6 *)
     7 
     8 open Thread;
     9 
    10 signature MULTITHREADING_POLYML =
    11 sig
    12   val ignore_interrupt: ('a -> 'b) -> 'a -> 'b
    13   val raise_interrupt: ('a -> 'b) -> 'a -> 'b
    14   structure TimeLimit: TIME_LIMIT
    15 end;
    16 
    17 signature BASIC_MULTITHREADING =
    18 sig
    19   include BASIC_MULTITHREADING
    20   include MULTITHREADING_POLYML
    21 end;
    22 
    23 signature MULTITHREADING =
    24 sig
    25   include MULTITHREADING
    26   include MULTITHREADING_POLYML
    27 end;
    28 
    29 structure Multithreading: MULTITHREADING =
    30 struct
    31 
    32 (* options *)
    33 
    34 val trace = ref 0;
    35 fun tracing level msg =
    36   if level <= ! trace
    37   then (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
    38   else ();
    39 
    40 val available = true;
    41 
    42 val max_threads = ref 1;
    43 
    44 fun max_threads_value () =
    45   let val m = ! max_threads
    46   in if m <= 0 then Thread.numProcessors () else m end;
    47 
    48 
    49 (* misc utils *)
    50 
    51 fun cons x xs = x :: xs;
    52 
    53 fun change r f = r := f (! r);
    54 
    55 fun inc i = (i := ! i + 1; ! i);
    56 fun dec i = (i := ! i - 1; ! i);
    57 
    58 fun show "" = "" | show name = " " ^ name;
    59 fun show' "" = "" | show' name = " [" ^ name ^ "]";
    60 
    61 
    62 (* thread attributes *)
    63 
    64 fun with_attributes new_atts f x =
    65   let
    66     val orig_atts = Thread.getAttributes ();
    67     fun restore () = Thread.setAttributes orig_atts;
    68   in
    69     Exn.release
    70     (*RACE for fully asynchronous interrupts!*)
    71     (let
    72         val _ = Thread.setAttributes new_atts;
    73         val result = Exn.capture (f orig_atts) x;
    74         val _ = restore ();
    75       in result end
    76       handle Interrupt => (restore (); Exn.Exn Interrupt))
    77   end;
    78 
    79 
    80 (* interrupt handling *)
    81 
    82 fun uninterruptible f x = with_attributes
    83   [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptDefer] f x;
    84 
    85 fun interruptible f x = with_attributes
    86   [Thread.EnableBroadcastInterrupt true, Thread.InterruptState Thread.InterruptAsynchOnce] f x;
    87 
    88 fun ignore_interrupt f = uninterruptible (fn _ => f);
    89 fun raise_interrupt f = interruptible (fn _ => f);
    90 
    91 
    92 (* execution with time limit *)
    93 
    94 structure TimeLimit =
    95 struct
    96 
    97 exception TimeOut;
    98 
    99 fun timeLimit time f x =
   100   uninterruptible (fn atts => fn () =>
   101     let
   102       val worker = Thread.self ();
   103       val timeout = ref false;
   104       val watchdog = Thread.fork (interruptible (fn _ => fn () =>
   105         (OS.Process.sleep time; timeout := true; Thread.interrupt worker)), []);
   106 
   107       (*RACE! timeout signal vs. external Interrupt*)
   108       val result = Exn.capture (with_attributes atts (fn _ => f)) x;
   109       val was_timeout = (case result of Exn.Exn Interrupt => ! timeout | _ => false);
   110 
   111       val _ = Thread.interrupt watchdog handle Thread _ => ();
   112     in if was_timeout then raise TimeOut else Exn.release result end) ();
   113 
   114 end;
   115 
   116 
   117 (* critical section -- may be nested within the same thread *)
   118 
   119 local
   120 
   121 val critical_lock = Mutex.mutex ();
   122 val critical_thread = ref (NONE: Thread.thread option);
   123 val critical_name = ref "";
   124 
   125 in
   126 
   127 fun self_critical () =
   128   (case ! critical_thread of
   129     NONE => false
   130   | SOME id => Thread.equal (id, Thread.self ()));
   131 
   132 fun NAMED_CRITICAL name e =
   133   if self_critical () then e ()
   134   else
   135     uninterruptible (fn atts => fn () =>
   136       let
   137         val name' = ! critical_name;
   138         val _ =
   139           if Mutex.trylock critical_lock then ()
   140           else
   141             let
   142               val timer = Timer.startRealTimer ();
   143               val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
   144               val _ = Mutex.lock critical_lock;
   145               val time = Timer.checkRealTimer timer;
   146               val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
   147                 "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
   148             in () end;
   149         val _ = critical_thread := SOME (Thread.self ());
   150         val _ = critical_name := name;
   151         val result = Exn.capture (with_attributes atts (fn _ => e)) ();
   152         val _ = critical_name := "";
   153         val _ = critical_thread := NONE;
   154         val _ = Mutex.unlock critical_lock;
   155       in Exn.release result end) ();
   156 
   157 fun CRITICAL e = NAMED_CRITICAL "" e;
   158 
   159 end;
   160 
   161 
   162 (* scheduling -- multiple threads working on a queue of tasks *)
   163 
   164 datatype 'a task =
   165   Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
   166 
   167 fun schedule n next_task = uninterruptible (fn _ => fn tasks =>
   168   let
   169     (*protected execution*)
   170     val lock = Mutex.mutex ();
   171     val protected_name = ref "";
   172     fun PROTECTED name e =
   173       let
   174         val name' = ! protected_name;
   175         val _ =
   176           if Mutex.trylock lock then ()
   177           else
   178             let
   179               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": waiting");
   180               val _ = Mutex.lock lock;
   181               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": passed");
   182             in () end;
   183         val _ = protected_name := name;
   184         val res = Exn.capture e ();
   185         val _ = protected_name := "";
   186         val _ = Mutex.unlock lock;
   187       in Exn.release res end;
   188 
   189     (*wakeup condition*)
   190     val wakeup = ConditionVar.conditionVar ();
   191     fun wakeup_all () = ConditionVar.broadcast wakeup;
   192     fun wait () = ConditionVar.wait (wakeup, lock);
   193     fun wait_timeout () = ConditionVar.waitUntil (wakeup, lock, Time.now () + Time.fromSeconds 1);
   194 
   195     (*queue of tasks*)
   196     val queue = ref tasks;
   197     val active = ref 0;
   198     fun trace_active () = tracing 1 (fn () => "SCHEDULE: " ^ Int.toString (! active) ^ " active");
   199     fun dequeue () =
   200       let
   201         val (next, tasks') = next_task (! queue);
   202         val _ = queue := tasks';
   203       in
   204         (case next of Wait =>
   205           (dec active; trace_active ();
   206             wait ();
   207             inc active; trace_active ();
   208             dequeue ())
   209         | _ => next)
   210       end;
   211 
   212     (*pool of running threads*)
   213     val status = ref ([]: exn list);
   214     val running = ref ([]: Thread.thread list);
   215     fun start f =
   216       (inc active;
   217        change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
   218     fun stop () =
   219       (dec active;
   220        change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
   221 
   222    (*worker thread*)
   223     fun worker () =
   224       (case PROTECTED "dequeue" dequeue of
   225         Task {body, cont, fail} =>
   226           (case Exn.capture (interruptible (fn _ => body)) () of
   227             Exn.Result () =>
   228               (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
   229           | Exn.Exn exn =>
   230               PROTECTED "fail" (fn () =>
   231                 (change status (cons exn); change queue fail; stop (); wakeup_all ())))
   232       | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
   233 
   234     (*main control: fork and wait*)
   235     fun fork 0 = ()
   236       | fork k = (start worker; fork (k - 1));
   237     val _ = PROTECTED "main" (fn () =>
   238      (fork (Int.max (n, 1));
   239       while not (List.null (! running)) do
   240       (trace_active ();
   241        if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else ();
   242        wait_timeout ())));
   243 
   244   in ! status end);
   245 
   246 
   247 (* serial numbers *)
   248 
   249 local
   250 
   251 val serial_lock = Mutex.mutex ();
   252 val serial_count = ref 0;
   253 
   254 in
   255 
   256 val serial = uninterruptible (fn _ => fn () =>
   257   let
   258     val _ = Mutex.lock serial_lock;
   259     val res = inc serial_count;
   260     val _ = Mutex.unlock serial_lock;
   261   in res end);
   262 
   263 end;
   264 
   265 
   266 (* thread data *)
   267 
   268 val get_data = Thread.getLocal;
   269 val put_data = Thread.setLocal;
   270 
   271 end;
   272 
   273 structure BasicMultithreading: BASIC_MULTITHREADING = Multithreading;
   274 open BasicMultithreading;