src/Pure/Tools/build_schedule.scala
changeset 78929 df323f23dfde
parent 78928 6c2c60b852e0
child 78930 f72f576fea3e
--- a/src/Pure/Tools/build_schedule.scala	Wed Nov 08 11:08:03 2023 +0100
+++ b/src/Pure/Tools/build_schedule.scala	Thu Nov 09 11:41:19 2023 +0100
@@ -330,9 +330,46 @@
 
   /* heuristics */
 
-  class Timing_Heuristic(threshold: Time, timing_data: Timing_Data) extends Scheduler(timing_data) {
+  class Timing_Heuristic(
+    threshold: Time,
+    timing_data: Timing_Data,
+    sessions_structure: Sessions.Structure
+  ) extends Scheduler(timing_data) {
+    /* pre-computed properties for efficient heuristic */
+
+    type Node = String
+
+    val build_graph = sessions_structure.build_graph
+    val all_maximals = build_graph.maximals.toSet
+    val maximals_preds =
+      all_maximals.map(node => node -> build_graph.all_preds(List(node)).toSet).toMap
+
+    val remaining_time = build_graph.node_height(timing_data.best_time(_).ms)
+
+    def elapsed_times(node: Node): Map[Node, Long] =
+      build_graph.reachable_length(timing_data.best_time(_).ms, build_graph.imm_succs, List(node))
+
+    def path_times(node: Node): Map[Node, Long] = {
+      val maximals = all_maximals.intersect(build_graph.all_succs(List(node)).toSet)
+      val elapsed_time = elapsed_times(node)
+
+      maximals
+        .flatMap(node => maximals_preds(node).map(_ -> elapsed_time(node)))
+        .groupMapReduce(_._1)(_._2)(_ max _)
+    }
+
+    def is_critical(ms: Long): Boolean = ms > threshold.ms
+
+    val critical_path_nodes =
+      build_graph.keys.map(node =>
+        node -> path_times(node).filter((_, time) => is_critical(time)).keySet).toMap
+
+
+    /* scheduling */
+
+    val host_infos = timing_data.host_infos
+
     def next(state: Build_Process.State): List[Config] = {
-      val host_infos = timing_data.host_infos
       val resources = host_infos.available(state)
 
       def best_threads(task: Build_Process.Task): Int =
@@ -346,22 +383,16 @@
         resources.try_allocate_tasks(free, ready.map(task => task -> best_threads(task)))._1
       else {
         val pending_tasks = state.pending.map(_.name).toSet
-        val graph = state.sessions.graph.restrict(pending_tasks)
-
-        val accumulated_time =
-          graph.node_depth(timing_data.best_time(_).ms).filter((name, _) => graph.is_maximal(name))
 
-        val path_time =
-          accumulated_time.flatMap((name, ms) => graph.all_preds(List(name)).map(_ -> ms)).toMap
+        val critical_nodes = ready.toSet.flatMap(task => critical_path_nodes(task.name))
+        def is_critical(node: Node): Boolean = critical_nodes.contains(node)
 
-        def is_critical(task: String): Boolean = path_time(task) > threshold.ms
+        def parallel_paths(node: Node): Int =
+          build_graph.imm_succs(node).filter(is_critical).map(parallel_paths(_) max 1).sum max 1
 
         val (critical, other) =
-          ready.sortBy(task => path_time(task.name)).partition(task => is_critical(task.name))
-
-        val critical_graph = graph.restrict(is_critical)
-        def parallel_paths(node: String): Int =
-          critical_graph.imm_succs(node).map(suc => parallel_paths(suc) max 1).sum max 1
+          ready.sortBy(task => remaining_time(task.name)).reverse.partition(task =>
+            is_critical(task.name))
 
         val (critical_hosts, other_hosts) =
           host_infos.hosts.sorted(host_infos.host_speeds).reverse.splitAt(
@@ -554,7 +585,11 @@
       new Scheduled_Build_Process(context, progress, server) {
         def init_scheduler(timing_data: Timing_Data): Scheduler = {
           val heuristics =
-            List(5, 10, 20).map(minutes => Timing_Heuristic(Time.minutes(minutes), timing_data))
+            List(5, 10, 20).map(minutes =>
+              Timing_Heuristic(
+                Time.minutes(minutes),
+                timing_data,
+                context.build_deps.sessions_structure))
           new Meta_Heuristic(heuristics, timing_data)
         }
       }