src/Pure/Tools/build_schedule.scala
changeset 79101 4e47b34fbb8e
parent 79091 06f380099b2e
child 79102 4d5f878665a3
--- a/src/Pure/Tools/build_schedule.scala	Fri Dec 01 10:10:59 2023 +0100
+++ b/src/Pure/Tools/build_schedule.scala	Fri Dec 01 20:32:34 2023 +0100
@@ -519,27 +519,29 @@
     type Node = String
 
     val build_graph = sessions_structure.build_graph
-    val all_maximals = build_graph.maximals.toSet
-    val maximals_preds =
-      all_maximals.map(node => node -> build_graph.all_preds(List(node)).toSet).toMap
+
+    val minimals = build_graph.minimals
+    val maximals = build_graph.maximals
+
+    def all_preds(node: Node): Set[Node] = build_graph.all_preds(List(node)).toSet
+    val maximals_all_preds = maximals.map(node => node -> all_preds(node)).toMap
 
     val best_times = build_graph.keys.map(node => node -> best_time(node)).toMap
-    val remaining_time_ms = build_graph.node_height(best_times(_).ms)
 
-    def elapsed_times(node: Node): Map[Node, Time] =
-      build_graph.reachable_length(best_times(_).ms, build_graph.imm_succs, List(node)).map(
-        (node, ms) => node -> Time.ms(ms))
+    val succs_max_time_ms = build_graph.node_height(best_times(_).ms)
+    def max_time(node: Node): Time = Time.ms(succs_max_time_ms(node)) + best_times(node)
+    def max_time(task: Build_Process.Task): Time = max_time(task.name)
 
-    def path_times(node: Node): Map[Node, Time] = {
-      val maximals = all_maximals.intersect(build_graph.all_succs(List(node)).toSet)
-      val elapsed_time = elapsed_times(node)
-
-      maximals
-        .flatMap(node => maximals_preds(node).map(_ -> elapsed_time(node)))
-        .groupMapReduce(_._1)(_._2)(_ max _)
+    def path_times(minimals: List[Node]): Map[Node, Time] = {
+      def time_ms(node: Node): Long = best_times(node).ms
+      val path_times_ms = build_graph.reachable_length(time_ms, build_graph.imm_succs, minimals)
+      path_times_ms.view.mapValues(Time.ms).toMap
     }
 
-    def parallel_paths(minimals: Set[Node], pred: Node => Boolean = _ => true): Int = {
+    def path_max_times(minimals: List[Node]): Map[Node, Time] =
+      path_times(minimals).toList.map((node, time) => node -> (time + max_time(node))).toMap
+
+    def parallel_paths(minimals: List[Node], pred: Node => Boolean = _ => true): Int = {
       def start(node: Node): (Node, Time) = node -> best_times(node)
 
       def pass_time(elapsed: Time)(node: Node, time: Time): (Node, Time) =
@@ -573,9 +575,6 @@
     sessions_structure: Sessions.Structure,
     max_threads_limit: Int = 8
   ) extends Path_Heuristic(timing_data, sessions_structure, max_threads_limit) {
-    val critical_path_nodes =
-      build_graph.keys.map(node =>
-        node -> path_times(node).filter((_, time) => time > threshold).keySet).toMap
 
     def next(state: Build_Process.State): List[Config] = {
       val resources = host_infos.available(state)
@@ -586,24 +585,24 @@
       val rev_ordered_hosts = ordered_hosts.reverse.map(_ -> max_threads)
 
       val resources0 = host_infos.available(state.copy(running = Map.empty))
-      val max_parallel = parallel_paths(state.ready.map(_.name).toSet)
+      val max_parallel = parallel_paths(state.ready.map(_.name))
       val fully_parallelizable = max_parallel <= resources0.unused_nodes(max_threads).length
 
-      val ready_sorted = state.next_ready.sortBy(task => remaining_time_ms(task.name)).reverse
+      val next_sorted = state.next_ready.sortBy(max_time(_).ms).reverse
 
       if (fully_parallelizable) {
-        val all_tasks = ready_sorted.map(task => (task, best_threads(task), best_threads(task)))
+        val all_tasks = next_sorted.map(task => (task, best_threads(task), best_threads(task)))
         resources.try_allocate_tasks(rev_ordered_hosts, all_tasks)._1
       }
       else {
-        val critical_nodes = state.ready.toSet.flatMap(task => critical_path_nodes(task.name))
+        val critical_minimals = state.ready.filter(max_time(_) > threshold).map(_.name)
+        val critical_nodes = path_max_times(critical_minimals).filter(_._2 > threshold).keySet
 
-        val (critical, other) = ready_sorted.partition(task => critical_nodes.contains(task.name))
+        val (critical, other) = next_sorted.partition(task => critical_nodes.contains(task.name))
 
         val critical_tasks = critical.map(task => (task, best_threads(task), best_threads(task)))
         val other_tasks = other.map(task => (task, 1, best_threads(task)))
 
-        val critical_minimals = critical_nodes.intersect(state.ready.map(_.name).toSet)
         val max_critical_parallel = parallel_paths(critical_minimals, critical_nodes.contains)
         val (critical_hosts, other_hosts) = rev_ordered_hosts.splitAt(max_critical_parallel)