clarified scope of "serial" and "numa_index" within database;
authorwenzelm
Tue, 28 Feb 2023 20:29:44 +0100
changeset 77416 d88c12f22ab0
parent 77415 6b928419f109
child 77417 9bd6c78b3b77
clarified scope of "serial" and "numa_index" within database;
src/Pure/Tools/build_process.scala
--- a/src/Pure/Tools/build_process.scala	Tue Feb 28 19:12:31 2023 +0100
+++ b/src/Pure/Tools/build_process.scala	Tue Feb 28 20:29:44 2023 +0100
@@ -257,12 +257,17 @@
       val table = make_table("", List(instance, ml_platform, options))
     }
 
-    object State {
-      val instance = Generic.instance.make_primary_key
+    object Serial {
       val serial = SQL.Column.long("serial")
+
+      val table = make_table("serial", List(serial))
+    }
+
+    object Node_Info {
+      val hostname = SQL.Column.string("hostname").make_primary_key
       val numa_index = SQL.Column.int("numa_index")
 
-      val table = make_table("state", List(instance, serial, numa_index))
+      val table = make_table("node_info", List(hostname, numa_index))
     }
 
     object Pending {
@@ -297,6 +302,39 @@
           List(name, hostname, numa_node, rc, out, err, timing_elapsed, timing_cpu, timing_gc))
     }
 
+    def get_serial(db: SQL.Database): Long =
+      db.using_statement(Serial.table.select())(stmt =>
+        stmt.execute_query().iterator(_.long(Serial.serial)).nextOption.getOrElse(0L))
+
+    def set_serial(db: SQL.Database, serial: Long): Unit =
+      if (get_serial(db) != serial) {
+        db.using_statement(Serial.table.delete())(_.execute())
+        db.using_statement(Serial.table.insert()) { stmt =>
+          stmt.long(1) = serial
+          stmt.execute()
+        }
+      }
+
+    def read_numa_index(db: SQL.Database, hostname: String): Int =
+      db.using_statement(
+        Node_Info.table.select(List(Node_Info.numa_index),
+          sql = Node_Info.hostname.where_equal(hostname))
+      )(stmt => stmt.execute_query().iterator(_.int(Node_Info.numa_index)).nextOption.getOrElse(0))
+
+    def update_numa_index(db: SQL.Database, hostname: String, numa_index: Int): Boolean =
+      if (read_numa_index(db, hostname) != numa_index) {
+        db.using_statement(
+          Node_Info.table.delete(sql = Node_Info.hostname.where_equal(hostname))
+        )(_.execute())
+        db.using_statement(Node_Info.table.insert()) { stmt =>
+          stmt.string(1) = hostname
+          stmt.int(2) = numa_index
+          stmt.execute()
+        }
+        true
+      }
+      else false
+
     def read_pending(db: SQL.Database): List[Entry] =
       db.using_statement(Pending.table.select(sql = SQL.order_by(List(Pending.name)))) { stmt =>
         List.from(
@@ -425,32 +463,15 @@
         stmt.execute()
       }
 
-    def read_state(db: SQL.Database, instance: String): (Long, Int) =
-      db.using_statement(
-        State.table.select(sql = SQL.where(Generic.sql_equal(instance = instance)))
-      ) { stmt =>
-          (stmt.execute_query().iterator { res =>
-            val serial = res.long(State.serial)
-            val numa_index = res.int(State.numa_index)
-            (serial, numa_index)
-          }).nextOption.getOrElse(error("No build state instance " + instance + " in database " + db))
-        }
-
-    def write_state(db: SQL.Database, instance: String, serial: Long, numa_index: Int): Unit =
-      db.using_statement(State.table.insert()) { stmt =>
-        stmt.string(1) = instance
-        stmt.long(2) = serial
-        stmt.int(3) = numa_index
-        stmt.execute()
-      }
-
-    def reset_state(db: SQL.Database, instance: String): Unit =
-      db.using_statement(
-        State.table.delete(sql = SQL.where(Generic.sql_equal(instance = instance))))(_.execute())
-
     def init_database(db: SQL.Database, build_context: Build_Process.Context): Unit = {
       val tables =
-        List(Config.table, State.table, Pending.table, Running.table, Results.table)
+        List(
+          Config.table,
+          Serial.table,
+          Node_Info.table,
+          Pending.table,
+          Running.table,
+          Results.table)
 
       for (table <- tables) db.create_table(table)
 
@@ -463,21 +484,25 @@
       for (table <- tables) db.using_statement(table.delete())(_.execute())
 
       write_config(db, build_context.instance, build_context.hostname, build_context.store.options)
-      write_state(db, build_context.instance, 0, 0)
     }
 
-    def update_database(db: SQL.Database, instance: String, state: State): State = {
-      val ch1 = update_pending(db, state.pending)
-      val ch2 = update_running(db, state.running)
-      val ch3 = update_results(db, state.results)
+    def update_database(
+      db: SQL.Database,
+      instance: String,
+      hostname: String,
+      state: State
+    ): State = {
+      val changed =
+        List(
+          update_numa_index(db, hostname, state.numa_index),
+          update_pending(db, state.pending),
+          update_running(db, state.running),
+          update_results(db, state.results))
 
-      val (serial0, _) = read_state(db, instance)
-      val serial = if (ch1 || ch2 || ch3) serial0 + 1 else serial0
-      if (serial != serial0) {
-        reset_state(db, instance)
-        write_state(db, instance, serial, state.numa_index)
-      }
+      val serial0 = get_serial(db)
+      val serial = if (changed.exists(identity)) serial0 + 1 else serial0
 
+      set_serial(db, serial)
       state.copy(serial = serial)
     }
   }
@@ -638,7 +663,9 @@
     for (db <- database) {
       synchronized {
         db.transaction {
-          _state = Build_Process.Data.update_database(db, build_context.instance, _state)
+          _state =
+            Build_Process.Data.update_database(
+              db, build_context.instance, build_context.hostname, _state)
         }
       }
     }