src/Pure/System/scala.scala
author wenzelm
Fri, 14 Aug 2020 13:59:09 +0200
changeset 72152 3fa75db844f5
parent 72151 64df1e514005
child 72157 d1ca82e27cbc
permissions -rw-r--r--
clarified demo functions;

/*  Title:      Pure/System/scala.scala
    Author:     Makarius

Support for Scala at runtime.
*/

package isabelle


import java.io.{File => JFile, StringWriter, PrintWriter}

import scala.tools.nsc.{GenericRunnerSettings, ConsoleWriter, NewLinePrintWriter}
import scala.tools.nsc.interpreter.IMain


object Scala
{
  /** demo functions **/

  def echo(arg: String): String = arg

  def sleep(seconds: String): String =
  {
    val t =
      seconds match {
        case Value.Double(s) => Time.seconds(s)
        case _ => error("Malformed argument: " + quote(seconds))
      }
    val t0 = Time.now()
    t.sleep
    val t1 = Time.now()
    (t1 - t0).toString
  }



  /** compiler **/

  object Compiler
  {
    def context(
      error: String => Unit = Exn.error,
      jar_dirs: List[JFile] = Nil): Context =
    {
      def find_jars(dir: JFile): List[String] =
        File.find_files(dir, file => file.getName.endsWith(".jar")).
          map(File.absolute_name)

      val class_path =
        for {
          prop <- List("isabelle.scala.classpath", "java.class.path")
          path = System.getProperty(prop, "") if path != "\"\""
          elem <- space_explode(JFile.pathSeparatorChar, path)
        } yield elem

      val settings = new GenericRunnerSettings(error)
      settings.classpath.value =
        (class_path ::: jar_dirs.flatMap(find_jars)).mkString(JFile.pathSeparator)

      new Context(settings)
    }

    def default_print_writer: PrintWriter =
      new NewLinePrintWriter(new ConsoleWriter, true)

    class Context private [Compiler](val settings: GenericRunnerSettings)
    {
      override def toString: String = settings.toString

      def interpreter(
        print_writer: PrintWriter = default_print_writer,
        class_loader: ClassLoader = null): IMain =
      {
        new IMain(settings, print_writer)
        {
          override def parentClassLoader: ClassLoader =
            if (class_loader == null) super.parentClassLoader
            else class_loader
        }
      }

      def toplevel(source: String): List[String] =
      {
        val out = new StringWriter
        val interp = interpreter(new PrintWriter(out))
        val rep = new interp.ReadEvalPrint
        val ok = interp.withLabel("\u0001") { rep.compile(source) }
        out.close

        val Error = """(?s)^\S* error: (.*)$""".r
        val errors =
          space_explode('\u0001', Library.strip_ansi_color(out.toString)).
            collect({ case Error(msg) => "Scala error: " + Library.trim_line(msg) })

        if (!ok && errors.isEmpty) List("Error") else errors
      }
    }
  }

  def toplevel_yxml(source: String): String =
  {
    val errors =
      try { Compiler.context().toplevel(source) }
      catch { case ERROR(msg) => List(msg) }
    locally { import XML.Encode._; YXML.string_of_body(list(string)(errors)) }
  }



  /** invoke Scala functions from ML **/

  /* registered functions */

  sealed case class Fun(name: String, apply: String => String)
  {
    override def toString: String = name
  }

  lazy val functions: List[Fun] =
    Isabelle_System.services.collect { case c: Isabelle_Scala_Functions => c.functions.toList }.flatten


  /* invoke function */

  object Tag extends Enumeration
  {
    val NULL, OK, ERROR, FAIL, INTERRUPT = Value
  }

  def function(name: String, arg: String): (Tag.Value, String) =
    functions.find(fun => fun.name == name) match {
      case Some(fun) =>
        Exn.capture { fun.apply(arg) } match {
          case Exn.Res(null) => (Tag.NULL, "")
          case Exn.Res(res) => (Tag.OK, res)
          case Exn.Exn(Exn.Interrupt()) => (Tag.INTERRUPT, "")
          case Exn.Exn(e) => (Tag.ERROR, Exn.message(e))
        }
      case None => (Tag.FAIL, "Unknown Isabelle/Scala function: " + quote(name))
    }
}


/* protocol handler */

class Scala extends Session.Protocol_Handler
{
  private var session: Session = null
  private var futures = Map.empty[String, Future[Unit]]

  override def init(init_session: Session): Unit =
    synchronized { session = init_session }

  override def exit(): Unit = synchronized
  {
    for ((id, future) <- futures) cancel(id, future)
    futures = Map.empty
  }

  private def result(id: String, tag: Scala.Tag.Value, res: String): Unit =
    synchronized
    {
      if (futures.isDefinedAt(id)) {
        session.protocol_command("Scala.result", id, tag.id.toString, res)
        futures -= id
      }
    }

  private def cancel(id: String, future: Future[Unit])
  {
    future.cancel
    result(id, Scala.Tag.INTERRUPT, "")
  }

  private def invoke_scala(msg: Prover.Protocol_Output): Boolean = synchronized
  {
    msg.properties match {
      case Markup.Invoke_Scala(name, id) =>
        futures += (id ->
          Future.fork {
            val (tag, res) = Scala.function(name, msg.text)
            result(id, tag, res)
          })
        true
      case _ => false
    }
  }

  private def cancel_scala(msg: Prover.Protocol_Output): Boolean = synchronized
  {
    msg.properties match {
      case Markup.Cancel_Scala(id) =>
        futures.get(id) match {
          case Some(future) => cancel(id, future)
          case None =>
        }
        true
      case _ => false
    }
  }

  val functions =
    List(
      Markup.Invoke_Scala.name -> invoke_scala,
      Markup.Cancel_Scala.name -> cancel_scala)
}


/* registered functions */

class Isabelle_Scala_Functions(val functions: Scala.Fun*) extends Isabelle_System.Service

class Functions extends Isabelle_Scala_Functions(
  Scala.Fun("echo", Scala.echo),
  Scala.Fun("sleep", Scala.sleep),
  Scala.Fun("scala_toplevel", Scala.toplevel_yxml),
  Scala.Fun("check_bibtex_database", Bibtex.check_database_yxml))