clarified handling of SQL.Type;
authorwenzelm
Thu, 09 Feb 2017 11:44:55 +0100
changeset 65008 ed2eedf786f3
parent 65007 b6a1a1d42f5d
child 65009 eda9366bbfac
clarified handling of SQL.Type;
src/Pure/General/sql.scala
--- a/src/Pure/General/sql.scala	Thu Feb 09 11:03:22 2017 +0100
+++ b/src/Pure/General/sql.scala	Thu Feb 09 11:44:55 2017 +0100
@@ -39,6 +39,21 @@
   def enclosure(ss: Iterable[String]): String = ss.mkString("(", ", ", ")")
 
 
+  /* types */
+
+  object Type extends Enumeration
+  {
+    val Int = Value("INTEGER")
+    val Long = Value("BIGINT")
+    val Double = Value("DOUBLE PRECISION")
+    val String = Value("TEXT")
+    val Bytes = Value("BLOB")
+  }
+
+  type Type_Name = Type.Value => String
+  def default_type_name(t: Type.Value): String = t.toString
+
+
   /* columns */
 
   object Column
@@ -52,19 +67,18 @@
     def string(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[String] =
       new Column_String(name, strict, primary_key)
     def bytes(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Bytes] =
-      new Column_Bytes(name, strict, primary_key, "BLOB")  // SQL standard
-    def bytea(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Bytes] =
-      new Column_Bytes(name, strict, primary_key, "BYTEA")  // PostgreSQL
+      new Column_Bytes(name, strict, primary_key)
   }
 
   abstract class Column[+A] private[SQL](
-      val name: String, val strict: Boolean, val primary_key: Boolean)
-    extends Function[ResultSet, A]
+      val name: String,
+      val strict: Boolean,
+      val primary_key: Boolean,
+      val sql_type: Type.Value) extends Function[ResultSet, A]
   {
     def sql_name: String = quote_ident(name)
-    def sql_type: String
-    def sql_decl: String =
-      sql_name + " " + sql_type +
+    def sql_decl(type_name: Type_Name): String =
+      sql_name + " " + type_name(sql_type) +
       (if (strict) " NOT NULL" else "") +
       (if (primary_key) " PRIMARY KEY" else "")
 
@@ -80,34 +94,30 @@
       if (rs.wasNull) None else Some(x)
     }
 
-    override def toString: String = sql_decl
+    override def toString: String = sql_decl(default_type_name)
   }
 
   class Column_Int private[SQL](name: String, strict: Boolean, primary_key: Boolean)
-    extends Column[Int](name, strict, primary_key)
+    extends Column[Int](name, strict, primary_key, Type.Int)
   {
-    def sql_type: String = "INTEGER"
     def apply(rs: ResultSet): Int = rs.getInt(name)
   }
 
   class Column_Long private[SQL](name: String, strict: Boolean, primary_key: Boolean)
-    extends Column[Long](name, strict, primary_key)
+    extends Column[Long](name, strict, primary_key, Type.Long)
   {
-    def sql_type: String = "BIGINT"
     def apply(rs: ResultSet): Long = rs.getLong(name)
   }
 
   class Column_Double private[SQL](name: String, strict: Boolean, primary_key: Boolean)
-    extends Column[Double](name, strict, primary_key)
+    extends Column[Double](name, strict, primary_key, Type.Double)
   {
-    def sql_type: String = "DOUBLE PRECISION"
     def apply(rs: ResultSet): Double = rs.getDouble(name)
   }
 
   class Column_String private[SQL](name: String, strict: Boolean, primary_key: Boolean)
-    extends Column[String](name, strict, primary_key)
+    extends Column[String](name, strict, primary_key, Type.String)
   {
-    def sql_type: String = "TEXT"
     def apply(rs: ResultSet): String =
     {
       val s = rs.getString(name)
@@ -115,9 +125,8 @@
     }
   }
 
-  class Column_Bytes private[SQL](
-    name: String, strict: Boolean, primary_key: Boolean, val sql_type: String)
-    extends Column[Bytes](name, strict, primary_key)
+  class Column_Bytes private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+    extends Column[Bytes](name, strict, primary_key, Type.Bytes)
   {
     def apply(rs: ResultSet): Bytes =
     {
@@ -147,9 +156,9 @@
       case _ =>
     }
 
-    def sql_create(strict: Boolean, rowid: Boolean): String =
+    def sql_create(strict: Boolean, rowid: Boolean, type_name: Type_Name): String =
       "CREATE TABLE " + (if (strict) "" else "IF NOT EXISTS ") +
-        quote_ident(name) + " " + enclosure(columns.map(_.sql_decl)) +
+        quote_ident(name) + " " + enclosure(columns.map(_.sql_decl(type_name))) +
         (if (rowid) "" else " WITHOUT ROWID")
 
     def sql_drop(strict: Boolean): String =
@@ -192,6 +201,11 @@
 
   trait Database
   {
+    /* types */
+
+    def type_name(t: Type.Value): String = default_type_name(t)
+
+
     /* connection */
 
     def connection: Connection
@@ -231,7 +245,7 @@
       iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList
 
     def create_table(table: Table, strict: Boolean = true, rowid: Boolean = true): Unit =
-      using(statement(table.sql_create(strict, rowid)))(_.execute())
+      using(statement(table.sql_create(strict, rowid, type_name)))(_.execute())
 
     def drop_table(table: Table, strict: Boolean = true): Unit =
       using(statement(table.sql_drop(strict)))(_.execute())
@@ -295,5 +309,9 @@
   class Database private[PostgreSQL](name: String, val connection: Connection) extends SQL.Database
   {
     override def toString: String = name
+
+    override def type_name(t: SQL.Type.Value): String =
+      if (t == SQL.Type.Bytes) "BYTEA"
+      else SQL.default_type_name(t)
   }
 }