--- a/src/HOL/Tools/Nitpick/kodkod.scala Mon Aug 24 22:30:34 2020 +0200
+++ b/src/HOL/Tools/Nitpick/kodkod.scala Tue Aug 25 13:44:09 2020 +0200
@@ -16,6 +16,8 @@
object Kodkod
{
+ /** result **/
+
sealed case class Result(rc: Int, out: String, err: String)
{
def ok: Boolean = rc == 0
@@ -29,6 +31,10 @@
}
}
+
+
+ /** execute **/
+
def execute(source: String,
solve_all: Boolean = false,
prove: Boolean = false,
@@ -37,11 +43,18 @@
timeout: Time = Time.zero,
max_threads: Int = 0): Result =
{
+ /* executor */
+
val pool_size = if (max_threads == 0) Isabelle_Thread.max_threads() else max_threads
- val executor = Executors.newFixedThreadPool(pool_size)
+ val executor = Executors.newFixedThreadPool(pool_size)
+ val executor_killed = Synchronized(false)
def executor_kill(): Unit =
- if (!executor.isShutdown) Isabelle_Thread.fork() { executor.shutdownNow() }
+ executor_killed.change(b =>
+ if (b) b else { Isabelle_Thread.fork() { executor.shutdownNow() }; true })
+
+
+ /* system context */
class Exit extends Exception("EXIT")
@@ -52,8 +65,19 @@
private val err = new StringBuilder
def return_code(i: Int): Unit = synchronized { rc = rc max i}
- override def output(s: String): Unit = synchronized { out ++= s; out += '\n' }
- override def error(s: String): Unit = synchronized { err ++= s; err += '\n' }
+
+ override def output(s: String): Unit = synchronized {
+ Exn.Interrupt.expose()
+ out ++= s
+ out += '\n'
+ }
+
+ override def error(s: String): Unit = synchronized {
+ Exn.Interrupt.expose()
+ err ++= s
+ err += '\n'
+ }
+
override def exit(i: Int): Unit =
synchronized {
return_code(i)
@@ -64,6 +88,9 @@
def result(): Result = synchronized { Result(rc, out.toString, err.toString) }
}
+
+ /* main */
+
try {
val lexer = new KodkodiLexer(new ANTLRInputStream(Bytes(source).stream))
val parser =
@@ -101,6 +128,8 @@
context.return_code(1)
}
+ executor.shutdownNow()
+
context.result()
}
@@ -110,7 +139,8 @@
File.read(Path.explode("$KODKODI/examples/weber3.kki"))).check
- /* scala function */
+
+ /** scala function **/
object Fun extends Scala.Fun("kodkod")
{