src/Pure/ML-Systems/multithreading_polyml.ML
author wenzelm
Mon Sep 20 23:36:26 2010 +0200 (2010-09-20 ago)
changeset 39576 48baf61cb888
parent 39232 69c6d3e87660
child 39583 c1e9c6dfeff8
permissions -rw-r--r--
refined ML/Scala bash wrapper, based on more general lib/scripts/process;
     1 (*  Title:      Pure/ML-Systems/multithreading_polyml.ML
     2     Author:     Makarius
     3 
     4 Multithreading in Poly/ML 5.2.1 or later (cf. polyml/basis/Thread.sml).
     5 *)
     6 
     7 signature MULTITHREADING_POLYML =
     8 sig
     9   val interruptible: ('a -> 'b) -> 'a -> 'b
    10   val uninterruptible: ((('c -> 'd) -> 'c -> 'd) -> 'a -> 'b) -> 'a -> 'b
    11   val bash_output: string -> string * int
    12   structure TimeLimit: TIME_LIMIT
    13 end;
    14 
    15 signature BASIC_MULTITHREADING =
    16 sig
    17   include BASIC_MULTITHREADING
    18   include MULTITHREADING_POLYML
    19 end;
    20 
    21 signature MULTITHREADING =
    22 sig
    23   include MULTITHREADING
    24   include MULTITHREADING_POLYML
    25 end;
    26 
    27 structure Multithreading: MULTITHREADING =
    28 struct
    29 
    30 (* options *)
    31 
    32 val available = true;
    33 
    34 val max_threads = Unsynchronized.ref 0;
    35 
    36 fun max_threads_value () =
    37   let val m = ! max_threads in
    38     if m > 0 then m
    39     else Int.min (Int.max (Thread.numProcessors (), 1), 4)
    40   end;
    41 
    42 fun enabled () = max_threads_value () > 1;
    43 
    44 
    45 (* misc utils *)
    46 
    47 fun show "" = "" | show name = " " ^ name;
    48 fun show' "" = "" | show' name = " [" ^ name ^ "]";
    49 
    50 fun read_file name =
    51   let val is = TextIO.openIn name
    52   in Exn.release (Exn.capture TextIO.inputAll is before TextIO.closeIn is) end;
    53 
    54 fun write_file name txt =
    55   let val os = TextIO.openOut name
    56   in Exn.release (Exn.capture TextIO.output (os, txt) before TextIO.closeOut os) end;
    57 
    58 
    59 (* thread attributes *)
    60 
    61 val no_interrupts =
    62   [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptDefer];
    63 
    64 val public_interrupts =
    65   [Thread.EnableBroadcastInterrupt true, Thread.InterruptState Thread.InterruptAsynchOnce];
    66 
    67 val private_interrupts =
    68   [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptAsynchOnce];
    69 
    70 val sync_interrupts = map
    71   (fn x as Thread.InterruptState Thread.InterruptDefer => x
    72     | Thread.InterruptState _ => Thread.InterruptState Thread.InterruptSynch
    73     | x => x);
    74 
    75 val safe_interrupts = map
    76   (fn Thread.InterruptState Thread.InterruptAsynch =>
    77       Thread.InterruptState Thread.InterruptAsynchOnce
    78     | x => x);
    79 
    80 fun with_attributes new_atts e =
    81   let
    82     val orig_atts = safe_interrupts (Thread.getAttributes ());
    83     val result = Exn.capture (fn () =>
    84       (Thread.setAttributes (safe_interrupts new_atts); e orig_atts)) ();
    85     val _ = Thread.setAttributes orig_atts;
    86   in Exn.release result end;
    87 
    88 
    89 (* portable wrappers *)
    90 
    91 fun interruptible f x = with_attributes public_interrupts (fn _ => f x);
    92 
    93 fun uninterruptible f x =
    94   with_attributes no_interrupts (fn atts =>
    95     f (fn g => fn y => with_attributes atts (fn _ => g y)) x);
    96 
    97 
    98 (* synchronous wait *)
    99 
   100 fun sync_wait opt_atts time cond lock =
   101   with_attributes
   102     (sync_interrupts (case opt_atts of SOME atts => atts | NONE => Thread.getAttributes ()))
   103     (fn _ =>
   104       (case time of
   105         SOME t => Exn.Result (ConditionVar.waitUntil (cond, lock, t))
   106       | NONE => (ConditionVar.wait (cond, lock); Exn.Result true))
   107       handle exn => Exn.Exn exn);
   108 
   109 
   110 (* tracing *)
   111 
   112 val trace = Unsynchronized.ref 0;
   113 
   114 fun tracing level msg =
   115   if level > ! trace then ()
   116   else uninterruptible (fn _ => fn () =>
   117     (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
   118       handle _ (*sic*) => ()) ();
   119 
   120 fun tracing_time detailed time =
   121   tracing
   122    (if not detailed then 5
   123     else if Time.>= (time, Time.fromMilliseconds 1000) then 1
   124     else if Time.>= (time, Time.fromMilliseconds 100) then 2
   125     else if Time.>= (time, Time.fromMilliseconds 10) then 3
   126     else if Time.>= (time, Time.fromMilliseconds 1) then 4 else 5);
   127 
   128 fun real_time f x =
   129   let
   130     val timer = Timer.startRealTimer ();
   131     val () = f x;
   132     val time = Timer.checkRealTimer timer;
   133   in time end;
   134 
   135 
   136 (* execution with time limit *)
   137 
   138 structure TimeLimit =
   139 struct
   140 
   141 exception TimeOut;
   142 
   143 fun timeLimit time f x = uninterruptible (fn restore_attributes => fn () =>
   144   let
   145     val worker = Thread.self ();
   146     val timeout = Unsynchronized.ref false;
   147     val watchdog = Thread.fork (fn () =>
   148       (OS.Process.sleep time; timeout := true; Thread.interrupt worker), []);
   149 
   150     val result = Exn.capture (restore_attributes f) x;
   151     val was_timeout = Exn.is_interrupt_exn result andalso ! timeout;
   152 
   153     val _ = Thread.interrupt watchdog handle Thread _ => ();
   154   in if was_timeout then raise TimeOut else Exn.release result end) ();
   155 
   156 end;
   157 
   158 
   159 (* GNU bash processes, with propagation of interrupts *)
   160 
   161 fun bash_output script = with_attributes no_interrupts (fn orig_atts =>
   162   let
   163     val script_name = OS.FileSys.tmpName ();
   164     val _ = write_file script_name script;
   165 
   166     val pid_name = OS.FileSys.tmpName ();
   167     val output_name = OS.FileSys.tmpName ();
   168 
   169     (*result state*)
   170     datatype result = Wait | Signal | Result of int;
   171     val result = Unsynchronized.ref Wait;
   172     val lock = Mutex.mutex ();
   173     val cond = ConditionVar.conditionVar ();
   174     fun set_result res =
   175       (Mutex.lock lock; result := res; ConditionVar.signal cond; Mutex.unlock lock);
   176 
   177     val _ = Mutex.lock lock;
   178 
   179     (*system thread*)
   180     val system_thread = Thread.fork (fn () =>
   181       let
   182         val status =
   183           OS.Process.system ("exec \"$ISABELLE_HOME/lib/scripts/process\" group " ^ pid_name ^
   184             " \"exec bash " ^ script_name ^ " > " ^ output_name ^ "\"");
   185         val res =
   186           (case Posix.Process.fromStatus status of
   187             Posix.Process.W_EXITED => Result 0
   188           | Posix.Process.W_EXITSTATUS 0wx82 => Signal
   189           | Posix.Process.W_EXITSTATUS w => Result (Word8.toInt w)
   190           | Posix.Process.W_SIGNALED s =>
   191               if s = Posix.Signal.int then Signal
   192               else Result (256 + LargeWord.toInt (Posix.Signal.toWord s))
   193           | Posix.Process.W_STOPPED s => Result (512 + LargeWord.toInt (Posix.Signal.toWord s)));
   194       in set_result res end handle _ (*sic*) => set_result (Result 2), []);
   195 
   196     (*main thread -- proxy for interrupts*)
   197     fun kill n =
   198       (case Int.fromString (read_file pid_name) of
   199         SOME pid =>
   200           Posix.Process.kill
   201             (Posix.Process.K_GROUP (Posix.Process.wordToPid (LargeWord.fromInt pid)),
   202               Posix.Signal.int)
   203       | NONE => ())
   204       handle OS.SysErr _ => () | IO.Io _ =>
   205         (OS.Process.sleep (Time.fromMilliseconds 100); if n > 0 then kill (n - 1) else ());
   206 
   207     val _ =
   208       while ! result = Wait do
   209         let val res =
   210           sync_wait (SOME orig_atts)
   211             (SOME (Time.+ (Time.now (), Time.fromMilliseconds 100))) cond lock
   212         in if Exn.is_interrupt_exn res then kill 10 else () end;
   213 
   214     (*cleanup*)
   215     val output = read_file output_name handle IO.Io _ => "";
   216     val _ = OS.FileSys.remove script_name handle OS.SysErr _ => ();
   217     val _ = OS.FileSys.remove pid_name handle OS.SysErr _ => ();
   218     val _ = OS.FileSys.remove output_name handle OS.SysErr _ => ();
   219     val _ = Thread.interrupt system_thread handle Thread _ => ();
   220     val rc = (case ! result of Signal => Exn.interrupt () | Result rc => rc);
   221   in (output, rc) end);
   222 
   223 
   224 (* critical section -- may be nested within the same thread *)
   225 
   226 local
   227 
   228 val critical_lock = Mutex.mutex ();
   229 val critical_thread = Unsynchronized.ref (NONE: Thread.thread option);
   230 val critical_name = Unsynchronized.ref "";
   231 
   232 in
   233 
   234 fun self_critical () =
   235   (case ! critical_thread of
   236     NONE => false
   237   | SOME t => Thread.equal (t, Thread.self ()));
   238 
   239 fun NAMED_CRITICAL name e =
   240   if self_critical () then e ()
   241   else
   242     Exn.release (uninterruptible (fn restore_attributes => fn () =>
   243       let
   244         val name' = ! critical_name;
   245         val _ =
   246           if Mutex.trylock critical_lock then ()
   247           else
   248             let
   249               val _ = tracing 5 (fn () => "CRITICAL" ^ show name ^ show' name' ^ ": waiting");
   250               val time = real_time Mutex.lock critical_lock;
   251               val _ = tracing_time true time (fn () =>
   252                 "CRITICAL" ^ show name ^ show' name' ^ ": passed after " ^ Time.toString time);
   253             in () end;
   254         val _ = critical_thread := SOME (Thread.self ());
   255         val _ = critical_name := name;
   256         val result = Exn.capture (restore_attributes e) ();
   257         val _ = critical_name := "";
   258         val _ = critical_thread := NONE;
   259         val _ = Mutex.unlock critical_lock;
   260       in result end) ());
   261 
   262 fun CRITICAL e = NAMED_CRITICAL "" e;
   263 
   264 end;
   265 
   266 
   267 (* serial numbers *)
   268 
   269 local
   270 
   271 val serial_lock = Mutex.mutex ();
   272 val serial_count = Unsynchronized.ref 0;
   273 
   274 in
   275 
   276 val serial = uninterruptible (fn _ => fn () =>
   277   let
   278     val _ = Mutex.lock serial_lock;
   279     val _ = serial_count := ! serial_count + 1;
   280     val res = ! serial_count;
   281     val _ = Mutex.unlock serial_lock;
   282   in res end);
   283 
   284 end;
   285 
   286 end;
   287 
   288 structure Basic_Multithreading: BASIC_MULTITHREADING = Multithreading;
   289 open Basic_Multithreading;