src/Pure/Concurrent/future.ML
author wenzelm
Fri, 18 Mar 2016 16:26:35 +0100
changeset 62663 bea354f6ff21
parent 62505 9e2a65912111
child 62819 d3ff367a16a0
permissions -rw-r--r--
clarified modules; tuned signature;

(*  Title:      Pure/Concurrent/future.ML
    Author:     Makarius

Value-oriented parallel execution via futures and promises.
*)

signature FUTURE =
sig
  type task = Task_Queue.task
  type group = Task_Queue.group
  val new_group: group option -> group
  val worker_task: unit -> task option
  val worker_group: unit -> group option
  val the_worker_group: unit -> group
  val worker_subgroup: unit -> group
  type 'a future
  val task_of: 'a future -> task
  val peek: 'a future -> 'a Exn.result option
  val is_finished: 'a future -> bool
  val ML_statistics: bool Unsynchronized.ref
  val interruptible_task: ('a -> 'b) -> 'a -> 'b
  val cancel_group: group -> unit
  val cancel: 'a future -> unit
  val error_message: Position.T -> (serial * string) * string option -> unit
  val identify_result: Position.T -> 'a Exn.result -> 'a Exn.result
  type params = {name: string, group: group option, deps: task list, pri: int, interrupts: bool}
  val default_params: params
  val forks: params -> (unit -> 'a) list -> 'a future list
  val fork: (unit -> 'a) -> 'a future
  val join_results: 'a future list -> 'a Exn.result list
  val join_result: 'a future -> 'a Exn.result
  val joins: 'a future list -> 'a list
  val join: 'a future -> 'a
  val join_tasks: task list -> unit
  val task_context: string -> group -> ('a -> 'b) -> 'a -> 'b
  val value_result: 'a Exn.result -> 'a future
  val value: 'a -> 'a future
  val cond_forks: params -> (unit -> 'a) list -> 'a future list
  val map: ('a -> 'b) -> 'a future -> 'b future
  val promise_group: group -> (unit -> unit) -> 'a future
  val promise: (unit -> unit) -> 'a future
  val fulfill_result: 'a future -> 'a Exn.result -> unit
  val fulfill: 'a future -> 'a -> unit
  val group_snapshot: group -> task list
  val shutdown: unit -> unit
end;

structure Future: FUTURE =
struct

(** future values **)

type task = Task_Queue.task;
type group = Task_Queue.group;
val new_group = Task_Queue.new_group;


(* identifiers *)

local
  val tag = Universal.tag () : task option Universal.tag;
in
  fun worker_task () = the_default NONE (Thread.getLocal tag);
  fun setmp_worker_task task f x = setmp_thread_data tag (worker_task ()) (SOME task) f x;
end;

val worker_group = Option.map Task_Queue.group_of_task o worker_task;

fun the_worker_group () =
  (case worker_group () of
    SOME group => group
  | NONE => raise Fail "Missing worker thread context");

fun worker_subgroup () = new_group (worker_group ());

fun worker_joining e =
  (case worker_task () of
    NONE => e ()
  | SOME task => Task_Queue.joining task e);

fun worker_waiting deps e =
  (case worker_task () of
    NONE => e ()
  | SOME task => Task_Queue.waiting task deps e);


(* datatype future *)

type 'a result = 'a Exn.result Single_Assignment.var;

datatype 'a future = Future of
 {promised: bool,
  task: task,
  result: 'a result};

fun task_of (Future {task, ...}) = task;
fun result_of (Future {result, ...}) = result;

fun peek x = Single_Assignment.peek (result_of x);
fun is_finished x = is_some (peek x);

val _ =
  PolyML.addPrettyPrinter (fn depth => fn pretty => fn x =>
    (case peek x of
      NONE => PolyML.PrettyString "<future>"
    | SOME (Exn.Exn _) => PolyML.PrettyString "<failed>"
    | SOME (Exn.Res y) => pretty (y, depth)));



(** scheduling **)

(* synchronization *)

val scheduler_event = ConditionVar.conditionVar ();
val work_available = ConditionVar.conditionVar ();
val work_finished = ConditionVar.conditionVar ();

local
  val lock = Mutex.mutex ();
in

fun SYNCHRONIZED name = Multithreading.synchronized name lock;

fun wait cond = (*requires SYNCHRONIZED*)
  Multithreading.sync_wait NONE NONE cond lock;

fun wait_timeout timeout cond = (*requires SYNCHRONIZED*)
  Multithreading.sync_wait NONE (SOME (Time.+ (Time.now (), timeout))) cond lock;

fun signal cond = (*requires SYNCHRONIZED*)
  ConditionVar.signal cond;

fun broadcast cond = (*requires SYNCHRONIZED*)
  ConditionVar.broadcast cond;

end;


(* global state *)

val queue = Unsynchronized.ref Task_Queue.empty;
val next = Unsynchronized.ref 0;
val scheduler = Unsynchronized.ref (NONE: Thread.thread option);
val canceled = Unsynchronized.ref ([]: group list);
val do_shutdown = Unsynchronized.ref false;
val max_workers = Unsynchronized.ref 0;
val max_active = Unsynchronized.ref 0;

val status_ticks = Unsynchronized.ref 0;
val last_round = Unsynchronized.ref Time.zeroTime;
val next_round = seconds 0.05;

datatype worker_state = Working | Waiting | Sleeping;
val workers = Unsynchronized.ref ([]: (Thread.thread * worker_state Unsynchronized.ref) list);

fun count_workers state = (*requires SYNCHRONIZED*)
  fold (fn (_, state_ref) => fn i => if ! state_ref = state then i + 1 else i) (! workers) 0;



(* status *)

val ML_statistics = Unsynchronized.ref false;

fun report_status () = (*requires SYNCHRONIZED*)
  if ! ML_statistics then
    let
      val {ready, pending, running, passive, urgent} = Task_Queue.status (! queue);
      val total = length (! workers);
      val active = count_workers Working;
      val waiting = count_workers Waiting;
      val stats =
       [("now", Markup.print_real (Time.toReal (Time.now ()))),
        ("tasks_ready", Markup.print_int ready),
        ("tasks_pending", Markup.print_int pending),
        ("tasks_running", Markup.print_int running),
        ("tasks_passive", Markup.print_int passive),
        ("tasks_urgent", Markup.print_int urgent),
        ("workers_total", Markup.print_int total),
        ("workers_active", Markup.print_int active),
        ("workers_waiting", Markup.print_int waiting)] @
        ML_Statistics.get ();
    in Output.try_protocol_message (Markup.ML_statistics :: stats) [] end
  else ();


(* cancellation primitives *)

fun cancel_now group = (*requires SYNCHRONIZED*)
  let
    val running = Task_Queue.cancel (! queue) group;
    val _ = running |> List.app (fn thread =>
      if Standard_Thread.is_self thread then ()
      else Standard_Thread.interrupt_unsynchronized thread);
  in running end;

fun cancel_all () = (*requires SYNCHRONIZED*)
  let
    val (groups, threads) = Task_Queue.cancel_all (! queue);
    val _ = List.app Standard_Thread.interrupt_unsynchronized threads;
  in groups end;

fun cancel_later group = (*requires SYNCHRONIZED*)
 (Unsynchronized.change canceled (insert Task_Queue.eq_group group);
  broadcast scheduler_event);

fun interruptible_task f x =
  Multithreading.with_attributes
    (if is_some (worker_task ())
     then Multithreading.private_interrupts
     else Multithreading.public_interrupts)
    (fn _ => f x)
  before Multithreading.interrupted ();


(* worker threads *)

fun worker_exec (task, jobs) =
  let
    val group = Task_Queue.group_of_task task;
    val valid = not (Task_Queue.is_canceled group);
    val ok =
      Task_Queue.running task (fn () =>
        setmp_worker_task task (fn () =>
          fold (fn job => fn ok => job valid andalso ok) jobs true) ());
    val _ =
      if ! Multithreading.trace >= 2 then
        Output.try_protocol_message (Markup.task_statistics :: Task_Queue.task_statistics task) []
      else ();
    val _ = SYNCHRONIZED "finish" (fn () =>
      let
        val maximal = Unsynchronized.change_result queue (Task_Queue.finish task);
        val test = Exn.capture Multithreading.interrupted ();
        val _ =
          if ok andalso not (Exn.is_interrupt_exn test) then ()
          else if null (cancel_now group) then ()
          else cancel_later group;
        val _ = broadcast work_finished;
        val _ = if maximal then () else signal work_available;
      in () end);
  in () end;

fun worker_wait worker_state cond = (*requires SYNCHRONIZED*)
  (case AList.lookup Thread.equal (! workers) (Thread.self ()) of
    SOME state => Unsynchronized.setmp state worker_state wait cond
  | NONE => wait cond);

fun worker_next () = (*requires SYNCHRONIZED*)
  if length (! workers) > ! max_workers then
    (Unsynchronized.change workers (AList.delete Thread.equal (Thread.self ()));
     signal work_available;
     NONE)
  else
    let val urgent_only = count_workers Working > ! max_active in
      (case Unsynchronized.change_result queue (Task_Queue.dequeue (Thread.self ()) urgent_only) of
        NONE => (worker_wait Sleeping work_available; worker_next ())
      | some => (signal work_available; some))
    end;

fun worker_loop name =
  (case SYNCHRONIZED name (fn () => worker_next ()) of
    NONE => ()
  | SOME work => (worker_exec work; worker_loop name));

fun worker_start name = (*requires SYNCHRONIZED*)
  let
    val threads_stack_limit =
      Real.floor (Options.default_real "threads_stack_limit" * 1024.0 * 1024.0 * 1024.0);
    val stack_limit = if threads_stack_limit <= 0 then NONE else SOME threads_stack_limit;
    val worker =
      Standard_Thread.fork {name = "worker", stack_limit = stack_limit, interrupts = false}
        (fn () => worker_loop name);
  in Unsynchronized.change workers (cons (worker, Unsynchronized.ref Working)) end
  handle Fail msg => Multithreading.tracing 0 (fn () => "SCHEDULER: " ^ msg);


(* scheduler *)

fun scheduler_next () = (*requires SYNCHRONIZED*)
  let
    val now = Time.now ();
    val tick = Time.<= (Time.+ (! last_round, next_round), now);
    val _ = if tick then last_round := now else ();


    (* runtime status *)

    val _ =
      if tick then Unsynchronized.change status_ticks (fn i => i + 1) else ();
    val _ =
      if tick andalso ! status_ticks mod (if ! Multithreading.trace >= 1 then 2 else 10) = 0
      then report_status () else ();

    val _ =
      if not tick orelse forall (Thread.isActive o #1) (! workers) then ()
      else
        let
          val (alive, dead) = List.partition (Thread.isActive o #1) (! workers);
          val _ = workers := alive;
        in
          Multithreading.tracing 0 (fn () =>
            "SCHEDULER: disposed " ^ string_of_int (length dead) ^ " dead worker threads")
        end;


    (* worker pool adjustments *)

    val max_active0 = ! max_active;
    val max_workers0 = ! max_workers;

    val m =
      if ! do_shutdown andalso Task_Queue.all_passive (! queue) then 0
      else Multithreading.max_threads_value ();
    val _ = max_active := m;
    val _ = max_workers := 2 * m;

    val missing = ! max_workers - length (! workers);
    val _ =
      if missing > 0 then
        funpow missing (fn () =>
          ignore (worker_start ("worker " ^ string_of_int (Unsynchronized.inc next)))) ()
      else ();

    val _ =
      if ! max_active = max_active0 andalso ! max_workers = max_workers0 then ()
      else signal work_available;


    (* canceled groups *)

    val _ =
      if null (! canceled) then ()
      else
       (Multithreading.tracing 1 (fn () =>
          string_of_int (length (! canceled)) ^ " canceled groups");
        Unsynchronized.change canceled (filter_out (null o cancel_now));
        signal work_available);


    (* delay loop *)

    val _ = Exn.release (wait_timeout next_round scheduler_event);


    (* shutdown *)

    val continue = not (! do_shutdown andalso null (! workers));
    val _ = if continue then () else (report_status (); scheduler := NONE);

    val _ = broadcast scheduler_event;
  in continue end
  handle exn =>
    if Exn.is_interrupt exn then
     (Multithreading.tracing 1 (fn () => "SCHEDULER: Interrupt");
      List.app cancel_later (cancel_all ());
      signal work_available; true)
    else Exn.reraise exn;

fun scheduler_loop () =
 (while
    Multithreading.with_attributes
      (Multithreading.sync_interrupts Multithreading.public_interrupts)
      (fn _ => SYNCHRONIZED "scheduler" (fn () => scheduler_next ()))
  do (); last_round := Time.zeroTime);

fun scheduler_active () = (*requires SYNCHRONIZED*)
  (case ! scheduler of NONE => false | SOME thread => Thread.isActive thread);

fun scheduler_check () = (*requires SYNCHRONIZED*)
 (do_shutdown := false;
  if scheduler_active () then ()
  else
    scheduler :=
      SOME (Standard_Thread.fork {name = "scheduler", stack_limit = NONE, interrupts = false}
        scheduler_loop));



(** futures **)

(* cancel *)

fun cancel_group_unsynchronized group = (*requires SYNCHRONIZED*)
  let
    val _ = if null (cancel_now group) then () else cancel_later group;
    val _ = signal work_available;
    val _ = scheduler_check ();
  in () end;

fun cancel_group group =
  SYNCHRONIZED "cancel_group" (fn () => cancel_group_unsynchronized group);

fun cancel x = cancel_group (Task_Queue.group_of_task (task_of x));


(* results *)

fun error_message pos ((serial, msg), exec_id) =
  Position.setmp_thread_data pos (fn () =>
    let val id = Position.get_id pos in
      if is_none id orelse is_none exec_id orelse id = exec_id
      then Output.error_message' (serial, msg) else ()
    end) ();

fun identify_result pos res =
  res |> Exn.map_exn (fn exn =>
    let val exec_id =
      (case Position.get_id pos of
        NONE => []
      | SOME id => [(Markup.exec_idN, id)])
    in Par_Exn.identify exec_id exn end);

fun assign_result group result res =
  let
    val _ = Single_Assignment.assign result res
      handle exn as Fail _ =>
        (case Single_Assignment.peek result of
          SOME (Exn.Exn e) => Exn.reraise (if Exn.is_interrupt e then e else exn)
        | _ => Exn.reraise exn);
    val ok =
      (case the (Single_Assignment.peek result) of
        Exn.Exn exn =>
          (SYNCHRONIZED "cancel" (fn () => Task_Queue.cancel_group group exn); false)
      | Exn.Res _ => true);
  in ok end;


(* future jobs *)

fun future_job group atts (e: unit -> 'a) =
  let
    val result = Single_Assignment.var "future" : 'a result;
    val pos = Position.thread_data ();
    val context = Context.thread_data ();
    fun job ok =
      let
        val res =
          if ok then
            Exn.capture (fn () =>
              Multithreading.with_attributes atts (fn _ =>
                (Position.setmp_thread_data pos o Context.setmp_thread_data context) e ())) ()
          else Exn.interrupt_exn;
      in assign_result group result (identify_result pos res) end;
  in (result, job) end;


(* fork *)

type params = {name: string, group: group option, deps: task list, pri: int, interrupts: bool};
val default_params: params = {name = "", group = NONE, deps = [], pri = 0, interrupts = true};

fun forks ({name, group, deps, pri, interrupts}: params) es =
  if null es then []
  else
    let
      val grp =
        (case group of
          NONE => worker_subgroup ()
        | SOME grp => grp);
      fun enqueue e queue =
        let
          val atts =
            if interrupts
            then Multithreading.private_interrupts
            else Multithreading.no_interrupts;
          val (result, job) = future_job grp atts e;
          val (task, queue') = Task_Queue.enqueue name grp deps pri job queue;
          val future = Future {promised = false, task = task, result = result};
        in (future, queue') end;
    in
      SYNCHRONIZED "enqueue" (fn () =>
        let
          val (futures, queue') = fold_map enqueue es (! queue);
          val _ = queue := queue';
          val minimal = forall (not o Task_Queue.known_task queue') deps;
          val _ = if minimal then signal work_available else ();
          val _ = scheduler_check ();
        in futures end)
    end;

fun fork e =
  (singleton o forks) {name = "fork", group = NONE, deps = [], pri = 0, interrupts = true} e;


(* join *)

fun get_result x =
  (case peek x of
    NONE => Exn.Exn (Fail "Unfinished future")
  | SOME res =>
      if Exn.is_interrupt_exn res then
        (case Task_Queue.group_status (Task_Queue.group_of_task (task_of x)) of
          [] => res
        | exns => Exn.Exn (Par_Exn.make exns))
      else res);

local

fun join_next atts deps = (*requires SYNCHRONIZED*)
  if null deps then NONE
  else
    (case Unsynchronized.change_result queue (Task_Queue.dequeue_deps (Thread.self ()) deps) of
      (NONE, []) => NONE
    | (NONE, deps') =>
       (worker_waiting deps' (fn () =>
          Multithreading.with_attributes atts (fn _ =>
            Exn.release (worker_wait Waiting work_finished)));
        join_next atts deps')
    | (SOME work, deps') => SOME (work, deps'));

fun join_loop atts deps =
  (case SYNCHRONIZED "join" (fn () => join_next atts deps) of
    NONE => ()
  | SOME (work, deps') => (worker_joining (fn () => worker_exec work); join_loop atts deps'));

in

fun join_results xs =
  let
    val _ =
      if forall is_finished xs then ()
      else if is_some (worker_task ()) then
        Multithreading.with_attributes Multithreading.no_interrupts
          (fn orig_atts => join_loop orig_atts (map task_of xs))
      else List.app (ignore o Single_Assignment.await o result_of) xs;
  in map get_result xs end;

end;

fun join_result x = singleton join_results x;
fun joins xs = Par_Exn.release_all (join_results xs);
fun join x = Exn.release (join_result x);

fun join_tasks tasks =
  if null tasks then ()
  else
    (singleton o forks)
      {name = "join_tasks", group = SOME (new_group NONE),
        deps = tasks, pri = 0, interrupts = false} I
    |> join;


(* task context for running thread *)

fun task_context name group f x =
  Multithreading.with_attributes Multithreading.no_interrupts (fn orig_atts =>
    let
      val (result, job) = future_job group orig_atts (fn () => f x);
      val task =
        SYNCHRONIZED "enroll" (fn () =>
          Unsynchronized.change_result queue (Task_Queue.enroll (Thread.self ()) name group));
      val _ = worker_exec (task, [job]);
    in
      (case Single_Assignment.peek result of
        NONE => raise Fail "Missing task context result"
      | SOME res => Exn.release res)
    end);


(* fast-path operations -- bypass task queue if possible *)

fun value_result (res: 'a Exn.result) =
  let
    val task = Task_Queue.dummy_task;
    val group = Task_Queue.group_of_task task;
    val result = Single_Assignment.var "value" : 'a result;
    val _ = assign_result group result (identify_result (Position.thread_data ()) res);
  in Future {promised = false, task = task, result = result} end;

fun value x = value_result (Exn.Res x);

fun cond_forks args es =
  if Multithreading.enabled () then forks args es
  else map (fn e => value_result (Exn.interruptible_capture e ())) es;

fun map_future f x =
  if is_finished x then value_result (Exn.interruptible_capture (f o join) x)
  else
    let
      val task = task_of x;
      val group = Task_Queue.group_of_task task;
      val (result, job) =
        future_job group Multithreading.private_interrupts (fn () => f (join x));

      val extended = SYNCHRONIZED "extend" (fn () =>
        (case Task_Queue.extend task job (! queue) of
          SOME queue' => (queue := queue'; true)
        | NONE => false));
    in
      if extended then Future {promised = false, task = task, result = result}
      else
        (singleton o cond_forks)
          {name = "map_future", group = SOME group, deps = [task],
            pri = Task_Queue.pri_of_task task, interrupts = true}
          (fn () => f (join x))
    end;


(* promised futures -- fulfilled by external means *)

fun promise_group group abort : 'a future =
  let
    val result = Single_Assignment.var "promise" : 'a result;
    fun assign () = assign_result group result Exn.interrupt_exn
      handle Fail _ => true
        | exn =>
            if Exn.is_interrupt exn
            then raise Fail "Concurrent attempt to fulfill promise"
            else Exn.reraise exn;
    fun job () =
      Multithreading.with_attributes Multithreading.no_interrupts
        (fn _ => Exn.release (Exn.capture assign () before abort ()));
    val task = SYNCHRONIZED "enqueue_passive" (fn () =>
      Unsynchronized.change_result queue (Task_Queue.enqueue_passive group job));
  in Future {promised = true, task = task, result = result} end;

fun promise abort = promise_group (worker_subgroup ()) abort;

fun fulfill_result (Future {promised, task, result}) res =
  if not promised then raise Fail "Not a promised future"
  else
    let
      val group = Task_Queue.group_of_task task;
      val pos = Position.thread_data ();
      fun job ok =
        assign_result group result (if ok then identify_result pos res else Exn.interrupt_exn);
      val _ =
        Multithreading.with_attributes Multithreading.no_interrupts (fn _ =>
          let
            val passive_job =
              SYNCHRONIZED "fulfill_result" (fn () =>
                Unsynchronized.change_result queue
                  (Task_Queue.dequeue_passive (Thread.self ()) task));
          in
            (case passive_job of
              SOME true => worker_exec (task, [job])
            | SOME false => ()
            | NONE => ignore (job (not (Task_Queue.is_canceled group))))
          end);
      val _ =
        if is_some (Single_Assignment.peek result) then ()
        else worker_waiting [task] (fn () => ignore (Single_Assignment.await result));
    in () end;

fun fulfill x res = fulfill_result x (Exn.Res res);


(* group snapshot *)

fun group_snapshot group =
  SYNCHRONIZED "group_snapshot" (fn () =>
    Task_Queue.group_tasks (! queue) group);


(* shutdown *)

fun shutdown () =
  if is_some (worker_task ()) then
    raise Fail "Cannot shutdown while running as worker thread"
  else
    SYNCHRONIZED "shutdown" (fn () =>
      while scheduler_active () do
       (do_shutdown := true;
        Multithreading.tracing 1 (fn () => "SHUTDOWN: wait");
        wait scheduler_event));


(*final declarations of this structure!*)
val map = map_future;

end;

type 'a future = 'a Future.future;