moved task, thread_data, group, queue to task_queue.ML;
authorwenzelm
Mon, 08 Sep 2008 16:08:23 +0200
changeset 28166 43087721a66e
parent 28165 26bb048f463c
child 28167 27e2ca41b58c
moved task, thread_data, group, queue to task_queue.ML; tuned signature; SYNCHRONIZED notify_all! misc tuning;
src/Pure/Concurrent/future.ML
--- a/src/Pure/Concurrent/future.ML	Mon Sep 08 16:08:18 2008 +0200
+++ b/src/Pure/Concurrent/future.ML	Mon Sep 08 16:08:23 2008 +0200
@@ -7,13 +7,16 @@
 
 signature FUTURE =
 sig
+  type task = TaskQueue.task
+  type group = TaskQueue.group
   type 'a T
-  eqtype id
-  val id_of: 'a T -> id
-  val interrupt: id -> unit
-  val dependent_future: id list -> (unit -> 'a) -> 'a T
-  val future: (unit -> 'a) -> 'a T
-  val await: 'a T -> 'a
+  val task_of: 'a T -> task
+  val group_of: 'a T -> group option
+  val future: group option -> task list -> (unit -> 'a) -> 'a T
+  val fork: (unit -> 'a) -> 'a T
+  val join: 'a T -> 'a
+  val interrupt: task -> unit
+  val interrupt_group: group -> unit
 end;
 
 structure Future: FUTURE =
@@ -33,111 +36,71 @@
     val _ = Mutex.unlock lock;
   in Exn.release result end) ();
 
-fun wait () = ConditionVar.wait (cond, lock);
-fun wait_timeout timeout = ConditionVar.waitUntil (cond, lock, Time.+ (Time.now (), timeout));
+fun wait () = (*requires SYNCHRONIZED*)
+  ConditionVar.wait (cond, lock);
 
-fun notify_all () = ConditionVar.broadcast cond;
+fun wait_timeout timeout = (*requires SYNCHRONIZED*)
+  ConditionVar.waitUntil (cond, lock, Time.+ (Time.now (), timeout));
+
+fun notify_all () = (*requires SYNCHRONIZED*)
+  ConditionVar.broadcast cond;
 
 end;
 
 
-(* typed futures, unytped ids *)
-
-datatype 'a T = Future of serial * 'a Exn.result option ref;
-
-datatype id = Id of serial;
-fun id_of (Future (id, _)) = Id id;
+(* datatype future *)
 
-local val tag = Universal.tag () : serial Universal.tag in
-  fun get_id () = Thread.getLocal tag;
-  fun put_id id = Thread.setLocal (tag, id);
-end;
-
-
-(* ordered queue of tasks *)
-
-datatype task =
-  Task of (unit -> unit) |
-  Running of Thread.thread;
-
-datatype queue = Queue of task IntGraph.T * (serial * (unit -> unit)) Queue.T;
-
-val empty_queue = Queue (IntGraph.empty, Queue.empty);
+type task = TaskQueue.task;
+type group = TaskQueue.group;
 
-fun check_cache (queue as Queue (tasks, cache)) =
-  if not (Queue.is_empty cache) then queue
-  else
-    let fun ready (id, (Task task, ([], _))) = Queue.enqueue (id, task) | ready _ = I
-    in Queue (tasks, IntGraph.fold ready tasks Queue.empty) end;
-
-val next_task = check_cache #> (fn queue as Queue (tasks, cache) =>
-  if Queue.is_empty cache then (NONE, queue)
-  else
-    let val (task, cache') = Queue.dequeue cache
-    in (SOME task, Queue (tasks, cache')) end);
+datatype 'a T = Future of
+ {task: task,
+  group: group option,
+  result: 'a Exn.result option ref};
 
-fun get_task (Queue (tasks, _)) id = IntGraph.get_node tasks id;
-
-fun new_task deps id task (Queue (tasks, _)) =
-  let
-    fun add_dep (Id dep) G = IntGraph.add_edge_acyclic (dep, id) G
-      handle IntGraph.UNDEF _ => G;  (*dep already finished*)
-    val tasks' = tasks |> IntGraph.new_node (id, Task task) |> fold add_dep deps;
-  in Queue (tasks', Queue.empty) end;
-
-fun running_task id thread (Queue (tasks, cache)) =
-  Queue (IntGraph.map_node id (K (Running thread)) tasks, cache);
-
-fun finished_task id (Queue (tasks, _)) =
-  Queue (IntGraph.del_nodes [id] tasks, Queue.empty);
+fun task_of (Future {task, ...}) = task;
+fun group_of (Future {group, ...}) = group;
 
 
 (* global state *)
 
-val tasks = ref empty_queue;
+val queue = ref TaskQueue.empty;
 val scheduler = ref (NONE: Thread.thread option);
 val workers = ref ([]: Thread.thread list);
 
 
-fun interrupt (Id id) = SYNCHRONIZED (fn () =>
-  (case try (get_task (! tasks)) id of
-    SOME (Running thread) => Thread.interrupt thread
-  | _ => ()));
-
-
 (* worker thread *)
 
 local val active = ref 0 in
 
 fun change_active b = (*requires SYNCHRONIZED*)
-  let
-    val _ = change active (fn n => if b then n + 1 else n - 1);
-    val n = ! active;
-    val _ = Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int n ^ " active");
-  in () end;
+ (change active (fn n => if b then n + 1 else n - 1);
+  Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int (! active) ^ " active"));
 
 end;
 
 fun excessive_threads () = false;  (* FIXME *)
 
-fun worker_stop () = SYNCHRONIZED (fn () =>
-  (change_active false; change workers (filter (fn t => not (Thread.equal (t, Thread.self ()))))));
-
-fun worker_wait () = SYNCHRONIZED (fn () =>
-  (change_active false; wait (); change_active true));
+fun worker_next () = (*requires SYNCHRONIZED*)
+  if excessive_threads () then
+   (change_active false;
+    change workers (filter_out (fn thread => Thread.equal (thread, Thread.self ())));
+    NONE)
+  else
+    (case change_result queue (TaskQueue.dequeue (Thread.self ())) of
+      NONE => (change_active false; wait (); change_active true; worker_next ())
+    | some => some);
 
 fun worker_loop () =
-  if excessive_threads () then worker_stop ()
-  else
-    (case SYNCHRONIZED (fn () => change_result tasks next_task) of
-      NONE => (worker_wait (); worker_loop ())
-    | SOME (id, task) =>
-        let
-          val _ = SYNCHRONIZED (fn () => change tasks (running_task id (Thread.self ())));
-          val _ = task ();
-          val _ = SYNCHRONIZED (fn () => change tasks (finished_task id));
-          val _ = notify_all ();
-        in worker_loop () end);
+  (case SYNCHRONIZED worker_next of
+    NONE => ()
+  | SOME (task, run) =>
+      let
+        val _ = TaskQueue.set_thread_data (SOME task);
+        val _ = run ();
+        val _ = TaskQueue.set_thread_data NONE;
+        val _ = SYNCHRONIZED (fn () => (change queue (TaskQueue.finished task); notify_all ()));
+      in worker_loop () end);
 
 fun worker_start () = SYNCHRONIZED (fn () =>
  (change_active true;
@@ -151,7 +114,7 @@
     val m = Multithreading.max_threads_value ();
     val k = m - length (! workers);
     val _ = if k > 0 then funpow k worker_start () else ();
-  in wait_timeout (Time.fromSeconds 1); scheduler_loop () end;
+  in wait_timeout (Time.fromMilliseconds 300); scheduler_loop () end;
 
 fun check_scheduler () = SYNCHRONIZED (fn () =>
   let
@@ -168,30 +131,32 @@
 
 (* future values *)
 
-fun dependent_future deps (e: unit -> 'a) =
+fun future group deps (e: unit -> 'a) =
   let
     val _ = check_scheduler ();
 
-    val r = ref (NONE: 'a Exn.result option);
-    val task = Multithreading.with_attributes (Thread.getAttributes ())
-      (fn _ => fn () => r := SOME (Exn.capture e ()));
+    val result = ref (NONE: 'a Exn.result option);
+    val run = Multithreading.with_attributes (Thread.getAttributes ())
+      (fn _ => fn () => result := SOME (Exn.capture e ()));
+    val task = SYNCHRONIZED (fn () =>
+      change_result queue (TaskQueue.enqueue group deps run) before notify_all ());
+  in Future {task = task, group = group, result = result} end;
 
-    val id = serial ();
-    val _ = SYNCHRONIZED (fn () => change tasks (new_task deps id task));
-    val _ = notify_all ();
+fun fork e = future NONE [] e;
 
-  in Future (id, r) end;
-
-fun future e = dependent_future [] e;
-
-fun await (Future (_, r)) =
+fun join (Future {result, ...}) =
   let
     val _ = check_scheduler ();
+    fun loop () =
+      (case ! result of
+        NONE => (wait (); loop ())
+      | SOME res => res);
+  in Exn.release (SYNCHRONIZED loop) end;
 
-    fun loop () =
-      (case SYNCHRONIZED (fn () => ! r) of
-        NONE => (SYNCHRONIZED (fn () => wait ()); loop ())
-      | SOME res => Exn.release res);
-  in loop () end;
+
+(* interrupts *)
+
+fun interrupt task = SYNCHRONIZED (fn () => TaskQueue.interrupt (! queue) task);
+fun interrupt_group group = SYNCHRONIZED (fn () => TaskQueue.interrupt_group (! queue) group);
 
 end;