src/Pure/Tools/sql.scala
changeset 63781 af9fe0b6b78e
parent 63780 163244cefb4e
child 63783 baa20f3b6cea
--- a/src/Pure/Tools/sql.scala	Sun Sep 04 20:31:23 2016 +0200
+++ b/src/Pure/Tools/sql.scala	Sun Sep 04 21:09:18 2016 +0200
@@ -42,18 +42,28 @@
 
   object Column
   {
-    def int(name: String, strict: Boolean = true): Column[Int] = new Column_Int(name, strict)
-    def long(name: String, strict: Boolean = true): Column[Long] = new Column_Long(name, strict)
-    def double(name: String, strict: Boolean = true): Column[Double] = new Column_Double(name, strict)
-    def string(name: String, strict: Boolean = true): Column[String] = new Column_String(name, strict)
-    def bytes(name: String, strict: Boolean = true): Column[Bytes] = new Column_Bytes(name, strict)
+    def int(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Int] =
+      new Column_Int(name, strict, primary_key)
+    def long(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Long] =
+      new Column_Long(name, strict, primary_key)
+    def double(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Double] =
+      new Column_Double(name, strict, primary_key)
+    def string(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[String] =
+      new Column_String(name, strict, primary_key)
+    def bytes(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Bytes] =
+      new Column_Bytes(name, strict, primary_key)
   }
 
-  abstract class Column[+A] private[SQL](val name: String, val strict: Boolean)
+  abstract class Column[+A] private[SQL](
+    val name: String, val strict: Boolean, val primary_key: Boolean)
   {
     def sql_name: String = quote_ident(name)
     def sql_type: String
-    def sql_decl: String = sql_name + " " + sql_type + (if (strict) " NOT NULL" else "")
+    def sql_decl: String =
+      sql_name + " " + sql_type +
+      (if (strict) " NOT NULL" else "") +
+      (if (primary_key) " PRIMARY KEY" else "")
+
     def string(rs: ResultSet): String =
     {
       val s = rs.getString(name)
@@ -69,29 +79,29 @@
     override def toString: String = sql_decl
   }
 
-  class Column_Int private[SQL](name: String, strict: Boolean)
-    extends Column[Int](name, strict)
+  class Column_Int private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+    extends Column[Int](name, strict, primary_key)
   {
     def sql_type: String = "INTEGER"
     def apply(rs: ResultSet): Int = rs.getInt(name)
   }
 
-  class Column_Long private[SQL](name: String, strict: Boolean)
-    extends Column[Long](name, strict)
+  class Column_Long private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+    extends Column[Long](name, strict, primary_key)
   {
     def sql_type: String = "INTEGER"
     def apply(rs: ResultSet): Long = rs.getLong(name)
   }
 
-  class Column_Double private[SQL](name: String, strict: Boolean)
-    extends Column[Double](name, strict)
+  class Column_Double private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+    extends Column[Double](name, strict, primary_key)
   {
     def sql_type: String = "REAL"
     def apply(rs: ResultSet): Double = rs.getDouble(name)
   }
 
-  class Column_String private[SQL](name: String, strict: Boolean)
-    extends Column[String](name, strict)
+  class Column_String private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+    extends Column[String](name, strict, primary_key)
   {
     def sql_type: String = "TEXT"
     def apply(rs: ResultSet): String =
@@ -101,8 +111,8 @@
     }
   }
 
-  class Column_Bytes private[SQL](name: String, strict: Boolean)
-    extends Column[Bytes](name, strict)
+  class Column_Bytes private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+    extends Column[Bytes](name, strict, primary_key)
   {
     def sql_type: String = "BLOB"
     def apply(rs: ResultSet): Bytes =
@@ -115,8 +125,19 @@
 
   /* tables */
 
-  sealed case class Table(name: String, columns: Column[Any]*)
+  sealed case class Table(name: String, columns: List[Column[Any]])
   {
+    Library.duplicates(columns.map(_.name)) match {
+      case Nil =>
+      case bad => error("Duplicate column names " + commas_quote(bad) + " for table " + quote(name))
+    }
+
+    columns.filter(_.primary_key) match {
+      case bad if bad.length > 1 =>
+        error("Multiple primary keys " + commas_quote(bad.map(_.name)) + " for table " + quote(name))
+      case _ =>
+    }
+
     def sql_create(strict: Boolean, rowid: Boolean): String =
       "CREATE TABLE " + (if (strict) "" else " IF NOT EXISTS ") +
         quote_ident(name) + " " + columns.map(_.sql_decl).mkString("(", ", ", ")") +