src/Pure/Concurrent/future.ML
changeset 54649 99b9249b3e05
parent 54637 db3d3d99c69d
child 54671 d64a4ef26edb
--- a/src/Pure/Concurrent/future.ML	Mon Nov 25 21:36:10 2013 +0100
+++ b/src/Pure/Concurrent/future.ML	Thu Nov 28 12:54:39 2013 +0100
@@ -48,7 +48,6 @@
   val worker_group: unit -> group option
   val the_worker_group: unit -> group
   val worker_subgroup: unit -> group
-  val worker_context: string -> group -> ('a -> 'b) -> 'a -> 'b
   type 'a future
   val task_of: 'a future -> task
   val peek: 'a future -> 'a Exn.result option
@@ -68,6 +67,7 @@
   val joins: 'a future list -> 'a list
   val join: 'a future -> 'a
   val join_tasks: task list -> unit
+  val task_context: string -> group -> ('a -> 'b) -> 'a -> 'b
   val value_result: 'a Exn.result -> 'a future
   val value: 'a -> 'a future
   val cond_forks: params -> (unit -> 'a) list -> 'a future list
@@ -109,9 +109,6 @@
 
 fun worker_subgroup () = new_group (worker_group ());
 
-fun worker_context name group f x =
-  setmp_worker_task (Task_Queue.new_task group name NONE) f x;
-
 fun worker_joining e =
   (case worker_task () of
     NONE => e ()
@@ -471,7 +468,7 @@
 
 (* future jobs *)
 
-fun future_job group interrupts (e: unit -> 'a) =
+fun future_job group atts (e: unit -> 'a) =
   let
     val result = Single_Assignment.var "future" : 'a result;
     val pos = Position.thread_data ();
@@ -480,10 +477,7 @@
         val res =
           if ok then
             Exn.capture (fn () =>
-              Multithreading.with_attributes
-                (if interrupts
-                 then Multithreading.private_interrupts else Multithreading.no_interrupts)
-                (fn _ => Position.setmp_thread_data pos e ())) ()
+              Multithreading.with_attributes atts (fn _ => Position.setmp_thread_data pos e ())) ()
           else Exn.interrupt_exn;
       in assign_result group result (identify_result pos res) end;
   in (result, job) end;
@@ -504,7 +498,11 @@
         | SOME grp => grp);
       fun enqueue e queue =
         let
-          val (result, job) = future_job grp interrupts e;
+          val atts =
+            if interrupts
+            then Multithreading.private_interrupts
+            else Multithreading.no_interrupts;
+          val (result, job) = future_job grp atts e;
           val (task, queue') = Task_Queue.enqueue name grp deps pri job queue;
           val future = Future {promised = false, task = task, result = result};
         in (future, queue') end;
@@ -580,6 +578,23 @@
     |> join;
 
 
+(* task context for running thread *)
+
+fun task_context name group f x =
+  Multithreading.with_attributes Multithreading.no_interrupts (fn orig_atts =>
+    let
+      val (result, job) = future_job group orig_atts (fn () => f x);
+      val task =
+        SYNCHRONIZED "enroll" (fn () =>
+          Unsynchronized.change_result queue (Task_Queue.enroll (Thread.self ()) name group));
+      val _ = worker_exec (task, [job]);
+    in
+      (case Single_Assignment.peek result of
+        NONE => raise Fail "Missing task context result"
+      | SOME res => Exn.release res)
+    end);
+
+
 (* fast-path operations -- bypass task queue if possible *)
 
 fun value_result (res: 'a Exn.result) =
@@ -602,7 +617,8 @@
     let
       val task = task_of x;
       val group = Task_Queue.group_of_task task;
-      val (result, job) = future_job group true (fn () => f (join x));
+      val (result, job) =
+        future_job group Multithreading.private_interrupts (fn () => f (join x));
 
       val extended = SYNCHRONIZED "extend" (fn () =>
         (case Task_Queue.extend task job (! queue) of