src/Pure/Concurrent/future.ML
changeset 28177 8c0335bc9336
parent 28170 a18cf8a0e656
child 28186 6a8417f36837
--- a/src/Pure/Concurrent/future.ML	Tue Sep 09 16:29:32 2008 +0200
+++ b/src/Pure/Concurrent/future.ML	Tue Sep 09 16:29:34 2008 +0200
@@ -11,13 +11,10 @@
   type group = TaskQueue.group
   type 'a T
   val task_of: 'a T -> task
-  val group_of: 'a T -> group option
-  val interrupt_task: task -> unit
-  val interrupt_group: group -> unit
-  val interrupt_task_group: task -> unit
-  val interrupt: 'a T -> unit
-  val shutdown: unit -> unit
-  val future: group option -> task list -> (unit -> 'a) -> 'a T
+  val group_of: 'a T -> group
+  val shutdown_request: unit -> unit
+  val cancel: 'a T -> unit
+  val future: bool -> task list -> (unit -> 'a) -> 'a T
   val fork: (unit -> 'a) -> 'a T
   val join: 'a T -> 'a
 end;
@@ -25,19 +22,16 @@
 structure Future: FUTURE =
 struct
 
+(** future values **)
+
 (* identifiers *)
 
 type task = TaskQueue.task;
 type group = TaskQueue.group;
 
-local val tag = Universal.tag () : task option Universal.tag in
-  fun get_task () = the_default NONE (Thread.getLocal tag);
-  fun set_task x = Thread.setLocal (tag, x);
-end;
-
-local val tag = Universal.tag () : group option Universal.tag in
-  fun get_group () = the_default NONE (Thread.getLocal tag);
-  fun set_group x = Thread.setLocal (tag, x);
+local val tag = Universal.tag () : (task * group) option Universal.tag in
+  fun thread_data () = the_default NONE (Thread.getLocal tag);
+  fun set_thread_data x = Thread.setLocal (tag, x);
 end;
 
 
@@ -45,14 +39,40 @@
 
 datatype 'a T = Future of
  {task: task,
-  group: group option,
+  group: group,
   result: 'a Exn.result option ref};
 
 fun task_of (Future {task, ...}) = task;
 fun group_of (Future {group, ...}) = group;
 
 
-(* synchronized execution *)
+
+(** scheduling **)
+
+(* global state *)
+
+val queue = ref TaskQueue.empty;
+val workers = ref ([]: Thread.thread list);
+val scheduler = ref (NONE: Thread.thread option);
+
+val excessive = ref 0;
+val active = ref 0;
+
+fun trace_active () =
+  Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int (! active) ^ " active");
+
+
+(* requests *)
+
+datatype request = Shutdown | Cancel of group;
+val requests = Mailbox.create () : request Mailbox.T;
+
+fun shutdown_request () = Mailbox.send requests Shutdown;
+fun cancel_request group = Mailbox.send requests (Cancel group);
+fun cancel x = cancel_request (group_of x);
+
+
+(* synchronization *)
 
 local
   val lock = Mutex.mutex ();
@@ -79,51 +99,31 @@
 end;
 
 
-(** scheduling **)
-
-datatype request = Shutdown | CancelGroup of group;
-val requests = Mailbox.create () : request Mailbox.T;
-
-val queue = ref TaskQueue.empty;
-val scheduler = ref (NONE: Thread.thread option);
-val workers = ref ([]: Thread.thread list);
-
+(* execute *)
 
-(* signals *)
-
-fun interrupt_task x = SYNCHRONIZED (fn () => TaskQueue.interrupt_task (! queue) x);
-fun interrupt_group x = SYNCHRONIZED (fn () => TaskQueue.interrupt_group (! queue) x);
-fun interrupt_task_group x = SYNCHRONIZED (fn () => TaskQueue.interrupt_task_group (! queue) x);
-
-fun interrupt (Future {task, ...}) = interrupt_task_group task;
-
-fun shutdown () = Mailbox.send requests Shutdown;
-
-
-(* execute *)
+fun cancel_group group = (*requires SYNCHRONIZED*)
+  (case change_result queue (TaskQueue.cancel group) of
+    [] => true
+  | running => (List.app (fn t => Thread.interrupt t handle Thread _ => ()) running; false));
 
 fun execute name (task, group, run) =
   let
-    val _ = set_task (SOME task);
-    val _ = set_group group;
+    val _ = set_thread_data (SOME (task, group));
     val _ = Multithreading.tracing 4 (fn () => name ^ ": running");
     val ok = run ();
     val _ = Multithreading.tracing 4 (fn () => name ^ ": finished");
-    val _ = set_task NONE;
-    val _ = set_group NONE;
-    val _ = SYNCHRONIZED (fn () => (change queue (TaskQueue.finished task); notify_all ()));
-    val _ = (case (ok, group) of (false, SOME g) => Mailbox.send requests (CancelGroup g) | _ => ());
+    val _ = set_thread_data NONE;
+    val _ = SYNCHRONIZED (fn () =>
+     (change queue (TaskQueue.finish task);
+      if ok then () else if cancel_group group then () else cancel_request group;
+      notify_all ()));
   in () end;
 
 
 (* worker threads *)
 
-val excessive = ref 0;
-val active = ref 0;
-
 fun change_active b = (*requires SYNCHRONIZED*)
- (change active (fn n => if b then n + 1 else n - 1);
-  Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int (! active) ^ " active"));
+ (change active (fn n => if b then n + 1 else n - 1); trace_active ());
 
 fun worker_next name = (*requires SYNCHRONIZED*)
   if ! excessive > 0 then
@@ -150,44 +150,59 @@
 
 fun scheduler_fork () = SYNCHRONIZED (fn () =>
   let
+    val _ = trace_active ();
     val m = Multithreading.max_threads_value ();
     val l = length (! workers);
     val _ = excessive := l - m;
   in List.app (fn i => worker_start ("worker " ^ string_of_int i)) (l upto m - 1) end);
 
-fun scheduler_loop () =
-  (scheduler_fork ();
-    (case Mailbox.receive_timeout (Time.fromMilliseconds 300) requests of
-      SOME Shutdown => ()   (* FIXME *)
-    | SOME (CancelGroup group) => (interrupt_group group; scheduler_loop ())  (* FIXME *)
-    | NONE => scheduler_loop ()));
+fun scheduler_loop canceled =
+  let
+    val canceled' = SYNCHRONIZED (fn () => filter_out cancel_group canceled);
+    val _ = scheduler_fork ();
+  in
+    (case Mailbox.receive_timeout (Time.fromSeconds 1) requests of
+      SOME Shutdown => ()   (* FIXME proper worker shutdown *)
+    | SOME (Cancel group) => scheduler_loop (group :: canceled')
+    | NONE => scheduler_loop canceled')
+  end;
 
 fun check_scheduler () = SYNCHRONIZED (fn () =>
   if (case ! scheduler of NONE => false | SOME thread => Thread.isActive thread) then ()
-  else scheduler := SOME (Thread.fork (scheduler_loop, Multithreading.no_interrupts)));
+  else scheduler := SOME (Thread.fork (fn () => scheduler_loop [], Multithreading.no_interrupts)));
 
 
 (* future values *)
 
-fun future group deps (e: unit -> 'a) =
+fun future new_group deps (e: unit -> 'a) =
   let
     val _ = check_scheduler ();
+
+    val group =
+      (case (new_group, thread_data ()) of
+        (false, SOME (_, group)) => group
+      | _ => TaskQueue.new_group ());
+
     val result = ref (NONE: 'a Exn.result option);
-    val run = Multithreading.with_attributes (Thread.getAttributes ()) (fn _ => fn () =>
-      let val res = Exn.capture e () in result := SOME res; is_some (Exn.get_result res) end);
+    val run = Multithreading.with_attributes (Thread.getAttributes ())
+      (fn _ => fn ok =>
+        let val res = if ok then Exn.capture e () else Exn.Exn Interrupt
+        in result := SOME res; is_some (Exn.get_result res) end);
+
     val task = SYNCHRONIZED (fn () =>
       change_result queue (TaskQueue.enqueue group deps run) before notify_all ());
   in Future {task = task, group = group, result = result} end;
 
-fun fork e = future (get_group ()) [] e;
+fun fork e = future false [] e;
 
 fun join (Future {result, ...}) =
   let
     val _ = check_scheduler ();
-    fun loop () =
+
+    fun passive_loop () =
       (case ! result of
-        NONE => (wait "join"; loop ())
+        NONE => (wait "join"; passive_loop ())
       | SOME res => res);
-  in Exn.release (SYNCHRONIZED loop) end;
+  in Exn.release (SYNCHRONIZED passive_loop) end;
 
 end;