merged
authorwenzelm
Tue, 12 Mar 2024 15:31:44 +0100
changeset 79872 85ff8d62c414
parent 79865 53d0d2860ed8 (current diff)
parent 79871 630a82f87310 (diff)
child 79873 6c19c29ddcbe
merged
--- a/src/Pure/Build/build_process.scala	Tue Mar 12 12:11:39 2024 +0000
+++ b/src/Pure/Build/build_process.scala	Tue Mar 12 15:31:44 2024 +0100
@@ -44,10 +44,8 @@
 
   object Task {
     type Entry = (String, Task)
-    def entry(name: String, deps: List[String], build_uuid: String): Entry =
-      name -> Task(name, deps, build_uuid)
     def entry(session: Build_Job.Session_Context, build_context: isabelle.Build.Context): Entry =
-      entry(session.name, session.deps, build_context.build_uuid)
+      session.name -> Task(session.name, session.deps, build_context.build_uuid)
   }
 
   sealed case class Task(
@@ -111,12 +109,15 @@
           })
       }
 
-    def pull(
-      data_domain: Set[String],
-      data: Set[String] => List[Build_Job.Session_Context]
-    ): Sessions = {
-      val dom = data_domain -- iterator.map(_.name)
-      make(data(dom).foldLeft(graph.restrict(dom)) { case (g, e) => g.new_node(e.name, e) })
+    def update(updates: List[Library.Update.Op[Build_Job.Session_Context]]): Sessions = {
+      val graph1 =
+        updates.foldLeft(graph) {
+          case (g, Library.Update.Delete(name)) => g.del_node(name)
+          case (g, Library.Update.Insert(session)) =>
+            (if (g.defined(session.name)) g.del_node(session.name) else g)
+              .new_node(session.name, session)
+        }
+      make(graph1)
     }
 
     def init(
@@ -211,7 +212,7 @@
     results: State.Results)     // finished results
 
   object State {
-    def inc_serial(serial: Long) = {
+    def inc_serial(serial: Long): Long = {
       require(serial < Long.MaxValue, "serial overflow")
       serial + 1
     }
@@ -229,10 +230,7 @@
     running: State.Running = Map.empty,
     results: State.Results = Map.empty
   ) {
-    require(serial >= 0, "serial underflow")
-
     def next_serial: Long = State.inc_serial(serial)
-    def inc_serial: State = copy(serial = next_serial)
 
     def ready: List[Task] = pending.valuesIterator.filter(_.is_ready).toList.sortBy(_.name)
     def next_ready: List[Task] = ready.filter(entry => !is_running(entry.name))
@@ -303,32 +301,6 @@
     private lazy val build_id_tables =
       tables.filter(t => Generic.build_id_table(t) && !Generic.build_uuid_table(t))
 
-    def pull[A <: Library.Named](
-      data_domain: Set[String],
-      data_iterator: Set[String] => Iterator[A],
-      old_data: Map[String, A]
-    ): Map[String, A] = {
-      val dom = data_domain -- old_data.keysIterator
-      val data = old_data -- old_data.keysIterator.filterNot(data_domain)
-      if (dom.isEmpty) data
-      else data_iterator(dom).foldLeft(data) { case (map, a) => map + (a.name -> a) }
-    }
-
-    def pull0[A <: Library.Named](
-      new_data: Map[String, A],
-      old_data: Map[String, A]
-    ): Map[String, A] = {
-      pull(new_data.keySet, dom => new_data.valuesIterator.filter(a => dom(a.name)), old_data)
-    }
-
-    def pull1[A <: Library.Named](
-      data_domain: Set[String],
-      data_base: Set[String] => Map[String, A],
-      old_data: Map[String, A]
-    ): Map[String, A] = {
-      pull(data_domain, dom => data_base(dom).valuesIterator, old_data)
-    }
-
     object Generic {
       val build_id = SQL.Column.long("build_id")
       val build_uuid = SQL.Column.string("build_uuid")
@@ -391,7 +363,7 @@
       build_id: Long,
       serial_seen: Long,
       get: SQL.Result => A
-    ): List[(String, Option[A])] = {
+    ): List[Library.Update.Op[A]] = {
       val domain_columns = List(Updates.dom_name)
       val domain_table =
         SQL.Table("domain", domain_columns, body =
@@ -408,12 +380,10 @@
           domain_table.query_named + SQL.join_outer + table +
             " ON " + Updates.dom + " = " + Generic.name)
 
-      db.execute_query_statement(select_sql, List.from[(String, Option[A])],
-        { res =>
-          val delete = res.bool(Updates.delete)
-          val name = res.string(Updates.name)
-          if (delete) name -> None else name -> Some(get(res))
-        })
+      db.execute_query_statement(select_sql, List.from[Library.Update.Op[A]],
+        res =>
+          if (res.bool(Updates.delete)) Library.Update.Delete(res.string(Updates.name))
+          else Library.Update.Insert(get(res)))
     }
 
     def write_updates(
@@ -477,8 +447,7 @@
           })
 
       for (build <- builds.sortBy(_.start)(Date.Ordering)) yield {
-        val sessions = private_data.read_sessions_domain(db, build_uuid = build.build_uuid)
-        build.copy(sessions = sessions.toList.sorted)
+        build.copy(sessions = private_data.read_sessions(db, build_uuid = build.build_uuid).sorted)
       }
     }
 
@@ -550,39 +519,29 @@
       lazy val table_index: Int = tables.index(table)
     }
 
-    def read_sessions_domain(db: SQL.Database, build_uuid: String = ""): Set[String] =
+    def read_sessions(db: SQL.Database, build_uuid: String = ""): List[String] =
       db.execute_query_statement(
         Sessions.table.select(List(Sessions.name),
           sql = if_proper(build_uuid, Sessions.build_uuid.where_equal(build_uuid))),
-        Set.from[String], res => res.string(Sessions.name))
+        List.from[String], res => res.string(Sessions.name))
 
-    def read_sessions(db: SQL.Database,
-      names: Iterable[String] = Nil,
-      build_uuid: String = ""
-    ): List[Build_Job.Session_Context] = {
-      db.execute_query_statement(
-        Sessions.table.select(
-          sql =
-            SQL.where_and(
-              if_proper(names, Sessions.name.member(names)),
-              if_proper(build_uuid, Sessions.build_uuid.equal(build_uuid)))
-        ),
-        List.from[Build_Job.Session_Context],
-        { res =>
-          val name = res.string(Sessions.name)
-          val deps = split_lines(res.string(Sessions.deps))
-          val ancestors = split_lines(res.string(Sessions.ancestors))
-          val options = res.string(Sessions.options)
-          val sources_shasum = SHA1.fake_shasum(res.string(Sessions.sources))
-          val timeout = Time.ms(res.long(Sessions.timeout))
-          val old_time = Time.ms(res.long(Sessions.old_time))
-          val old_command_timings_blob = res.bytes(Sessions.old_command_timings)
-          val build_uuid = res.string(Sessions.build_uuid)
-          Build_Job.Session_Context(name, deps, ancestors, options, sources_shasum,
-            timeout, old_time, old_command_timings_blob, build_uuid)
-        }
+    def pull_sessions(db: SQL.Database, build_id: Long, state: State): Sessions =
+      state.sessions.update(
+        read_updates(db, Sessions.table, build_id, state.serial,
+          { res =>
+            val name = res.string(Sessions.name)
+            val deps = split_lines(res.string(Sessions.deps))
+            val ancestors = split_lines(res.string(Sessions.ancestors))
+            val options = res.string(Sessions.options)
+            val sources_shasum = SHA1.fake_shasum(res.string(Sessions.sources))
+            val timeout = Time.ms(res.long(Sessions.timeout))
+            val old_time = Time.ms(res.long(Sessions.old_time))
+            val old_command_timings_blob = res.bytes(Sessions.old_command_timings)
+            val build_uuid = res.string(Sessions.build_uuid)
+            Build_Job.Session_Context(name, deps, ancestors, options, sources_shasum,
+              timeout, old_time, old_command_timings_blob, build_uuid)
+          })
       )
-    }
 
     def update_sessions(
       db: SQL.Database,
@@ -725,16 +684,16 @@
       lazy val table_index: Int = tables.index(table)
     }
 
-    def read_pending(db: SQL.Database): State.Pending =
-      db.execute_query_statement(
-        Pending.table.select(),
-        Map.from[String, Task],
-        { res =>
-          val name = res.string(Pending.name)
-          val deps = res.string(Pending.deps)
-          val build_uuid = res.string(Pending.build_uuid)
-          Task.entry(name, split_lines(deps), build_uuid)
+    def pull_pending(db: SQL.Database, build_id: Long, state: State): State.Pending =
+      Library.Update.data(state.pending,
+        read_updates(db, Pending.table, build_id, state.serial,
+          { res =>
+            val name = res.string(Pending.name)
+            val deps = res.string(Pending.deps)
+            val build_uuid = res.string(Pending.build_uuid)
+            Task(name, split_lines(deps), build_uuid)
         })
+      )
 
     def update_pending(
       db: SQL.Database,
@@ -781,22 +740,21 @@
       lazy val table_index: Int = tables.index(table)
     }
 
-    def read_running(db: SQL.Database): State.Running =
-      db.execute_query_statement(
-        Running.table.select(),
-        Map.from[String, Job],
-        { res =>
-          val name = res.string(Running.name)
-          val worker_uuid = res.string(Running.worker_uuid)
-          val build_uuid = res.string(Running.build_uuid)
-          val hostname = res.string(Running.hostname)
-          val numa_node = res.get_int(Running.numa_node)
-          val rel_cpus = res.string(Running.rel_cpus)
-          val start_date = res.date(Running.start_date)
+    def pull_running(db: SQL.Database, build_id: Long, state: State): State.Running =
+      Library.Update.data(state.running,
+        read_updates(db, Running.table, build_id, state.serial,
+          { res =>
+            val name = res.string(Running.name)
+            val worker_uuid = res.string(Running.worker_uuid)
+            val build_uuid = res.string(Running.build_uuid)
+            val hostname = res.string(Running.hostname)
+            val numa_node = res.get_int(Running.numa_node)
+            val rel_cpus = res.string(Running.rel_cpus)
+            val start_date = res.date(Running.start_date)
+            val node_info = Host.Node_Info(hostname, numa_node, Host.Range.from(rel_cpus))
 
-          val node_info = Host.Node_Info(hostname, numa_node, Host.Range.from(rel_cpus))
-          name -> Job(name, worker_uuid, build_uuid, node_info, start_date, None)
-        }
+            Job(name, worker_uuid, build_uuid, node_info, start_date, None)
+          })
       )
 
     def update_running(
@@ -856,44 +814,37 @@
       lazy val table_index: Int = tables.index(table)
     }
 
-    def read_results_domain(db: SQL.Database): Set[String] =
-      db.execute_query_statement(
-        Results.table.select(List(Results.name)),
-        Set.from[String], res => res.string(Results.name))
-
-    def read_results(db: SQL.Database, names: Iterable[String] = Nil): State.Results =
-      db.execute_query_statement(
-        Results.table.select(sql = if_proper(names, Results.name.where_member(names))),
-        Map.from[String, Result],
-        { res =>
-          val name = res.string(Results.name)
-          val worker_uuid = res.string(Results.worker_uuid)
-          val build_uuid = res.string(Results.build_uuid)
-          val hostname = res.string(Results.hostname)
-          val numa_node = res.get_int(Results.numa_node)
-          val rel_cpus = res.string(Results.rel_cpus)
-          val node_info = Host.Node_Info(hostname, numa_node, Host.Range.from(rel_cpus))
+    def pull_results(db: SQL.Database, build_id: Long, state: State): State.Results =
+      Library.Update.data(state.results,
+        read_updates(db, Results.table, build_id, state.serial,
+          { res =>
+            val name = res.string(Results.name)
+            val worker_uuid = res.string(Results.worker_uuid)
+            val build_uuid = res.string(Results.build_uuid)
+            val hostname = res.string(Results.hostname)
+            val numa_node = res.get_int(Results.numa_node)
+            val rel_cpus = res.string(Results.rel_cpus)
+            val node_info = Host.Node_Info(hostname, numa_node, Host.Range.from(rel_cpus))
 
-          val rc = res.int(Results.rc)
-          val out = res.string(Results.out)
-          val err = res.string(Results.err)
-          val timing =
-            res.timing(
-              Results.timing_elapsed,
-              Results.timing_cpu,
-              Results.timing_gc)
-          val process_result =
-            Process_Result(rc,
-              out_lines = split_lines(out),
-              err_lines = split_lines(err),
-              timing = timing)
+            val rc = res.int(Results.rc)
+            val out = res.string(Results.out)
+            val err = res.string(Results.err)
+            val timing =
+              res.timing(
+                Results.timing_elapsed,
+                Results.timing_cpu,
+                Results.timing_gc)
+            val process_result =
+              Process_Result(rc,
+                out_lines = split_lines(out),
+                err_lines = split_lines(err),
+                timing = timing)
 
-          val output_shasum = SHA1.fake_shasum(res.string(Results.output_shasum))
-          val current = res.bool(Results.current)
+            val output_shasum = SHA1.fake_shasum(res.string(Results.output_shasum))
+            val current = res.bool(Results.current)
 
-          name ->
             Result(name, worker_uuid, build_uuid, node_info, process_result, output_shasum, current)
-        }
+          })
       )
 
     def update_results(
@@ -936,17 +887,17 @@
 
     /* collective operations */
 
-    def pull_database(db: SQL.Database, worker_uuid: String, state: State): State = {
+    def pull_database(db: SQL.Database, build_id: Long, worker_uuid: String, state: State): State = {
       val serial_db = read_serial(db)
       if (serial_db == state.serial) state
       else {
         val serial = serial_db max state.serial
         stamp_worker(db, worker_uuid, serial)
 
-        val sessions = state.sessions.pull(read_sessions_domain(db), read_sessions(db, _))
-        val pending = read_pending(db)
-        val running = pull0(read_running(db), state.running)
-        val results = pull1(read_results_domain(db), read_results(db, _), state.results)
+        val sessions = pull_sessions(db, build_id, state)
+        val pending = pull_pending(db, build_id, state)
+        val running = pull_running(db, build_id, state)
+        val results = pull_results(db, build_id, state)
 
         state.copy(serial = serial, sessions = sessions, pending = pending,
           running = running, results = results)
@@ -1134,7 +1085,8 @@
         case None => body
         case Some(db) =>
           Build_Process.private_data.transaction_lock(db, label = label) {
-            val old_state = Build_Process.private_data.pull_database(db, worker_uuid, _state)
+            val old_state =
+              Build_Process.private_data.pull_database(db, build_id, worker_uuid, _state)
             _state = old_state
             val res = body
             _state =
@@ -1267,7 +1219,7 @@
 
   protected final def start_worker(): Unit = synchronized_database("Build_Process.start_worker") {
     for (db <- _build_database) {
-      _state = _state.inc_serial
+      _state = _state.copy(serial = _state.next_serial)
       Build_Process.private_data.start_worker(db, worker_uuid, build_uuid, _state.serial)
     }
   }
--- a/src/Pure/Build/build_schedule.scala	Tue Mar 12 12:11:39 2024 +0000
+++ b/src/Pure/Build/build_schedule.scala	Tue Mar 12 15:31:44 2024 +0100
@@ -468,8 +468,6 @@
     graph: Schedule.Graph,
     serial: Long = 0,
   ) {
-    require(serial >= 0, "serial underflow")
-
     def next_serial: Long = Build_Process.State.inc_serial(serial)
 
     def end: Date =
@@ -851,7 +849,8 @@
           case None => body
           case Some(db) =>
             db.transaction_lock(Build_Schedule.private_data.all_tables, label = label) {
-              val old_state = Build_Process.private_data.pull_database(db, worker_uuid, _state)
+              val old_state =
+                Build_Process.private_data.pull_database(db, build_id, worker_uuid, _state)
               val old_schedule = Build_Schedule.private_data.pull_schedule(db, _schedule)
               _state = old_state
               _schedule = old_schedule
--- a/src/Pure/library.scala	Tue Mar 12 12:11:39 2024 +0000
+++ b/src/Pure/library.scala	Tue Mar 12 15:31:44 2024 +0100
@@ -285,11 +285,26 @@
     }
 
 
+  /* named items */
+
+  trait Named { def name: String }
+
+
   /* data update */
 
   object Update {
     type Data[A] = Map[String, A]
 
+    sealed abstract class Op[A]
+    case class Delete[A](name: String) extends Op[A]
+    case class Insert[A](item: A) extends Op[A]
+
+    def data[A <: Named](old_data: Data[A], updates: List[Op[A]]): Data[A] =
+      updates.foldLeft(old_data) {
+        case (map, Delete(name)) => map - name
+        case (map, Insert(item)) => map + (item.name -> item)
+      }
+
     val empty: Update = Update()
 
     def make[A](a: Data[A], b: Data[A], kind: Int = 0): Update =
@@ -343,9 +358,4 @@
 
   def as_subclass[C](c: Class[C])(x: AnyRef): Option[C] =
     if (x == null || is_subclass(x.getClass, c)) Some(x.asInstanceOf[C]) else None
-
-
-  /* named items */
-
-  trait Named { def name: String }
 }