src/Pure/Thy/store.scala
changeset 78178 a177f71dc79f
child 78179 a49ad8d183af
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Pure/Thy/store.scala	Tue Jun 20 14:25:06 2023 +0200
@@ -0,0 +1,432 @@
+/*  Title:      Pure/Thy/store.scala
+    Author:     Makarius
+
+Persistent store for session content: within file-system and/or SQL database.
+*/
+
+package isabelle
+
+
+import java.sql.SQLException
+
+
+object Store {
+  def apply(options: Options, cache: Term.Cache = Term.Cache.make()): Store =
+    new Store(options, cache)
+
+
+  /* source files */
+
+  sealed case class Source_File(
+    name: String,
+    digest: SHA1.Digest,
+    compressed: Boolean,
+    body: Bytes,
+    cache: Compress.Cache
+  ) {
+    override def toString: String = name
+
+    def bytes: Bytes = if (compressed) body.uncompress(cache = cache) else body
+  }
+
+  object Sources {
+    val session_name = SQL.Column.string("session_name").make_primary_key
+    val name = SQL.Column.string("name").make_primary_key
+    val digest = SQL.Column.string("digest")
+    val compressed = SQL.Column.bool("compressed")
+    val body = SQL.Column.bytes("body")
+
+    val table =
+      SQL.Table("isabelle_sources", List(session_name, name, digest, compressed, body))
+
+    def where_equal(session_name: String, name: String = ""): SQL.Source =
+      SQL.where_and(
+        Sources.session_name.equal(session_name),
+        if_proper(name, Sources.name.equal(name)))
+
+    def load(session_base: Sessions.Base, cache: Compress.Cache = Compress.Cache.none): Sources =
+      new Sources(
+        session_base.session_sources.foldLeft(Map.empty) {
+          case (sources, (path, digest)) =>
+            def err(): Nothing = error("Incoherent digest for source file: " + path)
+            val name = File.symbolic_path(path)
+            sources.get(name) match {
+              case Some(source_file) =>
+                if (source_file.digest == digest) sources else err()
+              case None =>
+                val bytes = Bytes.read(path)
+                if (bytes.sha1_digest == digest) {
+                  val (compressed, body) =
+                    bytes.maybe_compress(Compress.Options_Zstd(), cache = cache)
+                  val file = Source_File(name, digest, compressed, body, cache)
+                  sources + (name -> file)
+                }
+                else err()
+            }
+        })
+  }
+
+  class Sources private(rep: Map[String, Source_File]) extends Iterable[Source_File] {
+    override def toString: String = rep.values.toList.sortBy(_.name).mkString("Sources(", ", ", ")")
+    override def iterator: Iterator[Source_File] = rep.valuesIterator
+
+    def get(name: String): Option[Source_File] = rep.get(name)
+    def apply(name: String): Source_File =
+      get(name).getOrElse(error("Missing session sources entry " + quote(name)))
+  }
+
+
+
+  /* session info */
+
+  sealed case class Build_Info(
+    sources: SHA1.Shasum,
+    input_heaps: SHA1.Shasum,
+    output_heap: SHA1.Shasum,
+    return_code: Int,
+    uuid: String
+  ) {
+    def ok: Boolean = return_code == 0
+  }
+
+  object Session_Info {
+    val session_name = SQL.Column.string("session_name").make_primary_key
+
+    // Build_Log.Session_Info
+    val session_timing = SQL.Column.bytes("session_timing")
+    val command_timings = SQL.Column.bytes("command_timings")
+    val theory_timings = SQL.Column.bytes("theory_timings")
+    val ml_statistics = SQL.Column.bytes("ml_statistics")
+    val task_statistics = SQL.Column.bytes("task_statistics")
+    val errors = SQL.Column.bytes("errors")
+    val build_log_columns =
+      List(session_name, session_timing, command_timings, theory_timings,
+        ml_statistics, task_statistics, errors)
+
+    // Build_Info
+    val sources = SQL.Column.string("sources")
+    val input_heaps = SQL.Column.string("input_heaps")
+    val output_heap = SQL.Column.string("output_heap")
+    val return_code = SQL.Column.int("return_code")
+    val uuid = SQL.Column.string("uuid")
+    val build_columns = List(sources, input_heaps, output_heap, return_code, uuid)
+
+    val table = SQL.Table("isabelle_session_info", build_log_columns ::: build_columns)
+
+    val augment_table: PostgreSQL.Source =
+      "ALTER TABLE IF EXISTS " + table.ident +
+      " ADD COLUMN IF NOT EXISTS " + uuid.decl(SQL.sql_type_postgresql)
+  }
+}
+
+class Store private(val options: Options, val cache: Term.Cache) {
+  store =>
+
+  override def toString: String = "Store(output_dir = " + output_dir.absolute + ")"
+
+
+  /* directories */
+
+  val system_output_dir: Path = Path.explode("$ISABELLE_HEAPS_SYSTEM/$ML_IDENTIFIER")
+  val user_output_dir: Path = Path.explode("$ISABELLE_HEAPS/$ML_IDENTIFIER")
+
+  def system_heaps: Boolean = options.bool("system_heaps")
+
+  val output_dir: Path =
+    if (system_heaps) system_output_dir else user_output_dir
+
+  val input_dirs: List[Path] =
+    if (system_heaps) List(system_output_dir)
+    else List(user_output_dir, system_output_dir)
+
+  def presentation_dir: Path =
+    if (system_heaps) Path.explode("$ISABELLE_BROWSER_INFO_SYSTEM")
+    else Path.explode("$ISABELLE_BROWSER_INFO")
+
+
+  /* file names */
+
+  def heap(name: String): Path = Path.basic(name)
+  def database(name: String): Path = Path.basic("log") + Path.basic(name).db
+  def log(name: String): Path = Path.basic("log") + Path.basic(name)
+  def log_gz(name: String): Path = log(name).gz
+
+  def output_heap(name: String): Path = output_dir + heap(name)
+  def output_database(name: String): Path = output_dir + database(name)
+  def output_log(name: String): Path = output_dir + log(name)
+  def output_log_gz(name: String): Path = output_dir + log_gz(name)
+
+
+  /* heap */
+
+  def find_heap(name: String): Option[Path] =
+    input_dirs.map(_ + heap(name)).find(_.is_file)
+
+  def find_heap_shasum(name: String): SHA1.Shasum =
+    (for {
+      path <- find_heap(name)
+      digest <- ML_Heap.read_digest(path)
+    } yield SHA1.shasum(digest, name)).getOrElse(SHA1.no_shasum)
+
+  def the_heap(name: String): Path =
+    find_heap(name) getOrElse
+      error("Missing heap image for session " + quote(name) + " -- expected in:\n" +
+        cat_lines(input_dirs.map(dir => "  " + File.standard_path(dir))))
+
+
+  /* databases for build process and session content */
+
+  def find_database(name: String): Option[Path] =
+    input_dirs.map(_ + database(name)).find(_.is_file)
+
+  def build_database_server: Boolean = options.bool("build_database_server")
+  def build_database_test: Boolean = options.bool("build_database_test")
+
+  def open_database_server(): PostgreSQL.Database =
+    PostgreSQL.open_database(
+      user = options.string("build_database_user"),
+      password = options.string("build_database_password"),
+      database = options.string("build_database_name"),
+      host = options.string("build_database_host"),
+      port = options.int("build_database_port"),
+      ssh =
+        proper_string(options.string("build_database_ssh_host")).map(ssh_host =>
+          SSH.open_session(options,
+            host = ssh_host,
+            user = options.string("build_database_ssh_user"),
+            port = options.int("build_database_ssh_port"))),
+      ssh_close = true)
+
+  def open_build_database(path: Path): SQL.Database =
+    if (build_database_server) open_database_server()
+    else SQLite.open_database(path, restrict = true)
+
+  def maybe_open_build_database(path: Path): Option[SQL.Database] =
+    if (!build_database_test) None else Some(open_build_database(path))
+
+  def try_open_database(
+    name: String,
+    output: Boolean = false,
+    server: Boolean = build_database_server
+  ): Option[SQL.Database] = {
+    def check(db: SQL.Database): Option[SQL.Database] =
+      if (output || session_info_exists(db)) Some(db) else { db.close(); None }
+
+    if (server) check(open_database_server())
+    else if (output) Some(SQLite.open_database(output_database(name)))
+    else {
+      (for {
+        dir <- input_dirs.view
+        path = dir + database(name) if path.is_file
+        db <- check(SQLite.open_database(path))
+      } yield db).headOption
+    }
+  }
+
+  def error_database(name: String): Nothing =
+    error("Missing build database for session " + quote(name))
+
+  def open_database(name: String, output: Boolean = false): SQL.Database =
+    try_open_database(name, output = output) getOrElse error_database(name)
+
+  def prepare_output(): Unit = Isabelle_System.make_directory(output_dir + Path.basic("log"))
+
+  def clean_output(name: String): Option[Boolean] = {
+    val relevant_db =
+      build_database_server &&
+        using_option(try_open_database(name))(init_session_info(_, name)).getOrElse(false)
+
+    val del =
+      for {
+        dir <-
+          (if (system_heaps) List(user_output_dir, system_output_dir) else List(user_output_dir))
+        file <- List(heap(name), database(name), log(name), log_gz(name))
+        path = dir + file if path.is_file
+      } yield path.file.delete
+
+    if (relevant_db || del.nonEmpty) Some(del.forall(identity)) else None
+  }
+
+  def init_output(name: String): Unit = {
+    clean_output(name)
+    using(open_database(name, output = true))(init_session_info(_, name))
+  }
+
+  def check_output(
+    name: String,
+    session_options: Options,
+    sources_shasum: SHA1.Shasum,
+    input_shasum: SHA1.Shasum,
+    fresh_build: Boolean,
+    store_heap: Boolean
+  ): (Boolean, SHA1.Shasum) = {
+    try_open_database(name) match {
+      case Some(db) =>
+        using(db)(read_build(_, name)) match {
+          case Some(build) =>
+            val output_shasum = find_heap_shasum(name)
+            val current =
+              !fresh_build &&
+              build.ok &&
+              Sessions.eq_sources(session_options, build.sources, sources_shasum) &&
+              build.input_heaps == input_shasum &&
+              build.output_heap == output_shasum &&
+              !(store_heap && output_shasum.is_empty)
+            (current, output_shasum)
+          case None => (false, SHA1.no_shasum)
+        }
+      case None => (false, SHA1.no_shasum)
+    }
+  }
+
+
+  /* SQL database content */
+
+  def read_bytes(db: SQL.Database, name: String, column: SQL.Column): Bytes =
+    db.execute_query_statementO[Bytes](
+      Store.Session_Info.table.select(List(column),
+        sql = Store.Session_Info.session_name.where_equal(name)),
+      res => res.bytes(column)
+    ).getOrElse(Bytes.empty)
+
+  def read_properties(db: SQL.Database, name: String, column: SQL.Column): List[Properties.T] =
+    Properties.uncompress(read_bytes(db, name, column), cache = cache)
+
+
+  /* session info */
+
+  val all_tables: SQL.Tables =
+    SQL.Tables(Store.Session_Info.table, Store.Sources.table, Export.Data.table,
+      Document_Build.Data.table)
+
+  def init_session_info(db: SQL.Database, name: String): Boolean =
+    db.transaction_lock(all_tables, create = true) {
+      val already_defined = session_info_defined(db, name)
+
+      db.execute_statement(
+        Store.Session_Info.table.delete(sql = Store.Session_Info.session_name.where_equal(name)))
+      if (db.is_postgresql) db.execute_statement(Store.Session_Info.augment_table)
+
+      db.execute_statement(Store.Sources.table.delete(sql = Store.Sources.where_equal(name)))
+
+      db.execute_statement(
+        Export.Data.table.delete(sql = Export.Data.session_name.where_equal(name)))
+
+      db.execute_statement(
+        Document_Build.Data.table.delete(sql = Document_Build.Data.session_name.where_equal(name)))
+
+      already_defined
+    }
+
+  def session_info_exists(db: SQL.Database): Boolean = {
+    val tables = db.tables
+    all_tables.forall(table => tables.contains(table.name))
+  }
+
+  def session_info_defined(db: SQL.Database, name: String): Boolean =
+    db.execute_query_statementB(
+      Store.Session_Info.table.select(List(Store.Session_Info.session_name),
+        sql = Store.Session_Info.session_name.where_equal(name)))
+
+  def write_session_info(
+    db: SQL.Database,
+    session_name: String,
+    sources: Store.Sources,
+    build_log: Build_Log.Session_Info,
+    build: Store.Build_Info
+  ): Unit = {
+    db.transaction_lock(all_tables) {
+      write_sources(db, session_name, sources)
+      db.execute_statement(Store.Session_Info.table.insert(), body =
+        { stmt =>
+          stmt.string(1) = session_name
+          stmt.bytes(2) = Properties.encode(build_log.session_timing)
+          stmt.bytes(3) = Properties.compress(build_log.command_timings, cache = cache.compress)
+          stmt.bytes(4) = Properties.compress(build_log.theory_timings, cache = cache.compress)
+          stmt.bytes(5) = Properties.compress(build_log.ml_statistics, cache = cache.compress)
+          stmt.bytes(6) = Properties.compress(build_log.task_statistics, cache = cache.compress)
+          stmt.bytes(7) = Build_Log.compress_errors(build_log.errors, cache = cache.compress)
+          stmt.string(8) = build.sources.toString
+          stmt.string(9) = build.input_heaps.toString
+          stmt.string(10) = build.output_heap.toString
+          stmt.int(11) = build.return_code
+          stmt.string(12) = build.uuid
+        })
+    }
+  }
+
+  def read_session_timing(db: SQL.Database, name: String): Properties.T =
+    Properties.decode(read_bytes(db, name, Store.Session_Info.session_timing), cache = cache)
+
+  def read_command_timings(db: SQL.Database, name: String): Bytes =
+    read_bytes(db, name, Store.Session_Info.command_timings)
+
+  def read_theory_timings(db: SQL.Database, name: String): List[Properties.T] =
+    read_properties(db, name, Store.Session_Info.theory_timings)
+
+  def read_ml_statistics(db: SQL.Database, name: String): List[Properties.T] =
+    read_properties(db, name, Store.Session_Info.ml_statistics)
+
+  def read_task_statistics(db: SQL.Database, name: String): List[Properties.T] =
+    read_properties(db, name, Store.Session_Info.task_statistics)
+
+  def read_theories(db: SQL.Database, name: String): List[String] =
+    read_theory_timings(db, name).flatMap(Markup.Name.unapply)
+
+  def read_errors(db: SQL.Database, name: String): List[String] =
+    Build_Log.uncompress_errors(read_bytes(db, name, Store.Session_Info.errors), cache = cache)
+
+  def read_build(db: SQL.Database, name: String): Option[Store.Build_Info] = {
+    if (db.tables.contains(Store.Session_Info.table.name)) {
+      db.execute_query_statementO[Store.Build_Info](
+        Store.Session_Info.table.select(sql = Store.Session_Info.session_name.where_equal(name)),
+        { res =>
+          val uuid =
+            try { Option(res.string(Store.Session_Info.uuid)).getOrElse("") }
+            catch { case _: SQLException => "" }
+          Store.Build_Info(
+            SHA1.fake_shasum(res.string(Store.Session_Info.sources)),
+            SHA1.fake_shasum(res.string(Store.Session_Info.input_heaps)),
+            SHA1.fake_shasum(res.string(Store.Session_Info.output_heap)),
+            res.int(Store.Session_Info.return_code),
+            uuid)
+        }
+      )
+    }
+    else None
+  }
+
+
+  /* session sources */
+
+  def write_sources(db: SQL.Database, session_name: String, sources: Store.Sources): Unit =
+    for (source_file <- sources) {
+      db.execute_statement(Store.Sources.table.insert(), body =
+        { stmt =>
+          stmt.string(1) = session_name
+          stmt.string(2) = source_file.name
+          stmt.string(3) = source_file.digest.toString
+          stmt.bool(4) = source_file.compressed
+          stmt.bytes(5) = source_file.body
+        })
+    }
+
+  def read_sources(
+    db: SQL.Database,
+    session_name: String,
+    name: String = ""
+  ): List[Store.Source_File] = {
+    db.execute_query_statement(
+      Store.Sources.table.select(sql =
+        Store.Sources.where_equal(session_name, name = name) + SQL.order_by(List(Store.Sources.name))),
+      List.from[Store.Source_File],
+      { res =>
+        val res_name = res.string(Store.Sources.name)
+        val digest = SHA1.fake_digest(res.string(Store.Sources.digest))
+        val compressed = res.bool(Store.Sources.compressed)
+        val body = res.bytes(Store.Sources.body)
+        Store.Source_File(res_name, digest, compressed, body, cache.compress)
+      }
+    )
+  }
+}