src/Pure/Tools/build_process.scala
changeset 77652 5f706f7c624b
parent 77651 b7fe1d822dc1
child 77653 26bb79d17910
--- a/src/Pure/Tools/build_process.scala	Tue Mar 14 11:14:50 2023 +0100
+++ b/src/Pure/Tools/build_process.scala	Tue Mar 14 17:05:49 2023 +0100
@@ -172,7 +172,7 @@
     build_uuid: String,
     node_info: Host.Node_Info,
     build: Option[Build_Job]
-  ) {
+  ) extends Library.Named {
     def no_build: Job = copy(build = None)
   }
 
@@ -184,7 +184,7 @@
     process_result: Process_Result,
     output_shasum: SHA1.Shasum,
     current: Boolean
-  ) {
+  ) extends Library.Named {
     def ok: Boolean = process_result.ok
   }
 
@@ -284,6 +284,32 @@
     def make_table(name: String, columns: List[SQL.Column], body: String = ""): SQL.Table =
       SQL.Table("isabelle_build" + if_proper(name, "_" + name), columns, body = body)
 
+    def pull_data[A <: Library.Named](
+      data_domain: Set[String],
+      data_iterator: Set[String] => Iterator[A],
+      old_data: Map[String, A]
+    ): Map[String, A] = {
+      val dom = data_domain -- old_data.keysIterator
+      val data = old_data -- old_data.keysIterator.filterNot(dom)
+      if (dom.isEmpty) data
+      else data_iterator(dom).foldLeft(data) { case (map, a) => map + (a.name -> a) }
+    }
+
+    def pull0[A <: Library.Named](
+      new_data: Map[String, A],
+      old_data: Map[String, A]
+    ): Map[String, A] = {
+      pull_data(new_data.keySet, dom => new_data.valuesIterator.filter(a => dom(a.name)), old_data)
+    }
+
+    def pull1[A <: Library.Named](
+      data_domain: Set[String],
+      data_base: Set[String] => Map[String, A],
+      old_data: Map[String, A]
+    ): Map[String, A] = {
+      pull_data(data_domain, dom => data_base(dom).valuesIterator, old_data)
+    }
+
     object Generic {
       val build_uuid = SQL.Column.string("build_uuid")
       val worker_uuid = SQL.Column.string("worker_uuid")
@@ -551,8 +577,9 @@
       build_uuid: String,
       hostname: String,
       java_pid: Long,
-      java_start: Date
-    ): Long = {
+      java_start: Date,
+      serial: Long
+    ): Unit = {
       def err(msg: String): Nothing =
         error("Cannot start worker " + worker_uuid + if_proper(msg, "\n" + msg))
 
@@ -567,7 +594,6 @@
         case None => err("for unknown build process " + build_uuid)
       }
 
-      val serial = serial_max(db)
       db.execute_statement(Workers.table.insert(), body =
         { stmt =>
           val now = db.now()
@@ -581,7 +607,6 @@
           stmt.date(8) = None
           stmt.long(9) = serial
         })
-      serial
     }
 
     def stamp_worker(
@@ -658,22 +683,22 @@
       val table = make_table("running", List(name, worker_uuid, build_uuid, hostname, numa_node))
     }
 
-    def read_running(db: SQL.Database): List[Job] =
+    def read_running(db: SQL.Database): State.Running =
       db.execute_query_statement(
         Running.table.select(sql = SQL.order_by(List(Running.name))),
-        List.from[Job],
+        Map.from[String, Job],
         { res =>
           val name = res.string(Running.name)
           val worker_uuid = res.string(Running.worker_uuid)
           val build_uuid = res.string(Running.build_uuid)
           val hostname = res.string(Running.hostname)
           val numa_node = res.get_int(Running.numa_node)
-          Job(name, worker_uuid, build_uuid, Host.Node_Info(hostname, numa_node), None)
+          name -> Job(name, worker_uuid, build_uuid, Host.Node_Info(hostname, numa_node), None)
         }
       )
 
     def update_running(db: SQL.Database, running: State.Running): Boolean = {
-      val running0 = read_running(db)
+      val running0 = read_running(db).valuesIterator.toList
       val running1 = running.valuesIterator.map(_.no_build).toList
 
       val (delete, insert) = Library.symmetric_difference(running0, running1)
@@ -726,7 +751,7 @@
         Results.table.select(List(Results.name)),
         Set.from[String], res => res.string(Results.name))
 
-    def read_results(db: SQL.Database, names: List[String] = Nil): State.Results =
+    def read_results(db: SQL.Database, names: Iterable[String] = Nil): State.Results =
       db.execute_query_statement(
         Results.table.select(sql = if_proper(names, Results.name.where_member(names))),
         Map.from[String, Result],
@@ -801,6 +826,30 @@
         Results.table,
         Host.Data.Node_Info.table)
 
+    def pull_database(
+      db: SQL.Database,
+      worker_uuid: String,
+      hostname: String,
+      state: State
+    ): State = {
+      val serial0 = serial_max(db)
+      if (serial0 == state.serial) state
+      else {
+        val serial = serial0 max state.serial
+        stamp_worker(db, worker_uuid, serial)
+
+        val numa_next = Host.Data.read_numa_next(db, hostname)
+        val sessions = pull1(read_sessions_domain(db), read_sessions(db, _), state.sessions)
+        val workers = read_workers(db)
+        val pending = read_pending(db)
+        val running = pull0(read_running(db), state.running)
+        val results = pull1(read_results_domain(db), read_results(db, _), state.results)
+
+        state.copy(serial = serial, numa_next = numa_next, sessions = sessions,
+          workers = workers, pending = pending, running = running, results = results)
+      }
+    }
+
     def update_database(
       db: SQL.Database,
       worker_uuid: String,
@@ -816,7 +865,7 @@
           update_results(db, state.results),
           Host.Data.update_numa_next(db, hostname, state.numa_next))
 
-      val serial0 = serial_max(db)
+      val serial0 = state.serial
       val serial = if (changed.exists(identity)) State.inc_serial(serial0) else serial0
 
       stamp_worker(db, worker_uuid, serial)
@@ -849,7 +898,7 @@
 
   /* global state: internal var vs. external database */
 
-  private var _state: Build_Process.State = init_state(Build_Process.State())
+  private var _state: Build_Process.State = Build_Process.State()
 
   private val _database: Option[SQL.Database] = store.open_build_database()
 
@@ -860,34 +909,36 @@
       _database match {
         case None => body
         case Some(db) =>
-          @tailrec def loop(): A = {
-            val sync_progress =
-              db.transaction_lock(Build_Process.Data.all_tables) {
-                val (messages, sync) =
-                  Build_Process.Data.sync_progress(
-                    db, _state.progress_seen, build_uuid, build_progress)
-                if (sync) Left(body) else Right(messages)
-              }
-            sync_progress match {
+          def pull_database(): Unit = {
+            _state = Build_Process.Data.pull_database(db, worker_uuid, build_context.hostname, _state)
+          }
+
+          def sync_database(): Unit = {
+            _state =
+              Build_Process.Data.update_database(
+                db, worker_uuid, build_uuid, build_context.hostname, _state)
+          }
+
+          def attempt(): Either[A, Build_Process.Progress_Messages] = {
+            val (messages, sync) =
+              Build_Process.Data.sync_progress(
+                db, _state.progress_seen, build_uuid, build_progress)
+            if (sync) Left { pull_database(); val res = body; sync_database(); res }
+            else Right(messages)
+          }
+
+          @tailrec def attempts(): A = {
+            db.transaction_lock(Build_Process.Data.all_tables) { attempt() } match {
               case Left(res) => res
               case Right(messages) =>
                 for ((message_serial, message) <- messages) {
                   _state = _state.progress_serial(message_serial = message_serial)
                   if (build_progress.do_output(message)) build_progress.output(message)
                 }
-                loop()
+                attempts()
             }
           }
-          loop()
-      }
-    }
-
-  private def sync_database(): Unit =
-    synchronized_database {
-      for (db <- _database) {
-        _state =
-          Build_Process.Data.update_database(
-            db, worker_uuid, build_uuid, build_context.hostname, _state)
+          attempts()
       }
     }
 
@@ -900,6 +951,7 @@
       for (db <- _database) {
         Build_Process.Data.write_progress(db, _state.serial, message, build_uuid)
         Build_Process.Data.stamp_worker(db, worker_uuid, _state.serial)
+        _state = _state.set_workers(Build_Process.Data.read_workers(db))
       }
       build_progress_output
     }
@@ -1034,16 +1086,17 @@
       val java = ProcessHandle.current()
       val java_pid = java.pid
       val java_start = Date.instant(java.info.startInstant.get)
-      val serial =
-        Build_Process.Data.start_worker(
-          db, worker_uuid, build_uuid, build_context.hostname, java_pid, java_start)
-      _state = _state.set_serial(serial)
+      _state = _state.inc_serial
+      Build_Process.Data.start_worker(
+        db, worker_uuid, build_uuid, build_context.hostname, java_pid, java_start, _state.serial)
+      _state = _state.set_workers(Build_Process.Data.read_workers(db))
     }
   }
 
   final def stop_worker(): Unit = synchronized_database {
     for (db <- _database) {
       Build_Process.Data.stamp_worker(db, worker_uuid, _state.serial, stop = true)
+      _state = _state.set_workers(Build_Process.Data.read_workers(db))
     }
   }
 
@@ -1051,6 +1104,8 @@
   /* run */
 
   def run(): Map[String, Process_Result] = {
+    if (build_context.master) synchronized_database { _state = init_state(_state) }
+
     def finished(): Boolean = synchronized_database { _state.finished }
 
     def sleep(): Unit =
@@ -1097,10 +1152,7 @@
             }
           }
 
-          if (!start_job()) {
-            sync_database()
-            sleep()
-          }
+          if (!start_job()) sleep()
         }
       }
       finally {