src/Pure/Concurrent/future.ML
changeset 28167 27e2ca41b58c
parent 28166 43087721a66e
child 28170 a18cf8a0e656
--- a/src/Pure/Concurrent/future.ML	Mon Sep 08 16:08:23 2008 +0200
+++ b/src/Pure/Concurrent/future.ML	Mon Sep 08 20:33:24 2008 +0200
@@ -12,16 +12,46 @@
   type 'a T
   val task_of: 'a T -> task
   val group_of: 'a T -> group option
+  val interrupt_task: task -> unit
+  val interrupt_group: group -> unit
+  val interrupt_task_group: task -> unit
+  val interrupt: 'a T -> unit
+  val shutdown: unit -> unit
   val future: group option -> task list -> (unit -> 'a) -> 'a T
   val fork: (unit -> 'a) -> 'a T
   val join: 'a T -> 'a
-  val interrupt: task -> unit
-  val interrupt_group: group -> unit
 end;
 
 structure Future: FUTURE =
 struct
 
+(* identifiers *)
+
+type task = TaskQueue.task;
+type group = TaskQueue.group;
+
+local val tag = Universal.tag () : task option Universal.tag in
+  fun get_task () = the_default NONE (Thread.getLocal tag);
+  fun set_task x = Thread.setLocal (tag, x);
+end;
+
+local val tag = Universal.tag () : group option Universal.tag in
+  fun get_group () = the_default NONE (Thread.getLocal tag);
+  fun set_group x = Thread.setLocal (tag, x);
+end;
+
+
+(* datatype future *)
+
+datatype 'a T = Future of
+ {task: task,
+  group: group option,
+  result: 'a Exn.result option ref};
+
+fun task_of (Future {task, ...}) = task;
+fun group_of (Future {group, ...}) = group;
+
+
 (* synchronized execution *)
 
 local
@@ -36,11 +66,12 @@
     val _ = Mutex.unlock lock;
   in Exn.release result end) ();
 
-fun wait () = (*requires SYNCHRONIZED*)
-  ConditionVar.wait (cond, lock);
-
-fun wait_timeout timeout = (*requires SYNCHRONIZED*)
-  ConditionVar.waitUntil (cond, lock, Time.+ (Time.now (), timeout));
+fun wait name = (*requires SYNCHRONIZED*)
+  let
+    val _ = Multithreading.tracing 4 (fn () => name ^ " : waiting");
+    val _ = ConditionVar.wait (cond, lock);
+    val _ = Multithreading.tracing 4 (fn () => name ^ " : notified");
+  in () end;
 
 fun notify_all () = (*requires SYNCHRONIZED*)
   ConditionVar.broadcast cond;
@@ -48,85 +79,92 @@
 end;
 
 
-(* datatype future *)
-
-type task = TaskQueue.task;
-type group = TaskQueue.group;
+(** scheduling **)
 
-datatype 'a T = Future of
- {task: task,
-  group: group option,
-  result: 'a Exn.result option ref};
-
-fun task_of (Future {task, ...}) = task;
-fun group_of (Future {group, ...}) = group;
-
-
-(* global state *)
+datatype request = Shutdown | CancelGroup of group;
+val requests = Mailbox.create () : request Mailbox.T;
 
 val queue = ref TaskQueue.empty;
 val scheduler = ref (NONE: Thread.thread option);
 val workers = ref ([]: Thread.thread list);
 
 
-(* worker thread *)
+(* signals *)
+
+fun interrupt_task x = SYNCHRONIZED (fn () => TaskQueue.interrupt_task (! queue) x);
+fun interrupt_group x = SYNCHRONIZED (fn () => TaskQueue.interrupt_group (! queue) x);
+fun interrupt_task_group x = SYNCHRONIZED (fn () => TaskQueue.interrupt_task_group (! queue) x);
+
+fun interrupt (Future {task, ...}) = interrupt_task_group task;
+
+fun shutdown () = Mailbox.send Shutdown requests;
+
+
+(* execute *)
 
-local val active = ref 0 in
+fun execute name (task, group, run) =
+  let
+    val _ = set_task (SOME task);
+    val _ = set_group group;
+    val _ = Multithreading.tracing 4 (fn () => name ^ ": running");
+    val ok = run ();
+    val _ = Multithreading.tracing 4 (fn () => name ^ ": finished");
+    val _ = set_task NONE;
+    val _ = set_group NONE;
+    val _ = SYNCHRONIZED (fn () => (change queue (TaskQueue.finished task); notify_all ()));
+    val _ = (case (ok, group) of (false, SOME g) => Mailbox.send (CancelGroup g) requests | _ => ());
+  in () end;
+
+
+(* worker threads *)
+
+val excessive = ref 0;
+val active = ref 0;
 
 fun change_active b = (*requires SYNCHRONIZED*)
  (change active (fn n => if b then n + 1 else n - 1);
   Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int (! active) ^ " active"));
 
-end;
-
-fun excessive_threads () = false;  (* FIXME *)
-
-fun worker_next () = (*requires SYNCHRONIZED*)
-  if excessive_threads () then
-   (change_active false;
-    change workers (filter_out (fn thread => Thread.equal (thread, Thread.self ())));
-    NONE)
+fun worker_next name = (*requires SYNCHRONIZED*)
+  if ! excessive > 0 then
+    (dec excessive;
+     change_active false;
+     change workers (remove Thread.equal (Thread.self ()));
+     NONE)
   else
     (case change_result queue (TaskQueue.dequeue (Thread.self ())) of
-      NONE => (change_active false; wait (); change_active true; worker_next ())
+      NONE => (change_active false; wait name; change_active true; worker_next name)
     | some => some);
 
-fun worker_loop () =
-  (case SYNCHRONIZED worker_next of
+fun worker_loop name =
+  (case SYNCHRONIZED (fn () => worker_next name) of
     NONE => ()
-  | SOME (task, run) =>
-      let
-        val _ = TaskQueue.set_thread_data (SOME task);
-        val _ = run ();
-        val _ = TaskQueue.set_thread_data NONE;
-        val _ = SYNCHRONIZED (fn () => (change queue (TaskQueue.finished task); notify_all ()));
-      in worker_loop () end);
+  | SOME work => (execute name work; worker_loop name));
 
-fun worker_start () = SYNCHRONIZED (fn () =>
+fun worker_start name = (*requires SYNCHRONIZED*)
  (change_active true;
-  change workers (cons (Thread.fork (worker_loop, Multithreading.no_interrupts)))));
+  change workers (cons (Thread.fork (fn () => worker_loop name, Multithreading.no_interrupts))));
 
 
 (* scheduler *)
 
-fun scheduler_loop () = (*requires SYNCHRONIZED*)
+fun scheduler_fork () = SYNCHRONIZED (fn () =>
   let
     val m = Multithreading.max_threads_value ();
-    val k = m - length (! workers);
-    val _ = if k > 0 then funpow k worker_start () else ();
-  in wait_timeout (Time.fromMilliseconds 300); scheduler_loop () end;
+    val l = length (! workers);
+    val _ = excessive := l - m;
+  in List.app (fn i => worker_start ("worker " ^ string_of_int i)) (l upto m - 1) end);
+
+fun scheduler_loop () =
+  (scheduler_fork ();
+    (case Mailbox.receive_timeout (Time.fromMilliseconds 300) requests of
+      SOME Shutdown => ()   (* FIXME *)
+    | SOME (CancelGroup group) => (interrupt_group group; scheduler_loop ())  (* FIXME *)
+    | NONE => scheduler_loop ()));
 
 fun check_scheduler () = SYNCHRONIZED (fn () =>
-  let
-    val scheduler_active =
-      (case ! scheduler of
-        NONE => false
-      | SOME t => Thread.isActive t);
-  in
-    if scheduler_active then ()
-    else scheduler :=
-      SOME (Thread.fork (SYNCHRONIZED o scheduler_loop, Multithreading.no_interrupts))
-  end);
+  if (case ! scheduler of NONE => false | SOME thread => Thread.isActive thread) then ()
+  else scheduler := SOME (Thread.fork (scheduler_loop, Multithreading.no_interrupts)));
 
 
 (* future values *)
@@ -134,29 +172,22 @@
 fun future group deps (e: unit -> 'a) =
   let
     val _ = check_scheduler ();
-
     val result = ref (NONE: 'a Exn.result option);
-    val run = Multithreading.with_attributes (Thread.getAttributes ())
-      (fn _ => fn () => result := SOME (Exn.capture e ()));
+    val run = Multithreading.with_attributes (Thread.getAttributes ()) (fn _ => fn () =>
+      let val res = Exn.capture e () in result := SOME res; is_some (Exn.get_result res) end);
     val task = SYNCHRONIZED (fn () =>
       change_result queue (TaskQueue.enqueue group deps run) before notify_all ());
   in Future {task = task, group = group, result = result} end;
 
-fun fork e = future NONE [] e;
+fun fork e = future (get_group ()) [] e;
 
 fun join (Future {result, ...}) =
   let
     val _ = check_scheduler ();
     fun loop () =
       (case ! result of
-        NONE => (wait (); loop ())
+        NONE => (wait "join"; loop ())
       | SOME res => res);
   in Exn.release (SYNCHRONIZED loop) end;
 
-
-(* interrupts *)
-
-fun interrupt task = SYNCHRONIZED (fn () => TaskQueue.interrupt (! queue) task);
-fun interrupt_group group = SYNCHRONIZED (fn () => TaskQueue.interrupt_group (! queue) group);
-
 end;