propagate exceptions within future groups;
authorwenzelm
Tue, 21 Jul 2009 15:25:22 +0200
changeset 32099 5382c93108db
parent 32098 c1e280ab4746
child 32100 8ac6b1102f16
propagate exceptions within future groups; Future.map: inherit group;
src/Pure/Concurrent/future.ML
src/Pure/Concurrent/task_queue.ML
--- a/src/Pure/Concurrent/future.ML	Tue Jul 21 13:46:18 2009 +0200
+++ b/src/Pure/Concurrent/future.ML	Tue Jul 21 15:25:22 2009 +0200
@@ -151,7 +151,20 @@
   count_active (! workers) > Multithreading.max_threads_value ();
 
 
-(* execute jobs *)
+(* execute future jobs *)
+
+fun future_job group (e: unit -> 'a) =
+  let
+    val result = ref (NONE: 'a Exn.result option);
+    val job = Multithreading.with_attributes Multithreading.restricted_interrupts
+      (fn _ => fn ok =>
+        let
+          val res = if ok then Exn.capture e () else Exn.Exn Exn.Interrupt;
+          val _ = result := SOME res;
+          val _ = (case res of Exn.Exn exn => Task_Queue.cancel_group group exn | _ => ());
+          val res_ok = is_some (Exn.get_result res);
+        in res_ok end);
+  in (result, job) end;
 
 fun do_cancel group = (*requires SYNCHRONIZED*)
   change canceled (insert Task_Queue.eq_group group);
@@ -159,7 +172,7 @@
 fun execute name (task, group, jobs) =
   let
     val _ = trace_active ();
-    val valid = Task_Queue.is_valid group;
+    val valid = null (Task_Queue.group_exns group);
     val ok = setmp_thread_data (name, task, group) (fn () =>
       fold (fn job => fn ok => job valid andalso ok) jobs true) ();
     val _ = SYNCHRONIZED "execute" (fn () =>
@@ -260,25 +273,6 @@
 
 (** futures **)
 
-(* future job: fill result *)
-
-fun future_job group (e: unit -> 'a) =
-  let
-    val result = ref (NONE: 'a Exn.result option);
-    val job = Multithreading.with_attributes Multithreading.restricted_interrupts
-      (fn _ => fn ok =>
-        let
-          val res = if ok then Exn.capture e () else Exn.Exn Exn.Interrupt;
-          val _ = result := SOME res;
-          val res_ok =
-            (case res of
-              Exn.Result _ => true
-            | Exn.Exn Exn.Interrupt => (Task_Queue.invalidate_group group; true)
-            | _ => false);
-        in res_ok end);
-  in (result, job) end;
-
-
 (* fork *)
 
 fun fork_future opt_group deps pri e =
@@ -302,7 +296,11 @@
 
 local
 
-fun get_result x = the_default (Exn.Exn (SYS_ERROR "unfinished future")) (peek x);
+fun get_result x =
+  (case peek x of
+    SOME (Exn.Exn Exn.Interrupt) => Exn.Exn (Exn.EXCEPTIONS (Task_Queue.group_exns (group_of x)))
+  | SOME res => res
+  | NONE => Exn.Exn (SYS_ERROR "unfinished future"));
 
 fun join_next deps = (*requires SYNCHRONIZED*)
   if overloaded () then (worker_wait (); join_next deps)
@@ -356,7 +354,7 @@
       | NONE => false));
   in
     if extended then Future {task = task, group = group, result = result}
-    else fork_future NONE [task] (Task_Queue.pri_of_task task) (fn () => f (join x))
+    else fork_future (SOME group) [task] (Task_Queue.pri_of_task task) (fn () => f (join x))
   end;
 
 
--- a/src/Pure/Concurrent/task_queue.ML	Tue Jul 21 13:46:18 2009 +0200
+++ b/src/Pure/Concurrent/task_queue.ML	Tue Jul 21 15:25:22 2009 +0200
@@ -14,8 +14,7 @@
   val group_id: group -> int
   val eq_group: group * group -> bool
   val new_group: unit -> group
-  val is_valid: group -> bool
-  val invalidate_group: group -> unit
+  val group_exns: group -> exn list
   val str_of_group: group -> string
   type queue
   val empty: queue
@@ -28,6 +27,7 @@
     (((task * group * (bool -> bool) list) * task list) option * queue)
   val interrupt: queue -> task -> unit
   val interrupt_external: queue -> string -> unit
+  val cancel_group: group -> exn -> unit
   val cancel: queue -> group -> bool
   val cancel_all: queue -> group list
   val finish: task -> queue -> queue
@@ -50,18 +50,17 @@
 
 (* groups *)
 
-datatype group = Group of serial * bool ref;
+datatype group = Group of serial * exn list ref;
 
 fun group_id (Group (gid, _)) = gid;
 fun eq_group (Group (gid1, _), Group (gid2, _)) = gid1 = gid2;
 
-fun new_group () = Group (serial (), ref true);
+fun new_group () = Group (serial (), ref []);
 
-fun is_valid (Group (_, ref ok)) = ok;
-fun invalidate_group (Group (_, ok)) = ok := false;
+fun group_exns (Group (_, ref exns)) = exns;
 
-fun str_of_group (Group (i, ref ok)) =
-  if ok then string_of_int i else enclose "(" ")" (string_of_int i);
+fun str_of_group (Group (i, ref exns)) =
+  if null exns then string_of_int i else enclose "(" ")" (string_of_int i);
 
 
 (* jobs *)
@@ -195,9 +194,14 @@
 
 (* termination *)
 
+fun cancel_group (Group (_, r)) exn = CRITICAL (fn () =>
+  (case exn of
+    Exn.Interrupt => if null (! r) then r := [exn] else ()
+  | _ => change r (cons exn)));
+
 fun cancel (Queue {groups, jobs, ...}) (group as Group (gid, _)) =
   let
-    val _ = invalidate_group group;
+    val _ = cancel_group group Exn.Interrupt;
     val tasks = Inttab.lookup_list groups gid;
     val running = fold (get_job jobs #> (fn Running t => insert Thread.equal t | _ => I)) tasks [];
     val _ = List.app SimpleThread.interrupt running;
@@ -206,7 +210,7 @@
 fun cancel_all (Queue {jobs, ...}) =
   let
     fun cancel_job (group, job) (groups, running) =
-      (invalidate_group group;
+      (cancel_group group Exn.Interrupt;
         (case job of Running t => (insert eq_group group groups, insert Thread.equal t running)
         | _ => (groups, running)));
     val (groups, running) = Task_Graph.fold (cancel_job o #1 o #2) jobs ([], []);