clarified signature: more specific exists_table --- avoid retrieving full list beforehand;
authorwenzelm
Mon, 17 Jul 2023 12:15:06 +0200
changeset 78375 234f2ff9afe6
parent 78374 f9f1412ea24e
child 78376 36a3a9a8b5fe
clarified signature: more specific exists_table --- avoid retrieving full list beforehand;
src/Pure/General/sql.scala
src/Pure/Thy/store.scala
src/Pure/Tools/server.scala
--- a/src/Pure/General/sql.scala	Mon Jul 17 11:39:32 2023 +0200
+++ b/src/Pure/General/sql.scala	Mon Jul 17 12:15:06 2023 +0200
@@ -536,13 +536,23 @@
 
     /* tables and views */
 
-    def tables: List[String] = {
+    def get_tables(pattern: String = "%"): List[String] = {
       val result = new mutable.ListBuffer[String]
-      val rs = connection.getMetaData.getTables(null, null, "%", null)
+      val rs = connection.getMetaData.getTables(null, null, pattern, null)
       while (rs.next) { result += rs.getString(3) }
       result.toList
     }
 
+    def exists_table(name: String): Boolean = {
+      val escape = connection.getMetaData.getSearchStringEscape
+      val pattern =
+        name.iterator.map(c =>
+          (if (c == '_' || c == '%' || c == escape(0)) escape else "") + c).mkString
+      get_tables(pattern = pattern).nonEmpty
+    }
+
+    def exists_table(table: Table): Boolean = exists_table(table.name)
+
     def create_table(table: Table, strict: Boolean = false, sql: Source = ""): Unit = {
       execute_statement(table.create(strict, sql_type) + SQL.separate(sql))
       if (is_postgresql) {
@@ -558,7 +568,7 @@
       execute_statement(table.create_index(name, columns, strict, unique))
 
     def create_view(table: Table, strict: Boolean = false): Unit = {
-      if (strict || !tables.contains(table.name)) {
+      if (strict || exists_table(table)) {
         execute_statement("CREATE VIEW " + table + " AS " + { table.query; table.body })
       }
     }
--- a/src/Pure/Thy/store.scala	Mon Jul 17 11:39:32 2023 +0200
+++ b/src/Pure/Thy/store.scala	Mon Jul 17 12:15:06 2023 +0200
@@ -150,7 +150,7 @@
       Build_Log.uncompress_errors(read_bytes(db, name, Session_Info.errors), cache = cache)
 
     def read_build(db: SQL.Database, name: String): Option[Store.Build_Info] = {
-      if (db.tables.contains(Session_Info.table.name)) {
+      if (db.exists_table(Session_Info.table)) {
         db.execute_query_statementO[Store.Build_Info](
           Session_Info.table.select(sql = Session_Info.session_name.where_equal(name)),
           { res =>
@@ -421,10 +421,8 @@
 
   /* session info */
 
-  def session_info_exists(db: SQL.Database): Boolean = {
-    val tables = db.tables
-    Store.Data.tables.forall(table => tables.contains(table.name))
-  }
+  def session_info_exists(db: SQL.Database): Boolean =
+    Store.Data.tables.forall(db.exists_table)
 
   def session_info_defined(db: SQL.Database, name: String): Boolean =
     db.execute_query_statementB(
--- a/src/Pure/Tools/server.scala	Mon Jul 17 11:39:32 2023 +0200
+++ b/src/Pure/Tools/server.scala	Mon Jul 17 12:15:06 2023 +0200
@@ -375,7 +375,7 @@
   }
 
   def list(db: SQLite.Database): List[Info] =
-    if (db.tables.contains(Data.table.name)) {
+    if (db.exists_table(Data.table)) {
       db.execute_query_statement(Data.table.select(),
         List.from[Info],
         { res =>