src/Pure/ML-Systems/multithreading_polyml.ML
changeset 32295 400cc493d466
parent 32286 1fb5db48002d
child 32738 15bb09ca0378
--- a/src/Pure/ML-Systems/multithreading_polyml.ML	Thu Jul 30 23:50:11 2009 +0200
+++ b/src/Pure/ML-Systems/multithreading_polyml.ML	Sat Aug 01 00:09:45 2009 +0200
@@ -27,31 +27,6 @@
 structure Multithreading: MULTITHREADING =
 struct
 
-(* tracing *)
-
-val trace = ref 0;
-
-fun tracing level msg =
-  if level > ! trace then ()
-  else (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
-    handle _ (*sic*) => ();
-
-fun tracing_time detailed time =
-  tracing
-   (if not detailed then 5
-    else if Time.>= (time, Time.fromMilliseconds 1000) then 1
-    else if Time.>= (time, Time.fromMilliseconds 100) then 2
-    else if Time.>= (time, Time.fromMilliseconds 10) then 3
-    else if Time.>= (time, Time.fromMilliseconds 1) then 4 else 5);
-
-fun real_time f x =
-  let
-    val timer = Timer.startRealTimer ();
-    val () = f x;
-    val time = Timer.checkRealTimer timer;
-  in time end;
-
-
 (* options *)
 
 val available = true;
@@ -91,57 +66,76 @@
 val no_interrupts =
   [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptDefer];
 
-val regular_interrupts =
+val public_interrupts =
   [Thread.EnableBroadcastInterrupt true, Thread.InterruptState Thread.InterruptAsynchOnce];
 
-val restricted_interrupts =
+val private_interrupts =
   [Thread.EnableBroadcastInterrupt false, Thread.InterruptState Thread.InterruptAsynchOnce];
 
+val sync_interrupts = map
+  (fn x as Thread.InterruptState Thread.InterruptDefer => x
+    | Thread.InterruptState _ => Thread.InterruptState Thread.InterruptSynch
+    | x => x);
+
 val safe_interrupts = map
   (fn Thread.InterruptState Thread.InterruptAsynch =>
       Thread.InterruptState Thread.InterruptAsynchOnce
     | x => x);
 
-fun with_attributes new_atts f x =
+fun with_attributes new_atts e =
   let
     val orig_atts = safe_interrupts (Thread.getAttributes ());
     val result = Exn.capture (fn () =>
-      (Thread.setAttributes (safe_interrupts new_atts); f orig_atts x)) ();
+      (Thread.setAttributes (safe_interrupts new_atts); e orig_atts)) ();
     val _ = Thread.setAttributes orig_atts;
   in Exn.release result end;
 
 
-(* regular interruptibility *)
+(* portable wrappers *)
+
+fun interruptible f x = with_attributes public_interrupts (fn _ => f x);
 
-fun interruptible f x =
-  (Thread.testInterrupt (); with_attributes regular_interrupts (fn _ => fn x => f x) x);
-
-fun uninterruptible f =
-  with_attributes no_interrupts (fn atts => fn x =>
-    f (fn g => with_attributes atts (fn _ => fn y => g y)) x);
+fun uninterruptible f x =
+  with_attributes no_interrupts (fn atts =>
+    f (fn g => fn y => with_attributes atts (fn _ => g y)) x);
 
 
 (* synchronous wait *)
 
-fun sync_attributes e =
+fun sync_wait opt_atts time cond lock =
+  with_attributes
+    (sync_interrupts (case opt_atts of SOME atts => atts | NONE => Thread.getAttributes ()))
+    (fn _ =>
+      (case time of
+        SOME t => Exn.Result (ConditionVar.waitUntil (cond, lock, t))
+      | NONE => (ConditionVar.wait (cond, lock); Exn.Result true))
+      handle exn => Exn.Exn exn);
+
+
+(* tracing *)
+
+val trace = ref 0;
+
+fun tracing level msg =
+  if level > ! trace then ()
+  else uninterruptible (fn _ => fn () =>
+    (TextIO.output (TextIO.stdErr, (">>> " ^ msg () ^ "\n")); TextIO.flushOut TextIO.stdErr)
+      handle _ (*sic*) => ()) ();
+
+fun tracing_time detailed time =
+  tracing
+   (if not detailed then 5
+    else if Time.>= (time, Time.fromMilliseconds 1000) then 1
+    else if Time.>= (time, Time.fromMilliseconds 100) then 2
+    else if Time.>= (time, Time.fromMilliseconds 10) then 3
+    else if Time.>= (time, Time.fromMilliseconds 1) then 4 else 5);
+
+fun real_time f x =
   let
-    val orig_atts = Thread.getAttributes ();
-    val broadcast =
-      (case List.find (fn Thread.EnableBroadcastInterrupt _ => true | _ => false) orig_atts of
-        NONE => Thread.EnableBroadcastInterrupt false
-      | SOME att => att);
-    val interrupt_state =
-      (case List.find (fn Thread.InterruptState _ => true | _ => false) orig_atts of
-        NONE => Thread.InterruptState Thread.InterruptDefer
-      | SOME (state as Thread.InterruptState Thread.InterruptDefer) => state
-      | _ => Thread.InterruptState Thread.InterruptSynch);
-  in with_attributes [broadcast, interrupt_state] (fn _ => fn () => e ()) () end;
-
-fun sync_wait time cond lock =
-  sync_attributes (fn () =>
-    (case time of
-      SOME t => ConditionVar.waitUntil (cond, lock, t)
-    | NONE => (ConditionVar.wait (cond, lock); true)));
+    val timer = Timer.startRealTimer ();
+    val () = f x;
+    val time = Timer.checkRealTimer timer;
+  in time end;
 
 
 (* execution with time limit *)
@@ -169,7 +163,7 @@
 
 (* system shell processes, with propagation of interrupts *)
 
-fun system_out script = uninterruptible (fn restore_attributes => fn () =>
+fun system_out script = with_attributes no_interrupts (fn orig_atts =>
   let
     val script_name = OS.FileSys.tmpName ();
     val _ = write_file script_name script;
@@ -180,13 +174,12 @@
     (*result state*)
     datatype result = Wait | Signal | Result of int;
     val result = ref Wait;
-    val result_mutex = Mutex.mutex ();
-    val result_cond = ConditionVar.conditionVar ();
+    val lock = Mutex.mutex ();
+    val cond = ConditionVar.conditionVar ();
     fun set_result res =
-      (Mutex.lock result_mutex; result := res; Mutex.unlock result_mutex;
-        ConditionVar.signal result_cond);
+      (Mutex.lock lock; result := res; ConditionVar.signal cond; Mutex.unlock lock);
 
-    val _ = Mutex.lock result_mutex;
+    val _ = Mutex.lock lock;
 
     (*system thread*)
     val system_thread = Thread.fork (fn () =>
@@ -216,11 +209,12 @@
       handle OS.SysErr _ => () | IO.Io _ =>
         (OS.Process.sleep (Time.fromMilliseconds 100); if n > 0 then kill (n - 1) else ());
 
-    val _ = while ! result = Wait do
-      restore_attributes (fn () =>
-        (ignore (sync_wait (SOME (Time.+ (Time.now (), Time.fromMilliseconds 100)))
-            result_cond result_mutex)
-          handle Exn.Interrupt => kill 10)) ();
+    val _ =
+      while ! result = Wait do
+        let val res =
+          sync_wait (SOME orig_atts)
+            (SOME (Time.+ (Time.now (), Time.fromMilliseconds 100))) cond lock
+        in case res of Exn.Exn Exn.Interrupt => kill 10 | _ => () end;
 
     (*cleanup*)
     val output = read_file output_name handle IO.Io _ => "";
@@ -229,7 +223,7 @@
     val _ = OS.FileSys.remove output_name handle OS.SysErr _ => ();
     val _ = Thread.interrupt system_thread handle Thread _ => ();
     val rc = (case ! result of Signal => raise Exn.Interrupt | Result rc => rc);
-  in (output, rc) end) ();
+  in (output, rc) end);
 
 
 (* critical section -- may be nested within the same thread *)