src/Pure/ML-Systems/multithreading_polyml.ML
author wenzelm
Fri Feb 15 23:22:02 2008 +0100 (2008-02-15 ago)
changeset 26074 44c5419cd9f1
parent 25775 90525e67ede7
child 26083 abb3f8dd66dc
permissions -rw-r--r--
support for managed external processes;
     1 (*  Title:      Pure/ML-Systems/multithreading_polyml.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Multithreading in Poly/ML 5.1 (cf. polyml/basis/Thread.sml).
     6 *)
     7 
     8 open Thread;
     9 
    10 signature MULTITHREADING_POLYML =
    11 sig
    12   val ignore_interrupt: ('a -> 'b) -> 'a -> 'b
    13   val raise_interrupt: ('a -> 'b) -> 'a -> 'b
    14   structure TimeLimit: TIME_LIMIT
    15 end;
    16 
    17 signature BASIC_MULTITHREADING =
    18 sig
    19   include BASIC_MULTITHREADING
    20   include MULTITHREADING_POLYML
    21 end;
    22 
    23 signature MULTITHREADING =
    24 sig
    25   include MULTITHREADING
    26   include MULTITHREADING_POLYML
    27 end;
    28 
    29 structure Multithreading: MULTITHREADING =
    30 struct
    31 
    32 (* options *)
    33 
    34 val trace = ref 0;
    35 fun tracing level msg =
    36   if level <= ! trace
    37   then (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
    38   else ();
    39 
    40 val available = true;
    41 
    42 val max_threads = ref 1;
    43 
    44 fun max_threads_value () =
    45   let val m = ! max_threads
    46   in if m <= 0 then Thread.numProcessors () else m end;
    47 
    48 
    49 (* misc utils *)
    50 
    51 fun cons x xs = x :: xs;
    52 
    53 fun change r f = r := f (! r);
    54 
    55 fun inc i = (i := ! i + 1; ! i);
    56 fun dec i = (i := ! i - 1; ! i);
    57 
    58 fun show "" = "" | show name = " " ^ name;
    59 fun show' "" = "" | show' name = " [" ^ name ^ "]";
    60 
    61 
    62 (* thread attributes *)
    63 
    64 fun with_attributes new_atts f x =
    65   let
    66     val orig_atts = Thread.getAttributes ();
    67     fun restore () = Thread.setAttributes orig_atts;
    68   in
    69     Exn.release
    70     (*RACE for fully asynchronous interrupts!*)
    71     (let
    72         val _ = Thread.setAttributes new_atts;
    73         val result = Exn.capture (f orig_atts) x;
    74         val _ = restore ();
    75       in result end
    76       handle Interrupt => (restore (); Exn.Exn Interrupt))
    77   end;
    78 
    79 
    80 (* interrupt handling *)
    81 
    82 fun uninterruptible f x = with_attributes
    83   [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptDefer] f x;
    84 
    85 fun interruptible f x = with_attributes
    86   [Thread.EnableBroadcastInterrupt true, Thread.InterruptState Thread.InterruptAsynchOnce] f x;
    87 
    88 fun ignore_interrupt f = uninterruptible (fn _ => f);
    89 fun raise_interrupt f = interruptible (fn _ => f);
    90 
    91 
    92 (* execution with time limit *)
    93 
    94 structure TimeLimit =
    95 struct
    96 
    97 exception TimeOut;
    98 
    99 fun timeLimit time f x =
   100   uninterruptible (fn atts => fn () =>
   101     let
   102       val worker = Thread.self ();
   103       val timeout = ref false;
   104       val watchdog = Thread.fork (interruptible (fn _ => fn () =>
   105         (OS.Process.sleep time; timeout := true; Thread.interrupt worker)), []);
   106 
   107       (*RACE! timeout signal vs. external Interrupt*)
   108       val result = Exn.capture (with_attributes atts (fn _ => f)) x;
   109       val was_timeout = (case result of Exn.Exn Interrupt => ! timeout | _ => false);
   110 
   111       val _ = Thread.interrupt watchdog handle Thread _ => ();
   112     in if was_timeout then raise TimeOut else Exn.release result end) ();
   113 
   114 end;
   115 
   116 
   117 (* managed external processes -- with propagation of interrupts *)
   118 
   119 fun managed_process cmdline = uninterruptible (fn atts => fn () =>
   120   let
   121     val proc = Unix.execute (cmdline, []);
   122     val (proc_stdout, proc_stdin) = Unix.streamsOf proc;
   123     val _ = TextIO.closeOut proc_stdin;
   124 
   125     (*finished state*)
   126     val finished = ref false;
   127     val finished_mutex = Mutex.mutex ();
   128     val finished_cond = ConditionVar.conditionVar ();
   129     fun signal_finished () =
   130       (Mutex.lock finished_mutex; finished := true; Mutex.unlock finished_mutex;
   131         ConditionVar.signal finished_cond);
   132 
   133     val _ = Mutex.lock finished_mutex;
   134 
   135     (*reader thread*)
   136     val buffer = ref [];
   137     fun reader () =
   138       (case Exn.capture TextIO.input proc_stdout of
   139         Exn.Exn Interrupt => ()
   140       | Exn.Exn _ => signal_finished ()
   141       | Exn.Result "" => signal_finished ()
   142       | Exn.Result txt => (change buffer (cons txt); reader ()));
   143     val reader_thread = Thread.fork (reader, []);
   144 
   145     (*main thread*)
   146     val () =
   147       while not (! finished) do with_attributes atts (fn _ => fn () =>
   148         ((ConditionVar.waitUntil (finished_cond, finished_mutex, Time.now () + Time.fromSeconds 1); ())
   149           handle Interrupt => Unix.kill (proc, Posix.Signal.int))) ();  (* FIXME lock!?! *)
   150     val _ = Thread.interrupt reader_thread handle Thread _ => ();
   151 
   152     val status = OS.Process.isSuccess (Unix.reap proc);
   153     val output = implode (rev (! buffer));
   154   in (output, status) end) ();
   155 
   156 
   157 (* critical section -- may be nested within the same thread *)
   158 
   159 local
   160 
   161 val critical_lock = Mutex.mutex ();
   162 val critical_thread = ref (NONE: Thread.thread option);
   163 val critical_name = ref "";
   164 
   165 in
   166 
   167 fun self_critical () =
   168   (case ! critical_thread of
   169     NONE => false
   170   | SOME id => Thread.equal (id, Thread.self ()));
   171 
   172 fun NAMED_CRITICAL name e =
   173   if self_critical () then e ()
   174   else
   175     uninterruptible (fn atts => fn () =>
   176       let
   177         val name' = ! critical_name;
   178         val _ =
   179           if Mutex.trylock critical_lock then ()
   180           else
   181             let
   182               val timer = Timer.startRealTimer ();
   183               val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
   184               val _ = Mutex.lock critical_lock;
   185               val time = Timer.checkRealTimer timer;
   186               val _ = tracing (if Time.> (time, Time.fromMilliseconds 10) then 3 else 4) (fn () =>
   187                 "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
   188             in () end;
   189         val _ = critical_thread := SOME (Thread.self ());
   190         val _ = critical_name := name;
   191         val result = Exn.capture (with_attributes atts (fn _ => e)) ();
   192         val _ = critical_name := "";
   193         val _ = critical_thread := NONE;
   194         val _ = Mutex.unlock critical_lock;
   195       in Exn.release result end) ();
   196 
   197 fun CRITICAL e = NAMED_CRITICAL "" e;
   198 
   199 end;
   200 
   201 
   202 (* scheduling -- multiple threads working on a queue of tasks *)
   203 
   204 datatype 'a task =
   205   Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
   206 
   207 fun schedule n next_task = uninterruptible (fn _ => fn tasks =>
   208   let
   209     (*protected execution*)
   210     val lock = Mutex.mutex ();
   211     val protected_name = ref "";
   212     fun PROTECTED name e =
   213       let
   214         val name' = ! protected_name;
   215         val _ =
   216           if Mutex.trylock lock then ()
   217           else
   218             let
   219               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": waiting");
   220               val _ = Mutex.lock lock;
   221               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": passed");
   222             in () end;
   223         val _ = protected_name := name;
   224         val res = Exn.capture e ();
   225         val _ = protected_name := "";
   226         val _ = Mutex.unlock lock;
   227       in Exn.release res end;
   228 
   229     (*wakeup condition*)
   230     val wakeup = ConditionVar.conditionVar ();
   231     fun wakeup_all () = ConditionVar.broadcast wakeup;
   232     fun wait () = ConditionVar.wait (wakeup, lock);
   233     fun wait_timeout () = ConditionVar.waitUntil (wakeup, lock, Time.now () + Time.fromSeconds 1);
   234 
   235     (*queue of tasks*)
   236     val queue = ref tasks;
   237     val active = ref 0;
   238     fun trace_active () = tracing 1 (fn () => "SCHEDULE: " ^ Int.toString (! active) ^ " active");
   239     fun dequeue () =
   240       let
   241         val (next, tasks') = next_task (! queue);
   242         val _ = queue := tasks';
   243       in
   244         (case next of Wait =>
   245           (dec active; trace_active ();
   246             wait ();
   247             inc active; trace_active ();
   248             dequeue ())
   249         | _ => next)
   250       end;
   251 
   252     (*pool of running threads*)
   253     val status = ref ([]: exn list);
   254     val running = ref ([]: Thread.thread list);
   255     fun start f =
   256       (inc active;
   257        change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
   258     fun stop () =
   259       (dec active;
   260        change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
   261 
   262    (*worker thread*)
   263     fun worker () =
   264       (case PROTECTED "dequeue" dequeue of
   265         Task {body, cont, fail} =>
   266           (case Exn.capture (interruptible (fn _ => body)) () of
   267             Exn.Result () =>
   268               (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
   269           | Exn.Exn exn =>
   270               PROTECTED "fail" (fn () =>
   271                 (change status (cons exn); change queue fail; stop (); wakeup_all ())))
   272       | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
   273 
   274     (*main control: fork and wait*)
   275     fun fork 0 = ()
   276       | fork k = (start worker; fork (k - 1));
   277     val _ = PROTECTED "main" (fn () =>
   278      (fork (Int.max (n, 1));
   279       while not (List.null (! running)) do
   280       (trace_active ();
   281        if not (List.null (! status)) then (List.app Thread.interrupt (! running)) else ();
   282        wait_timeout ())));
   283 
   284   in ! status end);
   285 
   286 
   287 (* serial numbers *)
   288 
   289 local
   290 
   291 val serial_lock = Mutex.mutex ();
   292 val serial_count = ref 0;
   293 
   294 in
   295 
   296 val serial = uninterruptible (fn _ => fn () =>
   297   let
   298     val _ = Mutex.lock serial_lock;
   299     val res = inc serial_count;
   300     val _ = Mutex.unlock serial_lock;
   301   in res end);
   302 
   303 end;
   304 
   305 
   306 (* thread data *)
   307 
   308 val get_data = Thread.getLocal;
   309 val put_data = Thread.setLocal;
   310 
   311 end;
   312 
   313 structure BasicMultithreading: BASIC_MULTITHREADING = Multithreading;
   314 open BasicMultithreading;