Functional threads as future values.
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Pure/Concurrent/future.ML Sun Sep 07 22:19:42 2008 +0200
@@ -0,0 +1,207 @@
+(* Title: Pure/Concurrent/future.ML
+ ID: $Id$
+ Author: Makarius
+
+Functional threads as future values.
+*)
+
+signature FUTURE =
+sig
+ type 'a T
+ eqtype id
+ val id_of: 'a T -> id
+ val interrupt: id -> unit
+ val dependent_future: id list -> (unit -> 'a) -> 'a T
+ val future: (unit -> 'a) -> 'a T
+ val await: 'a T -> 'a
+end;
+
+structure Future: FUTURE =
+struct
+
+(* synchronized execution *)
+
+local
+ val thread = ref (NONE: Thread.thread option);
+ val lock = Mutex.mutex ();
+ val cond = ConditionVar.conditionVar ();
+in
+
+fun self_synchronized () =
+ (case ! thread of
+ NONE => false
+ | SOME t => Thread.equal (t, Thread.self ()));
+
+fun SYNCHRONIZED e =
+ if self_synchronized () then e ()
+ else
+ uninterruptible (fn restore_attributes => fn () =>
+ let
+ val _ = Mutex.lock lock;
+ val _ = thread := SOME (Thread.self ());
+ val result = Exn.capture (restore_attributes e) ();
+ val _ = thread := NONE;
+ val _ = Mutex.unlock lock;
+ in Exn.release result end) ();
+
+fun wait () = ConditionVar.wait (cond, lock);
+fun wait_timeout timeout = ConditionVar.waitUntil (cond, lock, Time.+ (Time.now (), timeout));
+
+fun notify_all () = ConditionVar.broadcast cond;
+
+end;
+
+
+(* typed futures, unytped ids *)
+
+datatype 'a T = Future of serial * 'a Exn.result option ref;
+
+datatype id = Id of serial;
+fun id_of (Future (id, _)) = Id id;
+
+local val tag = Universal.tag () : serial Universal.tag in
+ fun get_id () = Thread.getLocal tag;
+ fun put_id id = Thread.setLocal (tag, id);
+end;
+
+
+(* ordered queue of tasks *)
+
+datatype task =
+ Task of (unit -> unit) |
+ Running of Thread.thread;
+
+datatype queue = Queue of task IntGraph.T * (serial * (unit -> unit)) Queue.T;
+
+val empty_queue = Queue (IntGraph.empty, Queue.empty);
+
+fun check_cache (queue as Queue (tasks, cache)) =
+ if not (Queue.is_empty cache) then queue
+ else
+ let
+ val cache' = fold (fn id =>
+ (case IntGraph.get_node tasks id of
+ Task task => Queue.enqueue (id, task)
+ | Running _ => I)) (IntGraph.minimals tasks) Queue.empty;
+ in Queue (tasks, cache') end;
+
+val next_task = check_cache #> (fn queue as Queue (tasks, cache) =>
+ if Queue.is_empty cache then (NONE, queue)
+ else
+ let val (task, cache') = Queue.dequeue cache
+ in (SOME task, Queue (tasks, cache')) end);
+
+fun get_task (Queue (tasks, _)) id = IntGraph.get_node tasks id;
+
+fun new_task deps id task (Queue (tasks, _)) =
+ let
+ fun add_dep (Id dep) G = IntGraph.add_edge_acyclic (dep, id) G
+ handle IntGraph.UNDEF _ => G; (*dep already finished*)
+ val tasks' = tasks |> IntGraph.new_node (id, Task task) |> fold add_dep deps;
+ in Queue (tasks', Queue.empty) end;
+
+fun running_task id thread (Queue (tasks, cache)) =
+ Queue (IntGraph.map_node id (K (Running thread)) tasks, cache);
+
+fun finished_task id (Queue (tasks, _)) =
+ Queue (IntGraph.del_nodes [id] tasks, Queue.empty);
+
+
+(* global state *)
+
+local val active = ref 0 in
+
+fun change_active b = SYNCHRONIZED (fn () =>
+ let
+ val _ = change active (fn n => if b then n + 1 else n - 1);
+ val n = ! active;
+ val _ = Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int n ^ " active tasks");
+ in () end);
+
+end;
+
+val tasks = ref empty_queue;
+val scheduler = ref (NONE: Thread.thread option);
+val workers = ref ([]: Thread.thread list);
+
+
+fun interrupt (Id id) = SYNCHRONIZED (fn () =>
+ (case try (get_task (! tasks)) id of
+ SOME (Running thread) => Thread.interrupt thread
+ | _ => ()));
+
+
+(* worker thread *)
+
+fun excessive_threads () = false; (* FIXME *)
+
+fun worker_stop () =
+ (change_active false; change workers (filter (fn t => not (Thread.equal (t, Thread.self ())))));
+
+fun worker_wait () =
+ (change_active false; wait (); change_active true);
+
+fun worker_loop () =
+ (case SYNCHRONIZED (fn () => change_result tasks next_task) of
+ SOME (id, task) =>
+ let
+ val _ = SYNCHRONIZED (fn () => change tasks (running_task id (Thread.self ())));
+ val _ = task ();
+ val _ = SYNCHRONIZED (fn () => change tasks (finished_task id));
+ val _ = notify_all ();
+ in if excessive_threads () then worker_stop () else worker_loop () end
+ | NONE => (worker_wait (); worker_loop ()));
+
+fun worker_start () =
+ (change_active true;
+ change workers (cons (Thread.fork (worker_loop, Multithreading.no_interrupts))));
+
+
+(* scheduler *)
+
+fun scheduler_loop () =
+ let
+ val m = Multithreading.max_threads_value ();
+ val k = m - length (! workers);
+ val _ = if k > 0 then funpow k worker_start () else ();
+ in wait_timeout (Time.fromSeconds 1); scheduler_loop () end;
+
+fun check_scheduler () = SYNCHRONIZED (fn () =>
+ let
+ val scheduler_active =
+ (case ! scheduler of
+ NONE => false
+ | SOME t => Thread.isActive t);
+ in
+ if scheduler_active then ()
+ else scheduler := SOME (Thread.fork (SYNCHRONIZED scheduler_loop, Multithreading.no_interrupts))
+ end);
+
+
+(* future values *)
+
+fun dependent_future deps (e: unit -> 'a) =
+ let
+ val _ = check_scheduler ();
+
+ val r = ref (NONE: 'a Exn.result option);
+ val task = Multithreading.with_attributes (Thread.getAttributes ())
+ (fn _ => fn () => r := SOME (Exn.capture e ()));
+ val id = serial ();
+ val _ = SYNCHRONIZED (fn () => change tasks (new_task deps id task));
+ val _ = notify_all ();
+ in Future (id, r) end;
+
+fun future e = dependent_future [] e;
+
+fun await (Future (_, r)) =
+ let
+ val _ = check_scheduler ();
+
+ fun loop () =
+ (case SYNCHRONIZED (fn () => ! r) of
+ NONE => (wait (); loop ())
+ | SOME res => Exn.release res);
+ in loop () end;
+
+end;