src/Pure/Build/build_schedule.scala
changeset 79594 f933e9153624
parent 79593 587a7dfeb03c
child 79614 58c0636e0ef5
--- a/src/Pure/Build/build_schedule.scala	Tue Feb 13 11:57:41 2024 +0100
+++ b/src/Pure/Build/build_schedule.scala	Tue Feb 13 16:03:55 2024 +0100
@@ -356,8 +356,8 @@
 
     def available(state: Build_Process.State): Resources = {
       val allocated =
-        state.running.values.map(_.node_info).groupMapReduce(the_host)(List(_))(_ ::: _)
-      Resources(this, allocated)
+        state.running.values.map(_.node_info).groupMapReduce(_.hostname)(List(_))(_ ::: _)
+      new Resources(this, allocated)
     }
   }
 
@@ -369,9 +369,9 @@
       Build_Process.Job(job_name, "", "", node_info, Date(start_time), None)
   }
 
-  case class Resources(
-    host_infos: Host_Infos,
-    allocated_nodes: Map[Host, List[Node_Info]]
+  class Resources(
+    val host_infos: Host_Infos,
+    allocated_nodes: Map[String, List[Node_Info]]
   ) {
     def unused_nodes(host: Host, threads: Int): List[Node_Info] =
       if (!available(host, threads)) Nil
@@ -383,11 +383,11 @@
     def unused_nodes(threads: Int): List[Node_Info] =
       host_infos.hosts.flatMap(unused_nodes(_, threads))
 
-    def allocated(host: Host): List[Node_Info] = allocated_nodes.getOrElse(host, Nil)
+    def allocated(host: Host): List[Node_Info] = allocated_nodes.getOrElse(host.name, Nil)
 
     def allocate(node_info: Node_Info): Resources = {
       val host = host_infos.the_host(node_info)
-      copy(allocated_nodes = allocated_nodes + (host -> (node_info :: allocated(host))))
+      new Resources(host_infos, allocated_nodes + (host.name -> (node_info :: allocated(host))))
     }
 
     def try_allocate_tasks(
@@ -547,7 +547,8 @@
 
         val host_preds =
           for {
-            (name, (pred_node, _)) <- finished.graph.iterator.toSet
+            name <- finished.graph.keys
+            pred_node = finished.graph.get_node(name)
             if pred_node.node_info.hostname == job.node_info.hostname
             if pred_node.end.time <= node.start.time
           } yield name
@@ -693,9 +694,12 @@
     def all_preds(node: Node): Set[Node] = build_graph.all_preds(List(node)).toSet
     val maximals_all_preds = maximals.map(node => node -> all_preds(node)).toMap
 
+    val best_threads =
+      build_graph.keys.map(node => node -> timing_data.best_threads(node, max_threads)).toMap
+
     def best_time(node: Node): Time = {
       val host = ordered_hosts.last
-      val threads = timing_data.best_threads(node, max_threads) min host.info.num_cpus
+      val threads = best_threads(node) min host.info.num_cpus
       timing_data.estimate(node, host.name, threads)
     }
     val best_times = build_graph.keys.map(node => node -> best_time(node)).toMap
@@ -713,38 +717,47 @@
     def path_max_times(minimals: List[Node]): Map[Node, Time] =
       path_times(minimals).toList.map((node, time) => node -> (time + max_time(node))).toMap
 
-    def parallel_paths(running: List[(Node, Time)], pred: Node => Boolean = _ => true): Int = {
-      def start(node: Node): (Node, Time) = node -> best_times(node)
-
-      def pass_time(elapsed: Time)(node: Node, time: Time): (Node, Time) =
-        node -> (time - elapsed)
+    val node_degrees =
+      build_graph.keys.map(node => node -> build_graph.imm_succs(node).size).toMap
 
-      def parallel_paths(running: Map[Node, Time]): (Int, Map[Node, Time]) =
-        if (running.isEmpty) (0, running)
-        else {
-          def get_next(node: Node): List[Node] =
-            build_graph.imm_succs(node).filter(pred).filter(
-              build_graph.imm_preds(_).intersect(running.keySet) == Set(node)).toList
+    def parallel_paths(
+      running: List[(Node, Time)],
+      nodes: Set[Node] = build_graph.keys.toSet,
+      max: Int = Int.MaxValue
+    ): Int =
+      if (nodes.nonEmpty && nodes.map(node_degrees.apply).max > max) max
+      else {
+        def start(node: Node): (Node, Time) = node -> best_times(node)
+
+        def pass_time(elapsed: Time)(node: Node, time: Time): (Node, Time) =
+          node -> (time - elapsed)
 
-          val (next, elapsed) = running.minBy(_._2.ms)
-          val (remaining, finished) =
-            running.toList.map(pass_time(elapsed)).partition(_._2 > Time.zero)
+        def parallel_paths(running: Map[Node, Time]): (Int, Map[Node, Time]) =
+          if (running.size >= max) (max, running)
+          else if (running.isEmpty) (0, running)
+          else {
+            def get_next(node: Node): List[Node] =
+              build_graph.imm_succs(node).intersect(nodes).filter(
+                build_graph.imm_preds(_).intersect(running.keySet) == Set(node)).toList
 
-          val running1 =
-            remaining.map(pass_time(elapsed)).toMap ++
-              finished.map(_._1).flatMap(get_next).map(start)
-          val (res, running2) = parallel_paths(running1)
-          (res max running.size, running2)
-        }
+            val (next, elapsed) = running.minBy(_._2.ms)
+            val (remaining, finished) =
+              running.toList.map(pass_time(elapsed)).partition(_._2 > Time.zero)
 
-      parallel_paths(running.toMap)._1
-    }
+            val running1 =
+              remaining.map(pass_time(elapsed)).toMap ++
+                finished.map(_._1).flatMap(get_next).map(start)
+            val (res, running2) = parallel_paths(running1)
+            (res max running.size, running2)
+          }
+
+        parallel_paths(running.toMap)._1
+      }
 
     def select_next(state: Build_Process.State): List[Config] = {
       val resources = host_infos.available(state)
 
-      def best_threads(task: Build_Process.Task): Int =
-        timing_data.best_threads(task.name, max_threads)
+      def best_threads(task: Build_Process.Task): Int = this.best_threads(task.name)
 
       val rev_ordered_hosts = ordered_hosts.reverse.map(_ -> max_threads)
 
@@ -755,7 +768,7 @@
 
       def remaining_time(node: Node): (Node, Time) =
         state.running.get(node) match {
-          case None => node -> best_time(node)
+          case None => node -> best_times(node)
           case Some(job) =>
             val estimate =
               timing_data.estimate(job.name, job.node_info.hostname,
@@ -763,10 +776,13 @@
             node -> ((Time.now() - job.start_date.time + estimate) max Time.zero)
         }
 
-      val max_parallel = parallel_paths(state.ready.map(_.name).map(remaining_time))
       val next_sorted = state.next_ready.sortBy(max_time(_).ms).reverse
+      val is_parallelizable =
+        available_nodes.length >= parallel_paths(
+          state.ready.map(_.name).map(remaining_time),
+          max = available_nodes.length + 1)
 
-      if (max_parallel <= available_nodes.length) {
+      if (is_parallelizable) {
         val all_tasks = next_sorted.map(task => (task, best_threads(task), best_threads(task)))
         resources.try_allocate_tasks(rev_ordered_hosts, all_tasks)._1
       }
@@ -788,13 +804,13 @@
         def parallel_threads(task: Build_Process.Task): Int =
           this.parallel_threads match {
             case Fixed_Thread(threads) => threads
-            case Time_Based_Threads(f) => f(best_time(task.name))
+            case Time_Based_Threads(f) => f(best_times(task.name))
           }
 
         val other_tasks = other.map(task => (task, parallel_threads(task), best_threads(task)))
 
         val max_critical_parallel =
-          parallel_paths(critical_minimals.map(remaining_time), critical_nodes.contains)
+          parallel_paths(critical_minimals.map(remaining_time), critical_nodes)
         val max_critical_hosts =
           available_nodes.take(max_critical_parallel).map(_.hostname).distinct.length
 
@@ -1318,6 +1334,7 @@
       def schedule_msg(res: Exn.Result[Schedule]): String =
         res match { case Exn.Res(schedule) => schedule.message case _ => "" }
 
+      progress.echo("Building schedule...")
       Timing.timeit(scheduler.schedule(build_state), schedule_msg, output = progress.echo(_))
     }