--- a/src/Pure/Concurrent/future.ML Sun Jul 19 17:08:34 2009 +0200
+++ b/src/Pure/Concurrent/future.ML Sun Jul 19 18:02:40 2009 +0200
@@ -29,7 +29,7 @@
val enabled: unit -> bool
type task = Task_Queue.task
type group = Task_Queue.group
- val thread_data: unit -> (string * task) option
+ val is_worker: unit -> bool
type 'a future
val task_of: 'a future -> task
val group_of: 'a future -> group
@@ -40,6 +40,7 @@
val fork_group: group -> (unit -> 'a) -> 'a future
val fork_deps: 'b future list -> (unit -> 'a) -> 'a future
val fork_pri: int -> (unit -> 'a) -> 'a future
+ val fork_local: int -> (unit -> 'a) -> 'a future
val join_results: 'a future list -> 'a Exn.result list
val join_result: 'a future -> 'a Exn.result
val join: 'a future -> 'a
@@ -66,11 +67,16 @@
type task = Task_Queue.task;
type group = Task_Queue.group;
-local val tag = Universal.tag () : (string * task) option Universal.tag in
+local
+ val tag = Universal.tag () : (string * task * group) option Universal.tag;
+in
fun thread_data () = the_default NONE (Thread.getLocal tag);
- fun setmp_thread_data data f x = Library.setmp_thread_data tag (thread_data ()) (SOME data) f x;
+ fun setmp_thread_data data f x =
+ Library.setmp_thread_data tag (thread_data ()) (SOME data) f x;
end;
+val is_worker = is_some o thread_data;
+
(* datatype future *)
@@ -148,7 +154,7 @@
let
val _ = trace_active ();
val valid = Task_Queue.is_valid group;
- val ok = setmp_thread_data (name, task) (fn () =>
+ val ok = setmp_thread_data (name, task, group) (fn () =>
fold (fn job => fn ok => job valid andalso ok) jobs true) ();
val _ = SYNCHRONIZED "execute" (fn () =>
(change queue (Task_Queue.finish task);
@@ -277,6 +283,7 @@
fun fork_group group e = fork_future (SOME group) [] 0 e;
fun fork_deps deps e = fork_future NONE (map task_of deps) 0 e;
fun fork_pri pri e = fork_future NONE [] pri e;
+fun fork_local pri e = fork_future (Option.map #3 (thread_data ())) [] pri e;
(* join *)
@@ -300,13 +307,13 @@
val _ = Multithreading.self_critical () andalso
error "Cannot join future values within critical section";
- val is_worker = is_some (thread_data ());
+ val worker = is_worker ();
fun join_wait x =
if SYNCHRONIZED "join_wait" (fn () =>
- is_finished x orelse (if is_worker then worker_wait () else wait (); false))
+ is_finished x orelse (if worker then worker_wait () else wait (); false))
then () else join_wait x;
- val _ = if is_worker then join_deps (map task_of xs) else ();
+ val _ = if worker then join_deps (map task_of xs) else ();
val _ = List.app join_wait xs;
in map get_result xs end) ();
@@ -342,7 +349,7 @@
fun interruptible_task f x =
if Multithreading.available then
Multithreading.with_attributes
- (if is_some (thread_data ())
+ (if is_worker ()
then Multithreading.restricted_interrupts
else Multithreading.regular_interrupts)
(fn _ => f) x