src/Pure/System/scala.scala
changeset 75654 21164fd15e3d
parent 75620 44815dc2b8f9
child 75660 45d3497c0baa
--- a/src/Pure/System/scala.scala	Wed Jul 06 13:08:33 2022 +0200
+++ b/src/Pure/System/scala.scala	Tue Jul 05 13:12:04 2022 +0200
@@ -7,11 +7,23 @@
 package isabelle
 
 
-import java.io.{File => JFile, StringWriter, PrintWriter}
+import java.io.{File => JFile, PrintStream, ByteArrayOutputStream, OutputStream}
+
+import scala.collection.mutable
+import scala.annotation.tailrec
 
-import scala.tools.nsc.{GenericRunnerSettings, ConsoleWriter, NewLinePrintWriter}
-import scala.tools.nsc.interpreter.{IMain, Results}
-import scala.tools.nsc.interpreter.shell.ReplReporterImpl
+import dotty.tools.dotc.CompilationUnit
+import dotty.tools.dotc.ast.Trees.PackageDef
+import dotty.tools.dotc.ast.untpd
+import dotty.tools.dotc.core.Contexts.{Context => CompilerContext}
+import dotty.tools.dotc.core.NameOps.moduleClassName
+import dotty.tools.dotc.core.{Phases, StdNames}
+import dotty.tools.dotc.interfaces
+import dotty.tools.dotc.reporting.{Diagnostic, ConsoleReporter}
+import dotty.tools.dotc.util.{SourceFile, SourcePosition, NoSourcePosition}
+import dotty.tools.repl
+import dotty.tools.repl.{ReplCompiler, ReplDriver}
+
 
 object Scala {
   /** registered functions **/
@@ -89,78 +101,109 @@
   /** compiler **/
 
   def class_path(): List[String] =
-    for {
-      prop <- List("isabelle.scala.classpath", "java.class.path")
-      elems = System.getProperty(prop, "") if elems.nonEmpty
-      elem <- space_explode(JFile.pathSeparatorChar, elems) if elem.nonEmpty
-    } yield elem
+    space_explode(JFile.pathSeparatorChar, System.getProperty("java.class.path", ""))
+      .filter(_.nonEmpty)
 
   object Compiler {
-    def default_print_writer: PrintWriter =
-      new NewLinePrintWriter(new ConsoleWriter, true)
+    object Message {
+      object Kind extends Enumeration {
+        val error, warning, info, other = Value
+      }
+      private val Header = """^--.* (Error|Warning|Info): .*$""".r
+      val header_kind: String => Kind.Value =
+        {
+          case "Error" => Kind.error
+          case "Warning" => Kind.warning
+          case "Info" => Kind.info
+          case _ => Kind.other
+        }
+
+      // see compiler/src/dotty/tools/dotc/reporting/MessageRendering.scala
+      def split(str: String): List[Message] = {
+        var kind = Kind.other
+        val text = new mutable.StringBuilder
+        val result = new mutable.ListBuffer[Message]
+
+        def flush(): Unit = {
+          if (text.nonEmpty) { result += Message(kind, text.toString) }
+          kind = Kind.other
+          text.clear()
+        }
+
+        for (line <- Library.trim_split_lines(str)) {
+          line match {
+            case Header(k) => flush(); kind = header_kind(k)
+            case _ => if (line.startsWith("-- ")) flush()
+          }
+          if (text.nonEmpty) { text += '\n' }
+          text ++= line
+        }
+        flush()
+        result.toList
+      }
+    }
+
+    sealed case class Message(kind: Message.Kind.Value, text: String)
+    {
+      def is_error: Boolean = kind == Message.Kind.error
+      override def toString: String = text
+    }
+
+    sealed case class Result(
+      state: repl.State,
+      messages: List[Message],
+      unit: Option[CompilationUnit] = None
+    ) {
+      val errors: List[String] = messages.flatMap(msg => if (msg.is_error) Some(msg.text) else None)
+      def ok: Boolean = errors.isEmpty
+      def check_state: repl.State = if (ok) state else error(cat_lines(errors))
+      override def toString: String = if (ok) "Result(ok)" else "Result(error)"
+    }
 
     def context(
-      print_writer: PrintWriter = default_print_writer,
-      error: String => Unit = Exn.error,
+      settings: List[String] = Nil,
       jar_dirs: List[JFile] = Nil,
       class_loader: Option[ClassLoader] = None
     ): Context = {
+      val isabelle_settings =
+        Word.explode(Isabelle_System.getenv_strict("ISABELLE_SCALAC_OPTIONS"))
+
       def find_jars(dir: JFile): List[String] =
         File.find_files(dir, file => file.getName.endsWith(".jar")).
           map(File.absolute_name)
 
-      val settings = new GenericRunnerSettings(error)
-      settings.classpath.value =
-        (class_path() ::: jar_dirs.flatMap(find_jars)).mkString(JFile.pathSeparator)
-
-      new Context(settings, print_writer, class_loader)
+      val classpath = (class_path() ::: jar_dirs.flatMap(find_jars)).mkString(JFile.pathSeparator)
+      val settings1 = isabelle_settings ::: settings ::: List("-classpath", classpath)
+      new Context(settings1, class_loader)
     }
 
     class Context private [Compiler](
-      val settings: GenericRunnerSettings,
-      val print_writer: PrintWriter,
-      val class_loader: Option[ClassLoader]
+      val settings: List[String],
+      val class_loader: Option[ClassLoader] = None
     ) {
-      override def toString: String = settings.toString
+      private val out_stream = new ByteArrayOutputStream(1024)
+      private val out = new PrintStream(out_stream)
+      private val driver: ReplDriver = new ReplDriver(settings.toArray, out, class_loader)
 
-      val interp: IMain =
-        new IMain(settings, new ReplReporterImpl(settings, print_writer)) {
-          override def parentClassLoader: ClassLoader =
-            class_loader getOrElse super.parentClassLoader
-        }
-    }
+      def init_state: repl.State = driver.initialState
 
-    def toplevel(interpret: Boolean, source: String): List[String] = {
-      val out = new StringWriter
-      val interp = Compiler.context(print_writer = new PrintWriter(out)).interp
-      val marker = '\u000b'
-      val ok =
-        interp.withLabel(marker.toString) {
-          if (interpret) interp.interpret(source) == Results.Success
-          else (new interp.ReadEvalPrint).compile(source)
-        }
-      out.close()
-
-      val Error = """(?s)^\S* error: (.*)$""".r
-      val errors =
-        space_explode(marker, Library.strip_ansi_color(out.toString)).
-          collect({ case Error(msg) => "Scala error: " + Library.trim_line(msg) })
-
-      if (!ok && errors.isEmpty) List("Error") else errors
+      def compile(source: String, state: repl.State = init_state): Result = {
+        out.flush()
+        out_stream.reset()
+        val state1 = driver.run(source)(state)
+        out.flush()
+        val messages = Message.split(out_stream.toString(UTF8.charset))
+        out_stream.reset()
+        Result(state1, messages)
+      }
     }
   }
 
   object Toplevel extends Fun_String("scala_toplevel") {
     val here = Scala_Project.here
-    def apply(arg: String): String = {
-      val (interpret, source) =
-        YXML.parse_body(arg) match {
-          case Nil => (false, "")
-          case List(XML.Text(source)) => (false, source)
-          case body => import XML.Decode._; pair(bool, string)(body)
-        }
+    def apply(source: String): String = {
       val errors =
-        try { Compiler.toplevel(interpret, source) }
+        try { Compiler.context().compile(source).errors.map("Scala error: " + _) }
         catch { case ERROR(msg) => List(msg) }
       locally { import XML.Encode._; YXML.string_of_body(list(string)(errors)) }
     }
@@ -174,7 +217,7 @@
     /* requests */
 
     sealed abstract class Request
-    case class Execute(command: Compiler.Context => Unit) extends Request
+    case class Execute(command: (Compiler.Context, repl.State) => repl.State) extends Request
     case object Shutdown extends Request
 
 
@@ -189,19 +232,21 @@
       known.value.collectFirst(which)
   }
 
-  class Interpreter(context: Compiler.Context) {
+  class Interpreter(context: Compiler.Context, out: OutputStream = Console.out) {
     interpreter =>
 
     private val running = Synchronized[Option[Thread]](None)
     def running_thread(thread: Thread): Boolean = running.value.contains(thread)
     def interrupt_thread(): Unit = running.change({ opt => opt.foreach(_.interrupt()); opt })
 
+    private var state = context.init_state
+
     private lazy val thread: Consumer_Thread[Interpreter.Request] =
       Consumer_Thread.fork("Scala.Interpreter") {
         case Interpreter.Execute(command) =>
           try {
             running.change(_ => Some(Thread.currentThread()))
-            command(context)
+            state = command(context, state)
           }
           finally {
             running.change(_ => None)
@@ -219,9 +264,12 @@
       thread.shutdown()
     }
 
-    def execute(command: Compiler.Context => Unit): Unit =
+    def execute(command: (Compiler.Context, repl.State) => repl.State): Unit =
       thread.send(Interpreter.Execute(command))
 
+    def reset(): Unit =
+      thread.send(Interpreter.Execute((context, _) => context.init_state))
+
     Interpreter.add(interpreter)
     thread
   }