--- a/src/Pure/Tools/sql.scala Sun Sep 04 20:31:23 2016 +0200
+++ b/src/Pure/Tools/sql.scala Sun Sep 04 21:09:18 2016 +0200
@@ -42,18 +42,28 @@
object Column
{
- def int(name: String, strict: Boolean = true): Column[Int] = new Column_Int(name, strict)
- def long(name: String, strict: Boolean = true): Column[Long] = new Column_Long(name, strict)
- def double(name: String, strict: Boolean = true): Column[Double] = new Column_Double(name, strict)
- def string(name: String, strict: Boolean = true): Column[String] = new Column_String(name, strict)
- def bytes(name: String, strict: Boolean = true): Column[Bytes] = new Column_Bytes(name, strict)
+ def int(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Int] =
+ new Column_Int(name, strict, primary_key)
+ def long(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Long] =
+ new Column_Long(name, strict, primary_key)
+ def double(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Double] =
+ new Column_Double(name, strict, primary_key)
+ def string(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[String] =
+ new Column_String(name, strict, primary_key)
+ def bytes(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Bytes] =
+ new Column_Bytes(name, strict, primary_key)
}
- abstract class Column[+A] private[SQL](val name: String, val strict: Boolean)
+ abstract class Column[+A] private[SQL](
+ val name: String, val strict: Boolean, val primary_key: Boolean)
{
def sql_name: String = quote_ident(name)
def sql_type: String
- def sql_decl: String = sql_name + " " + sql_type + (if (strict) " NOT NULL" else "")
+ def sql_decl: String =
+ sql_name + " " + sql_type +
+ (if (strict) " NOT NULL" else "") +
+ (if (primary_key) " PRIMARY KEY" else "")
+
def string(rs: ResultSet): String =
{
val s = rs.getString(name)
@@ -69,29 +79,29 @@
override def toString: String = sql_decl
}
- class Column_Int private[SQL](name: String, strict: Boolean)
- extends Column[Int](name, strict)
+ class Column_Int private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+ extends Column[Int](name, strict, primary_key)
{
def sql_type: String = "INTEGER"
def apply(rs: ResultSet): Int = rs.getInt(name)
}
- class Column_Long private[SQL](name: String, strict: Boolean)
- extends Column[Long](name, strict)
+ class Column_Long private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+ extends Column[Long](name, strict, primary_key)
{
def sql_type: String = "INTEGER"
def apply(rs: ResultSet): Long = rs.getLong(name)
}
- class Column_Double private[SQL](name: String, strict: Boolean)
- extends Column[Double](name, strict)
+ class Column_Double private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+ extends Column[Double](name, strict, primary_key)
{
def sql_type: String = "REAL"
def apply(rs: ResultSet): Double = rs.getDouble(name)
}
- class Column_String private[SQL](name: String, strict: Boolean)
- extends Column[String](name, strict)
+ class Column_String private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+ extends Column[String](name, strict, primary_key)
{
def sql_type: String = "TEXT"
def apply(rs: ResultSet): String =
@@ -101,8 +111,8 @@
}
}
- class Column_Bytes private[SQL](name: String, strict: Boolean)
- extends Column[Bytes](name, strict)
+ class Column_Bytes private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+ extends Column[Bytes](name, strict, primary_key)
{
def sql_type: String = "BLOB"
def apply(rs: ResultSet): Bytes =
@@ -115,8 +125,19 @@
/* tables */
- sealed case class Table(name: String, columns: Column[Any]*)
+ sealed case class Table(name: String, columns: List[Column[Any]])
{
+ Library.duplicates(columns.map(_.name)) match {
+ case Nil =>
+ case bad => error("Duplicate column names " + commas_quote(bad) + " for table " + quote(name))
+ }
+
+ columns.filter(_.primary_key) match {
+ case bad if bad.length > 1 =>
+ error("Multiple primary keys " + commas_quote(bad.map(_.name)) + " for table " + quote(name))
+ case _ =>
+ }
+
def sql_create(strict: Boolean, rowid: Boolean): String =
"CREATE TABLE " + (if (strict) "" else " IF NOT EXISTS ") +
quote_ident(name) + " " + columns.map(_.sql_decl).mkString("(", ", ", ")") +