src/Pure/Concurrent/future.ML
author wenzelm
Thu Oct 02 23:52:12 2008 +0200 (2008-10-02 ago)
changeset 28470 409fedeece30
parent 28468 7c80ab57f56d
child 28471 00046e3b46b5
permissions -rw-r--r--
added simple heartbeat thread;
     1 (*  Title:      Pure/Concurrent/future.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Future values.
     6 
     7 Notes:
     8 
     9   * Futures are similar to delayed evaluation, i.e. delay/force is
    10     generalized to fork/join (and variants).  The idea is to model
    11     parallel value-oriented computations, but *not* communicating
    12     processes.
    13 
    14   * Futures are grouped; failure of one group member causes the whole
    15     group to be interrupted eventually.
    16 
    17   * Forked futures are evaluated spontaneously by a farm of worker
    18     threads in the background; join resynchronizes the computation and
    19     delivers results (values or exceptions).
    20 
    21   * The pool of worker threads is limited, usually in correlation with
    22     the number of physical cores on the machine.  Note that allocation
    23     of runtime resources is distorted either if workers yield CPU time
    24     (e.g. via system sleep or wait operations), or if non-worker
    25     threads contend for significant runtime resources independently.
    26 *)
    27 
    28 signature FUTURE =
    29 sig
    30   type task = TaskQueue.task
    31   type group = TaskQueue.group
    32   val thread_data: unit -> (string * task * group) option
    33   type 'a T
    34   val task_of: 'a T -> task
    35   val group_of: 'a T -> group
    36   val str_of: 'a T -> string
    37   val is_finished: 'a T -> bool
    38   val future: group option -> task list -> bool -> (unit -> 'a) -> 'a T
    39   val fork: (unit -> 'a) -> 'a T
    40   val fork_background: (unit -> 'a) -> 'a T
    41   val join_results: 'a T list -> 'a Exn.result list
    42   val join: 'a T -> 'a
    43   val focus: task list -> unit
    44   val interrupt_task: string -> unit
    45   val cancel: 'a T -> unit
    46   val shutdown: unit -> unit
    47 end;
    48 
    49 structure Future: FUTURE =
    50 struct
    51 
    52 (** future values **)
    53 
    54 (* identifiers *)
    55 
    56 type task = TaskQueue.task;
    57 type group = TaskQueue.group;
    58 
    59 local val tag = Universal.tag () : (string * task * group) option Universal.tag in
    60   fun thread_data () = the_default NONE (Thread.getLocal tag);
    61   fun setmp_thread_data data f x = Library.setmp_thread_data tag (thread_data ()) (SOME data) f x;
    62 end;
    63 
    64 
    65 (* datatype future *)
    66 
    67 datatype 'a T = Future of
    68  {task: task,
    69   group: group,
    70   result: 'a Exn.result option ref};
    71 
    72 fun task_of (Future {task, ...}) = task;
    73 fun group_of (Future {group, ...}) = group;
    74 
    75 fun str_of (Future {result, ...}) =
    76   (case ! result of
    77     NONE => "<future>"
    78   | SOME (Exn.Result _) => "<finished future>"
    79   | SOME (Exn.Exn _) => "<failed future>");
    80 
    81 fun is_finished (Future {result, ...}) = is_some (! result);
    82 
    83 
    84 
    85 (** scheduling **)
    86 
    87 (* global state *)
    88 
    89 val queue = ref TaskQueue.empty;
    90 val next = ref 0;
    91 val workers = ref ([]: (Thread.thread * bool) list);
    92 val scheduler = ref (NONE: Thread.thread option);
    93 val excessive = ref 0;
    94 val canceled = ref ([]: TaskQueue.group list);
    95 val do_shutdown = ref false;
    96 
    97 
    98 (* synchronization *)
    99 
   100 local
   101   val lock = Mutex.mutex ();
   102   val cond = ConditionVar.conditionVar ();
   103 in
   104 
   105 fun SYNCHRONIZED name e = Exn.release (uninterruptible (fn restore_attributes => fn () =>
   106   let
   107     val _ =
   108       if Mutex.trylock lock then Multithreading.tracing 3 (fn () => name ^ ": locked")
   109       else
   110        (Multithreading.tracing 2 (fn () => name ^ ": locking ...");
   111         Mutex.lock lock;
   112         Multithreading.tracing 2 (fn () => name ^ ": ... locked"));
   113     val result = Exn.capture (restore_attributes e) ();
   114     val _ = Mutex.unlock lock;
   115     val _ = Multithreading.tracing 3 (fn () => name ^ ": unlocked");
   116   in result end) ());
   117 
   118 fun wait name = (*requires SYNCHRONIZED*)
   119  (Multithreading.tracing 3 (fn () => name ^ ": wait ...");
   120   ConditionVar.wait (cond, lock);
   121   Multithreading.tracing 3 (fn () => name ^ ": ... continue"));
   122 
   123 fun wait_timeout name timeout = (*requires SYNCHRONIZED*)
   124  (Multithreading.tracing 3 (fn () => name ^ ": wait ...");
   125   ConditionVar.waitUntil (cond, lock, Time.+ (Time.now (), timeout));
   126   Multithreading.tracing 3 (fn () => name ^ ": ... continue"));
   127 
   128 fun notify_all () = (*requires SYNCHRONIZED*)
   129   ConditionVar.broadcast cond;
   130 
   131 end;
   132 
   133 
   134 (* worker activity *)
   135 
   136 fun trace_active () =
   137   let
   138     val ws = ! workers;
   139     val m = string_of_int (length ws);
   140     val n = string_of_int (length (filter #2 ws));
   141   in Multithreading.tracing 1 (fn () => "SCHEDULE: " ^ m ^ " workers, " ^ n ^ " active") end;
   142 
   143 fun change_active active = (*requires SYNCHRONIZED*)
   144   change workers (AList.update Thread.equal (Thread.self (), active));
   145 
   146 
   147 (* execute *)
   148 
   149 fun execute name (task, group, run) =
   150   let
   151     val _ = trace_active ();
   152     val _ = Multithreading.tracing 3 (fn () => name ^ ": running");
   153     val ok = setmp_thread_data (name, task, group) run ();
   154     val _ = Multithreading.tracing 3 (fn () => name ^ ": finished");
   155     val _ = SYNCHRONIZED "execute" (fn () =>
   156      (change queue (TaskQueue.finish task);
   157       if ok then ()
   158       else if TaskQueue.cancel (! queue) group then ()
   159       else change canceled (cons group);
   160       notify_all ()));
   161   in () end;
   162 
   163 
   164 (* worker threads *)
   165 
   166 fun worker_wait name = (*requires SYNCHRONIZED*)
   167   (change_active false; wait name; change_active true);
   168 
   169 fun worker_next name = (*requires SYNCHRONIZED*)
   170   if ! excessive > 0 then
   171     (dec excessive;
   172      change workers (filter_out (fn (thread, _) => Thread.equal (thread, Thread.self ())));
   173      notify_all ();
   174      NONE)
   175   else
   176     (case change_result queue TaskQueue.dequeue of
   177       NONE => (worker_wait name; worker_next name)
   178     | some => some);
   179 
   180 fun worker_loop name =
   181   (case SYNCHRONIZED name (fn () => worker_next name) of
   182     NONE => Multithreading.tracing 3 (fn () => name ^ ": exit")
   183   | SOME work => (execute name work; worker_loop name));
   184 
   185 fun worker_start name = (*requires SYNCHRONIZED*)
   186   change workers (cons (SimpleThread.fork false (fn () => worker_loop name), true));
   187 
   188 
   189 (* scheduler *)
   190 
   191 fun heartbeat name =
   192  (Multithreading.tracing 1 (fn () => name);
   193   OS.Process.sleep (Time.fromMilliseconds 100);
   194   if ! do_shutdown then () else heartbeat name);
   195 
   196 fun scheduler_next () = (*requires SYNCHRONIZED*)
   197   let
   198     (*worker threads*)
   199     val _ =
   200       (case List.partition (Thread.isActive o #1) (! workers) of
   201         (_, []) => ()
   202       | (active, inactive) =>
   203           (workers := active; Multithreading.tracing 0 (fn () =>
   204             "SCHEDULE: disposed " ^ string_of_int (length inactive) ^ " dead worker threads")));
   205     val _ = trace_active ();
   206 
   207     val m = if ! do_shutdown then 0 else Multithreading.max_threads_value ();
   208     val l = length (! workers);
   209     val _ = excessive := l - m;
   210     val _ =
   211       if m > l then funpow (m - l) (fn () => worker_start ("worker " ^ string_of_int (inc next))) ()
   212       else ();
   213 
   214     (*canceled groups*)
   215     val _ =  change canceled (filter_out (TaskQueue.cancel (! queue)));
   216 
   217     (*shutdown*)
   218     val continue = not (! do_shutdown andalso null (! workers));
   219     val _ = if continue then () else scheduler := NONE;
   220 
   221     val _ = notify_all ();
   222     val _ = wait_timeout "scheduler" (Time.fromSeconds 3);
   223   in continue end;
   224 
   225 fun scheduler_loop () =
   226  (while SYNCHRONIZED "scheduler" scheduler_next do ();
   227   Multithreading.tracing 2 (fn () => "scheduler: exit"));
   228 
   229 fun scheduler_active () = (*requires SYNCHRONIZED*)
   230   (case ! scheduler of NONE => false | SOME thread => Thread.isActive thread);
   231 
   232 fun scheduler_check name = SYNCHRONIZED name (fn () =>
   233   if not (scheduler_active ()) then
   234     (Multithreading.tracing 2 (fn () => "scheduler: fork");
   235      do_shutdown := false; scheduler := SOME (SimpleThread.fork false scheduler_loop);
   236      SimpleThread.fork false (fn () => heartbeat ("heartbeat " ^ string_of_int (inc next))); ())
   237   else if ! do_shutdown then error "Scheduler shutdown in progress"
   238   else ());
   239 
   240 
   241 (* future values: fork independent computation *)
   242 
   243 fun future opt_group deps pri (e: unit -> 'a) =
   244   let
   245     val _ = scheduler_check "future check";
   246 
   247     val group = (case opt_group of SOME group => group | NONE => TaskQueue.new_group ());
   248 
   249     val result = ref (NONE: 'a Exn.result option);
   250     val run = Multithreading.with_attributes (Thread.getAttributes ())
   251       (fn _ => fn ok =>
   252         let val res = if ok then Exn.capture e () else Exn.Exn Exn.Interrupt
   253         in result := SOME res; is_some (Exn.get_result res) end);
   254 
   255     val task = SYNCHRONIZED "future" (fn () =>
   256       change_result queue (TaskQueue.enqueue group deps pri run) before notify_all ());
   257   in Future {task = task, group = group, result = result} end;
   258 
   259 fun fork_common pri = future (Option.map #3 (thread_data ())) [] pri;
   260 
   261 fun fork e = fork_common true e;
   262 fun fork_background e = fork_common false e;
   263 
   264 
   265 (* join: retrieve results *)
   266 
   267 fun join_results [] = []
   268   | join_results xs =
   269       let
   270         val _ = scheduler_check "join check";
   271         val _ = Multithreading.self_critical () andalso
   272           error "Cannot join future values within critical section";
   273 
   274         fun join_loop _ [] = ()
   275           | join_loop name tasks =
   276               (case SYNCHRONIZED name (fn () =>
   277                   change_result queue (TaskQueue.dequeue_towards tasks)) of
   278                 NONE => ()
   279               | SOME (work, tasks') => (execute name work; join_loop name tasks'));
   280         val _ =
   281           (case thread_data () of
   282             NONE =>
   283               (*alien thread -- refrain from contending for resources*)
   284               while exists (not o is_finished) xs
   285               do SYNCHRONIZED "join_thread" (fn () => wait "join_thread")
   286           | SOME (name, task, _) =>
   287               (*proper task -- actively work towards results*)
   288               let
   289                 val unfinished = xs |> map_filter
   290                   (fn Future {task, result = ref NONE, ...} => SOME task | _ => NONE);
   291                 val _ = SYNCHRONIZED "join" (fn () =>
   292                   (change queue (TaskQueue.depend unfinished task); notify_all ()));
   293                 val _ = join_loop ("join_loop: " ^ name) unfinished;
   294                 val _ =
   295                   while exists (not o is_finished) xs
   296                   do SYNCHRONIZED "join_task" (fn () => worker_wait "join_task");
   297               in () end);
   298 
   299       in xs |> map (fn Future {result = ref (SOME res), ...} => res) end;
   300 
   301 fun join x = Exn.release (singleton join_results x);
   302 
   303 
   304 (* misc operations *)
   305 
   306 (*focus: collection of high-priority task*)
   307 fun focus tasks = SYNCHRONIZED "focus" (fn () =>
   308   change queue (TaskQueue.focus tasks));
   309 
   310 (*interrupt: permissive signal, may get ignored*)
   311 fun interrupt_task id = SYNCHRONIZED "interrupt"
   312   (fn () => TaskQueue.interrupt_external (! queue) id);
   313 
   314 (*cancel: present and future group members will be interrupted eventually*)
   315 fun cancel x =
   316  (scheduler_check "cancel check";
   317   SYNCHRONIZED "cancel" (fn () => (change canceled (cons (group_of x)); notify_all ())));
   318 
   319 
   320 (*global join and shutdown*)
   321 fun shutdown () =
   322   if Multithreading.available then
   323    (scheduler_check "shutdown check";
   324     SYNCHRONIZED "shutdown" (fn () =>
   325      (while not (scheduler_active ()) do wait "shutdown: scheduler inactive";
   326       while not (TaskQueue.is_empty (! queue)) do wait "shutdown: join";
   327       do_shutdown := true;
   328       notify_all ();
   329       while not (null (! workers)) do wait "shutdown: workers";
   330       while scheduler_active () do wait "shutdown: scheduler still active";
   331       OS.Process.sleep (Time.fromMilliseconds 300))))
   332   else ();
   333 
   334 end;