clarified signature;
authorwenzelm
Fri, 05 May 2017 17:20:50 +0200
changeset 65730 7ae61e72a678
parent 65729 3f40afe30feb
child 65731 393d34045ffb
clarified signature;
src/Pure/General/sql.scala
--- a/src/Pure/General/sql.scala	Fri May 05 11:38:09 2017 +0200
+++ b/src/Pure/General/sql.scala	Fri May 05 17:20:50 2017 +0200
@@ -14,6 +14,9 @@
 {
   /** SQL language **/
 
+  type Source = String
+
+
   /* concrete syntax */
 
   def escape_char(c: Char): String =
@@ -30,23 +33,23 @@
       case _ => c.toString
     }
 
-  def string(s: String): String =
+  def string(s: String): Source =
     "'" + s.map(escape_char(_)).mkString + "'"
 
-  def ident(s: String): String =
+  def ident(s: String): Source =
     Long_Name.implode(Long_Name.explode(s).map(a => quote(a.replace("\"", "\"\""))))
 
-  def enclose(s: String): String = "(" + s + ")"
-  def enclosure(ss: Iterable[String]): String = ss.mkString("(", ", ", ")")
+  def enclose(s: Source): Source = "(" + s + ")"
+  def enclosure(ss: Iterable[Source]): Source = ss.mkString("(", ", ", ")")
 
-  def select(columns: List[Column], distinct: Boolean = false): String =
+  def select(columns: List[Column], distinct: Boolean = false): Source =
     "SELECT " + (if (distinct) "DISTINCT " else "") + commas(columns.map(_.ident)) + " FROM "
 
-  def join(table1: Table, table2: Table, sql: String = "", outer: Boolean = false): String =
+  def join(table1: Table, table2: Table, sql: Source = "", outer: Boolean = false): Source =
     table1.ident + (if (outer) " LEFT OUTER JOIN " else " INNER JOIN ") + table2.ident +
       (if (sql == "") "" else " ON " + sql)
 
-  def join_outer(table1: Table, table2: Table, sql: String = ""): String =
+  def join_outer(table1: Table, table2: Table, sql: Source = ""): Source =
     join(table1, table2, sql, outer = true)
 
 
@@ -63,14 +66,14 @@
     val Date = Value("TIMESTAMP WITH TIME ZONE")
   }
 
-  def sql_type_default(T: Type.Value): String = T.toString
+  def sql_type_default(T: Type.Value): Source = T.toString
 
-  def sql_type_sqlite(T: Type.Value): String =
+  def sql_type_sqlite(T: Type.Value): Source =
     if (T == Type.Boolean) "INTEGER"
     else if (T == Type.Date) "TEXT"
     else sql_type_default(T)
 
-  def sql_type_postgresql(T: Type.Value): String =
+  def sql_type_postgresql(T: Type.Value): Source =
     if (T == Type.Bytes) "BYTEA"
     else sql_type_default(T)
 
@@ -101,20 +104,20 @@
     def apply(table: Table): Column =
       Column(Long_Name.qualify(table.name, name), T, strict = strict, primary_key = primary_key)
 
-    def ident: String = SQL.ident(name)
+    def ident: Source = SQL.ident(name)
 
-    def decl(sql_type: Type.Value => String): String =
+    def decl(sql_type: Type.Value => Source): Source =
       ident + " " + sql_type(T) + (if (strict || primary_key) " NOT NULL" else "")
 
-    def where_equal(s: String): String = "WHERE " + ident + " = " + string(s)
+    def where_equal(s: String): Source = "WHERE " + ident + " = " + string(s)
 
-    override def toString: String = ident
+    override def toString: Source = ident
   }
 
 
   /* tables */
 
-  sealed case class Table(name: String, columns: List[Column], body: String = "")
+  sealed case class Table(name: String, columns: List[Column], body: Source = "")
   {
     private val columns_index: Map[String, Int] =
       columns.iterator.map(_.name).zipWithIndex.toMap
@@ -124,15 +127,15 @@
       case bad => error("Duplicate column names " + commas_quote(bad) + " for table " + quote(name))
     }
 
-    def ident: String = SQL.ident(name)
+    def ident: Source = SQL.ident(name)
 
-    def query: String =
+    def query: Source =
       if (body == "") error("Missing SQL body for table " + quote(name))
       else SQL.enclose(body)
 
-    def query_name: String = query + " AS " + SQL.ident(name)
+    def query_name: Source = query + " AS " + SQL.ident(name)
 
-    def create(strict: Boolean = false, sql_type: Type.Value => String): String =
+    def create(strict: Boolean = false, sql_type: Type.Value => Source): Source =
     {
       val primary_key =
         columns.filter(_.primary_key).map(_.name) match {
@@ -144,26 +147,26 @@
     }
 
     def create_index(index_name: String, index_columns: List[Column],
-        strict: Boolean = false, unique: Boolean = false): String =
+        strict: Boolean = false, unique: Boolean = false): Source =
       "CREATE " + (if (unique) "UNIQUE " else "") + "INDEX " +
         (if (strict) "" else "IF NOT EXISTS ") + SQL.ident(index_name) + " ON " +
         ident + " " + enclosure(index_columns.map(_.name))
 
-    def insert_cmd(cmd: String, sql: String = ""): String =
+    def insert_cmd(cmd: Source, sql: Source = ""): Source =
       cmd + " INTO " + ident + " VALUES " + enclosure(columns.map(_ => "?")) +
         (if (sql == "") "" else " " + sql)
 
-    def insert(sql: String = ""): String = insert_cmd("INSERT", sql)
+    def insert(sql: Source = ""): Source = insert_cmd("INSERT", sql)
 
-    def delete(sql: String = ""): String =
+    def delete(sql: Source = ""): Source =
       "DELETE FROM " + ident +
         (if (sql == "") "" else " " + sql)
 
-    def select(select_columns: List[Column], sql: String = "", distinct: Boolean = false): String =
+    def select(select_columns: List[Column], sql: Source = "", distinct: Boolean = false): Source =
       SQL.select(select_columns, distinct = distinct) + ident +
         (if (sql == "") "" else " " + sql)
 
-    override def toString: String = ident
+    override def toString: Source = ident
   }
 
 
@@ -183,7 +186,7 @@
   {
     /* types */
 
-    def sql_type(T: Type.Value): String
+    def sql_type(T: Type.Value): Source
 
 
     /* connection */
@@ -211,13 +214,13 @@
 
     /* statements */
 
-    def statement(sql: String): PreparedStatement =
+    def statement(sql: Source): PreparedStatement =
       connection.prepareStatement(sql)
 
-    def using_statement[A](sql: String)(f: PreparedStatement => A): A =
+    def using_statement[A](sql: Source)(f: PreparedStatement => A): A =
       using(statement(sql))(f)
 
-    def insert_permissive(table: Table, sql: String = ""): String
+    def insert_permissive(table: Table, sql: Source = ""): Source
 
 
     /* input */
@@ -304,7 +307,7 @@
     def tables: List[String] =
       iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList
 
-    def create_table(table: Table, strict: Boolean = false, sql: String = ""): Unit =
+    def create_table(table: Table, strict: Boolean = false, sql: Source = ""): Unit =
       using_statement(
         table.create(strict, sql_type) + (if (sql == "") "" else " " + sql))(_.execute())
 
@@ -347,7 +350,7 @@
   {
     override def toString: String = name
 
-    def sql_type(T: SQL.Type.Value): String = SQL.sql_type_sqlite(T)
+    def sql_type(T: SQL.Type.Value): SQL.Source = SQL.sql_type_sqlite(T)
 
     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
       if (date == null) set_string(stmt, i, null: String)
@@ -356,7 +359,7 @@
     def date(rs: ResultSet, column: SQL.Column): Date =
       date_format.parse(string(rs, column))
 
-    def insert_permissive(table: SQL.Table, sql: String = ""): String =
+    def insert_permissive(table: SQL.Table, sql: SQL.Source = ""): SQL.Source =
       table.insert_cmd("INSERT OR IGNORE", sql = sql)
 
     def rebuild { using_statement("VACUUM")(_.execute()) }
@@ -419,7 +422,7 @@
   {
     override def toString: String = name
 
-    def sql_type(T: SQL.Type.Value): String = SQL.sql_type_postgresql(T)
+    def sql_type(T: SQL.Type.Value): SQL.Source = SQL.sql_type_postgresql(T)
 
     // see https://jdbc.postgresql.org/documentation/head/8-date-time.html
     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
@@ -432,7 +435,7 @@
       if (obj == null) null else Date.instant(obj.toInstant)
     }
 
-    def insert_permissive(table: SQL.Table, sql: String = ""): String =
+    def insert_permissive(table: SQL.Table, sql: SQL.Source = ""): SQL.Source =
       table.insert_cmd("INSERT",
         sql = sql + (if (sql == "") "" else " ") + "ON CONFLICT DO NOTHING")