src/Pure/Concurrent/future.ML
changeset 28166 43087721a66e
parent 28163 8bf8c21296ca
child 28167 27e2ca41b58c
equal deleted inserted replaced
28165:26bb048f463c 28166:43087721a66e
     5 Functional threads as future values.
     5 Functional threads as future values.
     6 *)
     6 *)
     7 
     7 
     8 signature FUTURE =
     8 signature FUTURE =
     9 sig
     9 sig
       
    10   type task = TaskQueue.task
       
    11   type group = TaskQueue.group
    10   type 'a T
    12   type 'a T
    11   eqtype id
    13   val task_of: 'a T -> task
    12   val id_of: 'a T -> id
    14   val group_of: 'a T -> group option
    13   val interrupt: id -> unit
    15   val future: group option -> task list -> (unit -> 'a) -> 'a T
    14   val dependent_future: id list -> (unit -> 'a) -> 'a T
    16   val fork: (unit -> 'a) -> 'a T
    15   val future: (unit -> 'a) -> 'a T
    17   val join: 'a T -> 'a
    16   val await: 'a T -> 'a
    18   val interrupt: task -> unit
       
    19   val interrupt_group: group -> unit
    17 end;
    20 end;
    18 
    21 
    19 structure Future: FUTURE =
    22 structure Future: FUTURE =
    20 struct
    23 struct
    21 
    24 
    31     val _ = Mutex.lock lock;
    34     val _ = Mutex.lock lock;
    32     val result = Exn.capture (restore_attributes e) ();
    35     val result = Exn.capture (restore_attributes e) ();
    33     val _ = Mutex.unlock lock;
    36     val _ = Mutex.unlock lock;
    34   in Exn.release result end) ();
    37   in Exn.release result end) ();
    35 
    38 
    36 fun wait () = ConditionVar.wait (cond, lock);
    39 fun wait () = (*requires SYNCHRONIZED*)
    37 fun wait_timeout timeout = ConditionVar.waitUntil (cond, lock, Time.+ (Time.now (), timeout));
    40   ConditionVar.wait (cond, lock);
    38 
    41 
    39 fun notify_all () = ConditionVar.broadcast cond;
    42 fun wait_timeout timeout = (*requires SYNCHRONIZED*)
       
    43   ConditionVar.waitUntil (cond, lock, Time.+ (Time.now (), timeout));
       
    44 
       
    45 fun notify_all () = (*requires SYNCHRONIZED*)
       
    46   ConditionVar.broadcast cond;
    40 
    47 
    41 end;
    48 end;
    42 
    49 
    43 
    50 
    44 (* typed futures, unytped ids *)
    51 (* datatype future *)
    45 
    52 
    46 datatype 'a T = Future of serial * 'a Exn.result option ref;
    53 type task = TaskQueue.task;
       
    54 type group = TaskQueue.group;
    47 
    55 
    48 datatype id = Id of serial;
    56 datatype 'a T = Future of
    49 fun id_of (Future (id, _)) = Id id;
    57  {task: task,
       
    58   group: group option,
       
    59   result: 'a Exn.result option ref};
    50 
    60 
    51 local val tag = Universal.tag () : serial Universal.tag in
    61 fun task_of (Future {task, ...}) = task;
    52   fun get_id () = Thread.getLocal tag;
    62 fun group_of (Future {group, ...}) = group;
    53   fun put_id id = Thread.setLocal (tag, id);
       
    54 end;
       
    55 
       
    56 
       
    57 (* ordered queue of tasks *)
       
    58 
       
    59 datatype task =
       
    60   Task of (unit -> unit) |
       
    61   Running of Thread.thread;
       
    62 
       
    63 datatype queue = Queue of task IntGraph.T * (serial * (unit -> unit)) Queue.T;
       
    64 
       
    65 val empty_queue = Queue (IntGraph.empty, Queue.empty);
       
    66 
       
    67 fun check_cache (queue as Queue (tasks, cache)) =
       
    68   if not (Queue.is_empty cache) then queue
       
    69   else
       
    70     let fun ready (id, (Task task, ([], _))) = Queue.enqueue (id, task) | ready _ = I
       
    71     in Queue (tasks, IntGraph.fold ready tasks Queue.empty) end;
       
    72 
       
    73 val next_task = check_cache #> (fn queue as Queue (tasks, cache) =>
       
    74   if Queue.is_empty cache then (NONE, queue)
       
    75   else
       
    76     let val (task, cache') = Queue.dequeue cache
       
    77     in (SOME task, Queue (tasks, cache')) end);
       
    78 
       
    79 fun get_task (Queue (tasks, _)) id = IntGraph.get_node tasks id;
       
    80 
       
    81 fun new_task deps id task (Queue (tasks, _)) =
       
    82   let
       
    83     fun add_dep (Id dep) G = IntGraph.add_edge_acyclic (dep, id) G
       
    84       handle IntGraph.UNDEF _ => G;  (*dep already finished*)
       
    85     val tasks' = tasks |> IntGraph.new_node (id, Task task) |> fold add_dep deps;
       
    86   in Queue (tasks', Queue.empty) end;
       
    87 
       
    88 fun running_task id thread (Queue (tasks, cache)) =
       
    89   Queue (IntGraph.map_node id (K (Running thread)) tasks, cache);
       
    90 
       
    91 fun finished_task id (Queue (tasks, _)) =
       
    92   Queue (IntGraph.del_nodes [id] tasks, Queue.empty);
       
    93 
    63 
    94 
    64 
    95 (* global state *)
    65 (* global state *)
    96 
    66 
    97 val tasks = ref empty_queue;
    67 val queue = ref TaskQueue.empty;
    98 val scheduler = ref (NONE: Thread.thread option);
    68 val scheduler = ref (NONE: Thread.thread option);
    99 val workers = ref ([]: Thread.thread list);
    69 val workers = ref ([]: Thread.thread list);
   100 
       
   101 
       
   102 fun interrupt (Id id) = SYNCHRONIZED (fn () =>
       
   103   (case try (get_task (! tasks)) id of
       
   104     SOME (Running thread) => Thread.interrupt thread
       
   105   | _ => ()));
       
   106 
    70 
   107 
    71 
   108 (* worker thread *)
    72 (* worker thread *)
   109 
    73 
   110 local val active = ref 0 in
    74 local val active = ref 0 in
   111 
    75 
   112 fun change_active b = (*requires SYNCHRONIZED*)
    76 fun change_active b = (*requires SYNCHRONIZED*)
   113   let
    77  (change active (fn n => if b then n + 1 else n - 1);
   114     val _ = change active (fn n => if b then n + 1 else n - 1);
    78   Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int (! active) ^ " active"));
   115     val n = ! active;
       
   116     val _ = Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ string_of_int n ^ " active");
       
   117   in () end;
       
   118 
    79 
   119 end;
    80 end;
   120 
    81 
   121 fun excessive_threads () = false;  (* FIXME *)
    82 fun excessive_threads () = false;  (* FIXME *)
   122 
    83 
   123 fun worker_stop () = SYNCHRONIZED (fn () =>
    84 fun worker_next () = (*requires SYNCHRONIZED*)
   124   (change_active false; change workers (filter (fn t => not (Thread.equal (t, Thread.self ()))))));
    85   if excessive_threads () then
   125 
    86    (change_active false;
   126 fun worker_wait () = SYNCHRONIZED (fn () =>
    87     change workers (filter_out (fn thread => Thread.equal (thread, Thread.self ())));
   127   (change_active false; wait (); change_active true));
    88     NONE)
       
    89   else
       
    90     (case change_result queue (TaskQueue.dequeue (Thread.self ())) of
       
    91       NONE => (change_active false; wait (); change_active true; worker_next ())
       
    92     | some => some);
   128 
    93 
   129 fun worker_loop () =
    94 fun worker_loop () =
   130   if excessive_threads () then worker_stop ()
    95   (case SYNCHRONIZED worker_next of
   131   else
    96     NONE => ()
   132     (case SYNCHRONIZED (fn () => change_result tasks next_task) of
    97   | SOME (task, run) =>
   133       NONE => (worker_wait (); worker_loop ())
    98       let
   134     | SOME (id, task) =>
    99         val _ = TaskQueue.set_thread_data (SOME task);
   135         let
   100         val _ = run ();
   136           val _ = SYNCHRONIZED (fn () => change tasks (running_task id (Thread.self ())));
   101         val _ = TaskQueue.set_thread_data NONE;
   137           val _ = task ();
   102         val _ = SYNCHRONIZED (fn () => (change queue (TaskQueue.finished task); notify_all ()));
   138           val _ = SYNCHRONIZED (fn () => change tasks (finished_task id));
   103       in worker_loop () end);
   139           val _ = notify_all ();
       
   140         in worker_loop () end);
       
   141 
   104 
   142 fun worker_start () = SYNCHRONIZED (fn () =>
   105 fun worker_start () = SYNCHRONIZED (fn () =>
   143  (change_active true;
   106  (change_active true;
   144   change workers (cons (Thread.fork (worker_loop, Multithreading.no_interrupts)))));
   107   change workers (cons (Thread.fork (worker_loop, Multithreading.no_interrupts)))));
   145 
   108 
   149 fun scheduler_loop () = (*requires SYNCHRONIZED*)
   112 fun scheduler_loop () = (*requires SYNCHRONIZED*)
   150   let
   113   let
   151     val m = Multithreading.max_threads_value ();
   114     val m = Multithreading.max_threads_value ();
   152     val k = m - length (! workers);
   115     val k = m - length (! workers);
   153     val _ = if k > 0 then funpow k worker_start () else ();
   116     val _ = if k > 0 then funpow k worker_start () else ();
   154   in wait_timeout (Time.fromSeconds 1); scheduler_loop () end;
   117   in wait_timeout (Time.fromMilliseconds 300); scheduler_loop () end;
   155 
   118 
   156 fun check_scheduler () = SYNCHRONIZED (fn () =>
   119 fun check_scheduler () = SYNCHRONIZED (fn () =>
   157   let
   120   let
   158     val scheduler_active =
   121     val scheduler_active =
   159       (case ! scheduler of
   122       (case ! scheduler of
   166   end);
   129   end);
   167 
   130 
   168 
   131 
   169 (* future values *)
   132 (* future values *)
   170 
   133 
   171 fun dependent_future deps (e: unit -> 'a) =
   134 fun future group deps (e: unit -> 'a) =
   172   let
   135   let
   173     val _ = check_scheduler ();
   136     val _ = check_scheduler ();
   174 
   137 
   175     val r = ref (NONE: 'a Exn.result option);
   138     val result = ref (NONE: 'a Exn.result option);
   176     val task = Multithreading.with_attributes (Thread.getAttributes ())
   139     val run = Multithreading.with_attributes (Thread.getAttributes ())
   177       (fn _ => fn () => r := SOME (Exn.capture e ()));
   140       (fn _ => fn () => result := SOME (Exn.capture e ()));
       
   141     val task = SYNCHRONIZED (fn () =>
       
   142       change_result queue (TaskQueue.enqueue group deps run) before notify_all ());
       
   143   in Future {task = task, group = group, result = result} end;
   178 
   144 
   179     val id = serial ();
   145 fun fork e = future NONE [] e;
   180     val _ = SYNCHRONIZED (fn () => change tasks (new_task deps id task));
       
   181     val _ = notify_all ();
       
   182 
   146 
   183   in Future (id, r) end;
   147 fun join (Future {result, ...}) =
   184 
       
   185 fun future e = dependent_future [] e;
       
   186 
       
   187 fun await (Future (_, r)) =
       
   188   let
   148   let
   189     val _ = check_scheduler ();
   149     val _ = check_scheduler ();
       
   150     fun loop () =
       
   151       (case ! result of
       
   152         NONE => (wait (); loop ())
       
   153       | SOME res => res);
       
   154   in Exn.release (SYNCHRONIZED loop) end;
   190 
   155 
   191     fun loop () =
   156 
   192       (case SYNCHRONIZED (fn () => ! r) of
   157 (* interrupts *)
   193         NONE => (SYNCHRONIZED (fn () => wait ()); loop ())
   158 
   194       | SOME res => Exn.release res);
   159 fun interrupt task = SYNCHRONIZED (fn () => TaskQueue.interrupt (! queue) task);
   195   in loop () end;
   160 fun interrupt_group group = SYNCHRONIZED (fn () => TaskQueue.interrupt_group (! queue) group);
   196 
   161 
   197 end;
   162 end;