src/Pure/General/sql.scala
changeset 79721 a5629eade476
parent 78891 76d1382d6077
child 79722 5938158733bb
--- a/src/Pure/General/sql.scala	Sat Feb 24 16:30:25 2024 +0100
+++ b/src/Pure/General/sql.scala	Sat Feb 24 22:07:21 2024 +0100
@@ -14,7 +14,7 @@
 
 import org.sqlite.SQLiteConfig
 import org.sqlite.jdbc4.JDBC4Connection
-import org.postgresql.{PGConnection, PGNotification}
+import org.postgresql.PGConnection
 
 import scala.collection.mutable
 
@@ -383,6 +383,11 @@
   }
 
 
+  /* notifications: IPC via database server */
+
+  sealed case class Notification(name: String, parameter: String)
+
+
   /* database */
 
   trait Database extends AutoCloseable {
@@ -581,6 +586,14 @@
         execute_statement("CREATE VIEW " + table + " AS " + { table.query; table.body })
       }
     }
+
+
+    /* notifications (PostgreSQL only) */
+
+    def listen(name: String): Unit = ()
+    def unlisten(name: String = "*"): Unit = ()
+    def send(name: String, parameter: String = ""): Unit = ()
+    def receive(filter: Notification => Boolean): List[Notification] = Nil
   }
 
 
@@ -771,24 +784,86 @@
 
 
     /* notifications: IPC via database server */
-    // see https://www.postgresql.org/docs/current/sql-notify.html
+    /*
+      - see https://www.postgresql.org/docs/current/sql-notify.html
+      - self-notifications and repeated notifications are suppressed
+      - notifications are sorted by local system time (nano seconds)
+    */
 
-    def listen(name: String): Unit =
-      execute_statement("LISTEN " + SQL.ident(name))
+    private var _receiver_buffer: Option[Map[SQL.Notification, Long]] = None
 
-    def unlisten(name: String = "*"): Unit =
-      execute_statement("UNLISTEN " + (if (name == "*") name else SQL.ident(name)))
+    private lazy val _receiver_thread =
+      Isabelle_Thread.fork(name = "PostgreSQL.receiver", daemon = true, uninterruptible = true) {
+        val conn = the_postgresql_connection
+        val self_pid = conn.getBackendPID
 
-    def notify(name: String, payload: String = ""): Unit =
-      execute_statement("NOTIFY " + SQL.ident(name) + if_proper(payload, ", " + SQL.string(payload)))
-
-    def get_notifications(): List[PGNotification] =
-      the_postgresql_connection.getNotifications() match {
-        case null => Nil
-        case array => array.toList
+        try {
+          while (true) {
+            Isabelle_Thread.interruptible { Time.seconds(0.5).sleep() }
+            Option(conn.getNotifications()) match {
+              case Some(array) if array.nonEmpty =>
+                synchronized {
+                  var received = _receiver_buffer.getOrElse(Map.empty)
+                  for (a <- array.iterator if a.getPID != self_pid) {
+                    val msg = SQL.Notification(a.getName, a.getParameter)
+                    if (!received.isDefinedAt(msg)) {
+                      val stamp = System.nanoTime()
+                      received = received + (msg -> stamp)
+                    }
+                  }
+                  _receiver_buffer = Some(received)
+                }
+              case _ =>
+            }
+          }
+        }
+        catch { case Exn.Interrupt() => }
       }
 
+    private def receiver_shutdown(): Unit = synchronized {
+      if (_receiver_buffer.isDefined) {
+        _receiver_thread.interrupt()
+        Some(_receiver_thread)
+      }
+      else None
+    }.foreach(_.join())
 
-    override def close(): Unit = { super.close(); if (server_close) server.close() }
+    private def synchronized_receiver[A](body: => A): A = synchronized {
+      if (_receiver_buffer.isEmpty) {
+        _receiver_buffer = Some(Map.empty)
+        _receiver_thread
+      }
+      body
+    }
+
+    override def listen(name: String): Unit = synchronized_receiver {
+      execute_statement("LISTEN " + SQL.ident(name))
+    }
+
+    override def unlisten(name: String = "*"): Unit = synchronized_receiver {
+      execute_statement("UNLISTEN " + (if (name == "*") name else SQL.ident(name)))
+    }
+
+    override def send(name: String, parameter: String = ""): Unit = synchronized_receiver {
+      execute_statement(
+        "NOTIFY " + SQL.ident(name) + if_proper(parameter, ", " + SQL.string(parameter)))
+    }
+
+    override def receive(filter: SQL.Notification => Boolean = _ => true): List[SQL.Notification] =
+      synchronized {
+        val received = _receiver_buffer.getOrElse(Map.empty)
+        val filtered = received.keysIterator.filter(filter).toList
+        if (_receiver_buffer.isDefined && filtered.nonEmpty) {
+          _receiver_buffer = Some(received -- filtered)
+          filtered.map(msg => msg -> received(msg)).sortBy(_._2).map(_._1)
+        }
+        else Nil
+      }
+
+    override def close(): Unit = {
+      receiver_shutdown()
+      super.close()
+      if (server_close) server.close()
+    }
   }
 }