# HG changeset patch # User wenzelm # Date 1220882903 -7200 # Node ID 43087721a66e594d62008166077a67700c102aa9 # Parent 26bb048f463c66c43b229ca3cfea4776e3960b3c moved task, thread_data, group, queue to task_queue.ML; tuned signature; SYNCHRONIZED notify_all! misc tuning; diff -r 26bb048f463c -r 43087721a66e src/Pure/Concurrent/future.ML --- a/src/Pure/Concurrent/future.ML Mon Sep 08 16:08:18 2008 +0200 +++ b/src/Pure/Concurrent/future.ML Mon Sep 08 16:08:23 2008 +0200 @@ -7,13 +7,16 @@ signature FUTURE = sig + type task = TaskQueue.task + type group = TaskQueue.group type 'a T - eqtype id - val id_of: 'a T -> id - val interrupt: id -> unit - val dependent_future: id list -> (unit -> 'a) -> 'a T - val future: (unit -> 'a) -> 'a T - val await: 'a T -> 'a + val task_of: 'a T -> task + val group_of: 'a T -> group option + val future: group option -> task list -> (unit -> 'a) -> 'a T + val fork: (unit -> 'a) -> 'a T + val join: 'a T -> 'a + val interrupt: task -> unit + val interrupt_group: group -> unit end; structure Future: FUTURE = @@ -33,111 +36,71 @@ val _ = Mutex.unlock lock; in Exn.release result end) (); -fun wait () = ConditionVar.wait (cond, lock); -fun wait_timeout timeout = ConditionVar.waitUntil (cond, lock, Time.+ (Time.now (), timeout)); +fun wait () = (*requires SYNCHRONIZED*) + ConditionVar.wait (cond, lock); -fun notify_all () = ConditionVar.broadcast cond; +fun wait_timeout timeout = (*requires SYNCHRONIZED*) + ConditionVar.waitUntil (cond, lock, Time.+ (Time.now (), timeout)); + +fun notify_all () = (*requires SYNCHRONIZED*) + ConditionVar.broadcast cond; end; -(* typed futures, unytped ids *) - -datatype 'a T = Future of serial * 'a Exn.result option ref; - -datatype id = Id of serial; -fun id_of (Future (id, _)) = Id id; +(* datatype future *) -local val tag = Universal.tag () : serial Universal.tag in - fun get_id () = Thread.getLocal tag; - fun put_id id = Thread.setLocal (tag, id); -end; - - -(* ordered queue of tasks *) - -datatype task = - Task of (unit -> unit) | - Running of Thread.thread; - -datatype queue = Queue of task IntGraph.T * (serial * (unit -> unit)) Queue.T; - -val empty_queue = Queue (IntGraph.empty, Queue.empty); +type task = TaskQueue.task; +type group = TaskQueue.group; -fun check_cache (queue as Queue (tasks, cache)) = - if not (Queue.is_empty cache) then queue - else - let fun ready (id, (Task task, ([], _))) = Queue.enqueue (id, task) | ready _ = I - in Queue (tasks, IntGraph.fold ready tasks Queue.empty) end; - -val next_task = check_cache #> (fn queue as Queue (tasks, cache) => - if Queue.is_empty cache then (NONE, queue) - else - let val (task, cache') = Queue.dequeue cache - in (SOME task, Queue (tasks, cache')) end); +datatype 'a T = Future of + {task: task, + group: group option, + result: 'a Exn.result option ref}; -fun get_task (Queue (tasks, _)) id = IntGraph.get_node tasks id; - -fun new_task deps id task (Queue (tasks, _)) = - let - fun add_dep (Id dep) G = IntGraph.add_edge_acyclic (dep, id) G - handle IntGraph.UNDEF _ => G; (*dep already finished*) - val tasks' = tasks |> IntGraph.new_node (id, Task task) |> fold add_dep deps; - in Queue (tasks', Queue.empty) end; - -fun running_task id thread (Queue (tasks, cache)) = - Queue (IntGraph.map_node id (K (Running thread)) tasks, cache); - -fun finished_task id (Queue (tasks, _)) = - Queue (IntGraph.del_nodes [id] tasks, Queue.empty); +fun task_of (Future {task, ...}) = task; +fun group_of (Future {group, ...}) = group; (* global state *) -val tasks = ref empty_queue; +val queue = ref TaskQueue.empty; val scheduler = ref (NONE: Thread.thread option); val workers = ref ([]: Thread.thread list); -fun interrupt (Id id) = SYNCHRONIZED (fn () => - (case try (get_task (! tasks)) id of - SOME (Running thread) => Thread.interrupt thread - | _ => ())); - - (* worker thread *) local val active = ref 0 in fun change_active b = (*requires SYNCHRONIZED*) - let - val _ = change active (fn n => if b then n + 1 else n - 1); - val n = ! active; - val _ = Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int n ^ " active"); - in () end; + (change active (fn n => if b then n + 1 else n - 1); + Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int (! active) ^ " active")); end; fun excessive_threads () = false; (* FIXME *) -fun worker_stop () = SYNCHRONIZED (fn () => - (change_active false; change workers (filter (fn t => not (Thread.equal (t, Thread.self ())))))); - -fun worker_wait () = SYNCHRONIZED (fn () => - (change_active false; wait (); change_active true)); +fun worker_next () = (*requires SYNCHRONIZED*) + if excessive_threads () then + (change_active false; + change workers (filter_out (fn thread => Thread.equal (thread, Thread.self ()))); + NONE) + else + (case change_result queue (TaskQueue.dequeue (Thread.self ())) of + NONE => (change_active false; wait (); change_active true; worker_next ()) + | some => some); fun worker_loop () = - if excessive_threads () then worker_stop () - else - (case SYNCHRONIZED (fn () => change_result tasks next_task) of - NONE => (worker_wait (); worker_loop ()) - | SOME (id, task) => - let - val _ = SYNCHRONIZED (fn () => change tasks (running_task id (Thread.self ()))); - val _ = task (); - val _ = SYNCHRONIZED (fn () => change tasks (finished_task id)); - val _ = notify_all (); - in worker_loop () end); + (case SYNCHRONIZED worker_next of + NONE => () + | SOME (task, run) => + let + val _ = TaskQueue.set_thread_data (SOME task); + val _ = run (); + val _ = TaskQueue.set_thread_data NONE; + val _ = SYNCHRONIZED (fn () => (change queue (TaskQueue.finished task); notify_all ())); + in worker_loop () end); fun worker_start () = SYNCHRONIZED (fn () => (change_active true; @@ -151,7 +114,7 @@ val m = Multithreading.max_threads_value (); val k = m - length (! workers); val _ = if k > 0 then funpow k worker_start () else (); - in wait_timeout (Time.fromSeconds 1); scheduler_loop () end; + in wait_timeout (Time.fromMilliseconds 300); scheduler_loop () end; fun check_scheduler () = SYNCHRONIZED (fn () => let @@ -168,30 +131,32 @@ (* future values *) -fun dependent_future deps (e: unit -> 'a) = +fun future group deps (e: unit -> 'a) = let val _ = check_scheduler (); - val r = ref (NONE: 'a Exn.result option); - val task = Multithreading.with_attributes (Thread.getAttributes ()) - (fn _ => fn () => r := SOME (Exn.capture e ())); + val result = ref (NONE: 'a Exn.result option); + val run = Multithreading.with_attributes (Thread.getAttributes ()) + (fn _ => fn () => result := SOME (Exn.capture e ())); + val task = SYNCHRONIZED (fn () => + change_result queue (TaskQueue.enqueue group deps run) before notify_all ()); + in Future {task = task, group = group, result = result} end; - val id = serial (); - val _ = SYNCHRONIZED (fn () => change tasks (new_task deps id task)); - val _ = notify_all (); +fun fork e = future NONE [] e; - in Future (id, r) end; - -fun future e = dependent_future [] e; - -fun await (Future (_, r)) = +fun join (Future {result, ...}) = let val _ = check_scheduler (); + fun loop () = + (case ! result of + NONE => (wait (); loop ()) + | SOME res => res); + in Exn.release (SYNCHRONIZED loop) end; - fun loop () = - (case SYNCHRONIZED (fn () => ! r) of - NONE => (SYNCHRONIZED (fn () => wait ()); loop ()) - | SOME res => Exn.release res); - in loop () end; + +(* interrupts *) + +fun interrupt task = SYNCHRONIZED (fn () => TaskQueue.interrupt (! queue) task); +fun interrupt_group group = SYNCHRONIZED (fn () => TaskQueue.interrupt_group (! queue) group); end;