src/Pure/ML-Systems/multithreading_polyml.ML
author wenzelm
Mon Mar 31 23:08:55 2008 +0200 (2008-03-31)
changeset 26504 6e87c0a60104
parent 26493 de4764e95166
child 28124 10a1f1f4c6ae
permissions -rw-r--r--
before close: Exn.capture/release;
     1 (*  Title:      Pure/ML-Systems/multithreading_polyml.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Multithreading in Poly/ML 5.1 (cf. polyml/basis/Thread.sml).
     6 *)
     7 
     8 open Thread;
     9 
    10 signature MULTITHREADING_POLYML =
    11 sig
    12   val interruptible: ('a -> 'b) -> 'a -> 'b
    13   val uninterruptible: ((('c -> 'd) -> 'c -> 'd) -> 'a -> 'b) -> 'a -> 'b
    14   val system_out: string -> string * int
    15   structure TimeLimit: TIME_LIMIT
    16   val profile: int -> ('a -> 'b) -> 'a -> 'b
    17 end;
    18 
    19 signature BASIC_MULTITHREADING =
    20 sig
    21   include BASIC_MULTITHREADING
    22   include MULTITHREADING_POLYML
    23 end;
    24 
    25 signature MULTITHREADING =
    26 sig
    27   include MULTITHREADING
    28   include MULTITHREADING_POLYML
    29 end;
    30 
    31 structure Multithreading: MULTITHREADING =
    32 struct
    33 
    34 (* options *)
    35 
    36 val trace = ref 0;
    37 fun tracing level msg =
    38   if level <= ! trace
    39   then (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
    40   else ();
    41 
    42 val available = true;
    43 
    44 val max_threads = ref 1;
    45 
    46 fun max_threads_value () =
    47   let val m = ! max_threads
    48   in if m <= 0 then Thread.numProcessors () else m end;
    49 
    50 
    51 (* misc utils *)
    52 
    53 fun cons x xs = x :: xs;
    54 
    55 fun change r f = r := f (! r);
    56 
    57 fun inc i = (i := ! i + 1; ! i);
    58 fun dec i = (i := ! i - 1; ! i);
    59 
    60 fun show "" = "" | show name = " " ^ name;
    61 fun show' "" = "" | show' name = " [" ^ name ^ "]";
    62 
    63 fun read_file name =
    64   let val is = TextIO.openIn name
    65   in Exn.release (Exn.capture TextIO.inputAll is before TextIO.closeIn is) end;
    66 
    67 fun write_file name txt =
    68   let val os = TextIO.openOut name
    69   in Exn.release (Exn.capture TextIO.output (os, txt) before TextIO.closeOut os) end;
    70 
    71 
    72 (* thread attributes *)
    73 
    74 fun with_attributes new_atts f x =
    75   let
    76     val orig_atts = Thread.getAttributes ();
    77     fun restore () = Thread.setAttributes orig_atts;
    78   in
    79     Exn.release
    80     (*RACE for fully asynchronous interrupts!*)
    81     (let
    82         val _ = Thread.setAttributes new_atts;
    83         val result = Exn.capture (f orig_atts) x;
    84         val _ = restore ();
    85       in result end
    86       handle Interrupt => (restore (); Exn.Exn Interrupt))
    87   end;
    88 
    89 fun interruptible f =
    90   with_attributes
    91     [Thread.EnableBroadcastInterrupt true, Thread.InterruptState Thread.InterruptAsynchOnce]
    92     (fn _ => f);
    93 
    94 fun uninterruptible f =
    95   with_attributes
    96     [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptDefer]
    97     (fn atts => f (fn g => with_attributes atts (fn _ => g)));
    98 
    99 
   100 (* execution with time limit *)
   101 
   102 structure TimeLimit =
   103 struct
   104 
   105 exception TimeOut;
   106 
   107 fun timeLimit time f x = uninterruptible (fn restore_attributes => fn () =>
   108   let
   109     val worker = Thread.self ();
   110     val timeout = ref false;
   111     val watchdog = Thread.fork (fn () =>
   112       (OS.Process.sleep time; timeout := true; Thread.interrupt worker), []);
   113 
   114     (*RACE! timeout signal vs. external Interrupt*)
   115     val result = Exn.capture (restore_attributes f) x;
   116     val was_timeout = (case result of Exn.Exn Interrupt => ! timeout | _ => false);
   117 
   118     val _ = Thread.interrupt watchdog handle Thread _ => ();
   119   in if was_timeout then raise TimeOut else Exn.release result end) ();
   120 
   121 end;
   122 
   123 
   124 (* system shell processes, with propagation of interrupts *)
   125 
   126 fun system_out_threaded script = uninterruptible (fn restore_attributes => fn () =>
   127   let
   128     val script_name = OS.FileSys.tmpName ();
   129     val _ = write_file script_name script;
   130 
   131     val pid_name = OS.FileSys.tmpName ();
   132     val output_name = OS.FileSys.tmpName ();
   133 
   134     (*result state*)
   135     datatype result = Wait | Signal | Result of int;
   136     val result = ref Wait;
   137     val result_mutex = Mutex.mutex ();
   138     val result_cond = ConditionVar.conditionVar ();
   139     fun set_result res =
   140       (Mutex.lock result_mutex; result := res; Mutex.unlock result_mutex;
   141         ConditionVar.signal result_cond);
   142 
   143     val _ = Mutex.lock result_mutex;
   144 
   145     (*system thread*)
   146     val system_thread = Thread.fork (fn () =>
   147       let
   148         val status =
   149           OS.Process.system ("perl -w \"$ISABELLE_HOME/lib/scripts/system.pl\" group " ^
   150             script_name ^ " " ^ pid_name ^ " " ^ output_name);
   151         val res =
   152           (case Posix.Process.fromStatus status of
   153             Posix.Process.W_EXITED => Result 0
   154           | Posix.Process.W_EXITSTATUS 0wx82 => Signal
   155           | Posix.Process.W_EXITSTATUS w => Result (Word8.toInt w)
   156           | Posix.Process.W_SIGNALED s =>
   157               if s = Posix.Signal.int then Signal
   158               else Result (256 + LargeWord.toInt (Posix.Signal.toWord s))
   159           | Posix.Process.W_STOPPED s => Result (512 + LargeWord.toInt (Posix.Signal.toWord s)));
   160       in set_result res end handle _ => set_result (Result 2), []);
   161 
   162     (*main thread -- proxy for interrupts*)
   163     fun kill n =
   164       (case Int.fromString (read_file pid_name) of
   165         SOME pid =>
   166           Posix.Process.kill
   167             (Posix.Process.K_GROUP (Posix.Process.wordToPid (LargeWord.fromInt pid)),
   168               Posix.Signal.int)
   169       | NONE => ())
   170       handle OS.SysErr _ => () | IO.Io _ =>
   171         (OS.Process.sleep (Time.fromMilliseconds 100); if n > 0 then kill (n - 1) else ());
   172 
   173     val _ = while ! result = Wait do
   174       restore_attributes (fn () =>
   175         (ConditionVar.waitUntil (result_cond, result_mutex, Time.now () + Time.fromMilliseconds 100); ())
   176           handle Interrupt => kill 10) ();
   177 
   178     (*cleanup*)
   179     val output = read_file output_name handle IO.Io _ => "";
   180     val _ = OS.FileSys.remove script_name handle OS.SysErr _ => ();
   181     val _ = OS.FileSys.remove pid_name handle OS.SysErr _ => ();
   182     val _ = OS.FileSys.remove output_name handle OS.SysErr _ => ();
   183     val _ = Thread.interrupt system_thread handle Thread _ => ();
   184     val rc = (case ! result of Signal => raise Interrupt | Result rc => rc);
   185   in (output, rc) end) ();
   186 
   187 val system_out =
   188   if ml_system = "polyml-5.1" then system_out  (*signals not propagated from root thread!*)
   189   else system_out_threaded;
   190 
   191 
   192 (* critical section -- may be nested within the same thread *)
   193 
   194 local
   195 
   196 val critical_lock = Mutex.mutex ();
   197 val critical_thread = ref (NONE: Thread.thread option);
   198 val critical_name = ref "";
   199 
   200 in
   201 
   202 fun self_critical () =
   203   (case ! critical_thread of
   204     NONE => false
   205   | SOME id => Thread.equal (id, Thread.self ()));
   206 
   207 fun NAMED_CRITICAL name e =
   208   if self_critical () then e ()
   209   else
   210     uninterruptible (fn restore_attributes => fn () =>
   211       let
   212         val name' = ! critical_name;
   213         val _ =
   214           if Mutex.trylock critical_lock then ()
   215           else
   216             let
   217               val timer = Timer.startRealTimer ();
   218               val _ = tracing 4 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
   219               val _ = Mutex.lock critical_lock;
   220               val time = Timer.checkRealTimer timer;
   221               val trace_time =
   222                 if Time.>= (time, Time.fromMilliseconds 1000) then 1
   223                 else if Time.>= (time, Time.fromMilliseconds 100) then 2
   224                 else if Time.>= (time, Time.fromMilliseconds 10) then 3 else 4;
   225               val _ = tracing trace_time (fn () =>
   226                 "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
   227             in () end;
   228         val _ = critical_thread := SOME (Thread.self ());
   229         val _ = critical_name := name;
   230         val result = Exn.capture (restore_attributes e) ();
   231         val _ = critical_name := "";
   232         val _ = critical_thread := NONE;
   233         val _ = Mutex.unlock critical_lock;
   234       in Exn.release result end) ();
   235 
   236 fun CRITICAL e = NAMED_CRITICAL "" e;
   237 
   238 end;
   239 
   240 
   241 (* scheduling -- multiple threads working on a queue of tasks *)
   242 
   243 datatype 'a task =
   244   Task of {body: unit -> unit, cont: 'a -> 'a, fail: 'a -> 'a} | Wait | Terminate;
   245 
   246 fun schedule n next_task = uninterruptible (fn restore_attributes => fn tasks =>
   247   let
   248     (*protected execution*)
   249     val lock = Mutex.mutex ();
   250     val protected_name = ref "";
   251     fun PROTECTED name e =
   252       let
   253         val name' = ! protected_name;
   254         val _ =
   255           if Mutex.trylock lock then ()
   256           else
   257             let
   258               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": waiting");
   259               val _ = Mutex.lock lock;
   260               val _ = tracing 2 (fn () => "PROTECTED" ^ show name ^ show' name' ^ ": passed");
   261             in () end;
   262         val _ = protected_name := name;
   263         val res = Exn.capture e ();
   264         val _ = protected_name := "";
   265         val _ = Mutex.unlock lock;
   266       in Exn.release res end;
   267 
   268     (*wakeup condition*)
   269     val wakeup = ConditionVar.conditionVar ();
   270     fun wakeup_all () = ConditionVar.broadcast wakeup;
   271     fun wait () = ConditionVar.wait (wakeup, lock);
   272     fun wait_timeout () = ConditionVar.waitUntil (wakeup, lock, Time.now () + Time.fromSeconds 1);
   273 
   274     (*queue of tasks*)
   275     val queue = ref tasks;
   276     val active = ref 0;
   277     fun trace_active () = tracing 1 (fn () => "SCHEDULE: " ^ Int.toString (! active) ^ " active");
   278     fun dequeue () =
   279       let
   280         val (next, tasks') = next_task (! queue);
   281         val _ = queue := tasks';
   282       in
   283         (case next of Wait =>
   284           (dec active; trace_active ();
   285             wait ();
   286             inc active; trace_active ();
   287             dequeue ())
   288         | _ => next)
   289       end;
   290 
   291     (*pool of running threads*)
   292     val status = ref ([]: exn list);
   293     val running = ref ([]: Thread.thread list);
   294     fun start f =
   295       (inc active;
   296        change running (cons (Thread.fork (f, [Thread.InterruptState Thread.InterruptDefer]))));
   297     fun stop () =
   298       (dec active;
   299        change running (List.filter (fn t => not (Thread.equal (t, Thread.self ())))));
   300 
   301    (*worker thread*)
   302     fun worker () =
   303       (case PROTECTED "dequeue" dequeue of
   304         Task {body, cont, fail} =>
   305           (case Exn.capture (restore_attributes body) () of
   306             Exn.Result () =>
   307               (PROTECTED "cont" (fn () => (change queue cont; wakeup_all ())); worker ())
   308           | Exn.Exn exn =>
   309               PROTECTED "fail" (fn () =>
   310                 (change status (cons exn); change queue fail; stop (); wakeup_all ())))
   311       | Terminate => PROTECTED "terminate" (fn () => (stop (); wakeup_all ())));
   312 
   313     (*main control: fork and wait*)
   314     fun fork 0 = ()
   315       | fork k = (start worker; fork (k - 1));
   316     val _ = PROTECTED "main" (fn () =>
   317      (fork (Int.max (n, 1));
   318       while not (List.null (! running)) do
   319       (trace_active ();
   320        if not (List.null (! status))
   321        then (List.app (fn t => Thread.interrupt t handle Thread _ => ()) (! running))
   322        else ();
   323        wait_timeout ())));
   324 
   325   in ! status end);
   326 
   327 
   328 (* profiling *)
   329 
   330 local val profile_orig = profile in
   331 
   332 fun profile 0 f x = f x
   333   | profile n f x = NAMED_CRITICAL "profile" (fn () => profile_orig n f x);
   334 
   335 end;
   336 
   337 
   338 (* serial numbers *)
   339 
   340 local
   341 
   342 val serial_lock = Mutex.mutex ();
   343 val serial_count = ref 0;
   344 
   345 in
   346 
   347 val serial = uninterruptible (fn _ => fn () =>
   348   let
   349     val _ = Mutex.lock serial_lock;
   350     val res = inc serial_count;
   351     val _ = Mutex.unlock serial_lock;
   352   in res end);
   353 
   354 end;
   355 
   356 
   357 (* thread data *)
   358 
   359 val get_data = Thread.getLocal;
   360 val put_data = Thread.setLocal;
   361 
   362 end;
   363 
   364 structure BasicMultithreading: BASIC_MULTITHREADING = Multithreading;
   365 open BasicMultithreading;