more scalable write_entries and Export.consumer via db.execute_batch_statement;
authorwenzelm
Sun, 20 Aug 2023 21:05:56 +0200
changeset 78541 d95497dcd9bc
parent 78540 a6d079e0575d
child 78542 4ffc1933f5d9
more scalable write_entries and Export.consumer via db.execute_batch_statement;
src/Pure/Thy/export.scala
--- a/src/Pure/Thy/export.scala	Sat Aug 19 22:57:06 2023 +0200
+++ b/src/Pure/Thy/export.scala	Sun Aug 20 21:05:56 2023 +0200
@@ -9,6 +9,7 @@
 
 import scala.annotation.tailrec
 import scala.util.matching.Regex
+import scala.collection.mutable
 
 
 object Export {
@@ -54,10 +55,30 @@
     def clean_session(db: SQL.Database, session_name: String): Unit =
       db.execute_statement(Base.table.delete(sql = where_equal(session_name)))
 
-    def readable_entry(db: SQL.Database, entry_name: Entry_Name): Boolean = {
-      db.execute_query_statementB(
-        Base.table.select(List(Base.name),
-          sql = where_equal(entry_name.session, entry_name.theory, entry_name.name)))
+    def known_entries(db: SQL.Database, entry_names: Iterable[Entry_Name]): Set[Entry_Name] = {
+      val it = entry_names.iterator
+      if (it.isEmpty) Set.empty[Entry_Name]
+      else {
+        val sql_preds =
+          List.from(
+            for (entry_name <- it) yield {
+              SQL.and(
+                Base.session_name.equal(entry_name.session),
+                Base.theory_name.equal(entry_name.theory),
+                Base.name.equal(entry_name.name)
+              )
+            })
+        db.execute_query_statement(
+          Base.table.select(List(Base.session_name, Base.theory_name, Base.name),
+            sql = SQL.where(SQL.OR(sql_preds))),
+          Set.from[Entry_Name],
+          { res =>
+            val session_name = res.string(Base.session_name)
+            val theory_name = res.string(Base.theory_name)
+            val name = res.string(Base.name)
+            Entry_Name(session_name, theory_name, name)
+          })
+      }
     }
 
     def read_entry(db: SQL.Database, entry_name: Entry_Name, cache: XML.Cache): Option[Entry] =
@@ -73,17 +94,21 @@
         }
       )
 
-    def write_entry(db: SQL.Database, entry: Entry): Unit = {
-      val (compressed, bs) = entry.body.join
-      db.execute_statement(Base.table.insert(), body = { stmt =>
-        stmt.string(1) = entry.session_name
-        stmt.string(2) = entry.theory_name
-        stmt.string(3) = entry.name
-        stmt.bool(4) = entry.executable
-        stmt.bool(5) = compressed
-        stmt.bytes(6) = bs
+    def write_entries(db: SQL.Database, entries: List[Option[Entry]]): Unit =
+      db.execute_batch_statement(Base.table.insert(), batch = { stmt =>
+        entries.iterator.map({
+          case None => false
+          case Some(entry) =>
+            val (compressed, bs) = entry.body.join
+            stmt.string(1) = entry.session_name
+            stmt.string(2) = entry.theory_name
+            stmt.string(3) = entry.name
+            stmt.bool(4) = entry.executable
+            stmt.bool(5) = compressed
+            stmt.bytes(6) = bs
+            true
+        })
       })
-    }
 
     def read_theory_names(db: SQL.Database, session_name: String): List[String] =
       db.execute_query_statement(
@@ -237,25 +262,45 @@
         bulk = { case (entry, _) => entry.is_finished },
         consume =
           { (args: List[(Entry, Boolean)]) =>
-            val results =
-              private_data.transaction_lock(db, label = "Export.consumer(" + args.length + ")") {
-                for ((entry, strict) <- args)
-                yield {
-                  if (progress.stopped) {
-                    entry.cancel()
-                    Exn.Res(())
+            for ((entry, _) <- args) {
+              if (progress.stopped) entry.cancel() else entry.body.join
+            }
+            private_data.transaction_lock(db, label = "Export.consumer(" + args.length + ")") {
+              var known = private_data.known_entries(db, args.map(p => p._1.entry_name))
+              val buffer = new mutable.ListBuffer[Option[Entry]]
+
+              for ((entry, strict) <- args) {
+                if (progress.stopped) {
+                  buffer += None
+                }
+                else if (known(entry.entry_name)) {
+                  if (strict) {
+                    val msg = message("Duplicate export", entry.theory_name, entry.name)
+                    errors.change(msg :: _)
                   }
-                  else if (private_data.readable_entry(db, entry.entry_name)) {
-                    if (strict) {
-                      val msg = message("Duplicate export", entry.theory_name, entry.name)
-                      errors.change(msg :: _)
-                    }
-                    Exn.Res(())
-                  }
-                  else Exn.capture { private_data.write_entry(db, entry) }
+                  buffer += None
+                }
+                else {
+                  known += entry.entry_name
+                  buffer += Some(entry)
                 }
               }
-            (results, true)
+
+              val entries = buffer.toList
+              val results =
+                try {
+                  private_data.write_entries(db, entries)
+                  val ok = Exn.Res[Unit](())
+                  entries.map(_ => ok)
+                }
+                catch {
+                  case exn: Throwable =>
+                    val err = Exn.Exn[Unit](exn)
+                    entries.map(_ => err)
+                }
+
+              (results, true)
+            }
           })
 
     def make_entry(session_name: String, args: Protocol.Export.Args, body: Bytes): Unit = {