clarified signature for Scala functions;
authorwenzelm
Mon, 12 Apr 2021 18:10:13 +0200
changeset 73565 1aa92bc4d356
parent 73564 a021bb558feb
child 73566 4e6b31ed7197
clarified signature for Scala functions;
src/HOL/Tools/ATP/system_on_tptp.scala
src/HOL/Tools/Nitpick/kodkod.scala
src/Pure/PIDE/resources.ML
src/Pure/PIDE/resources.scala
src/Pure/PIDE/session.scala
src/Pure/System/bash.scala
src/Pure/System/isabelle_system.ML
src/Pure/System/isabelle_system.scala
src/Pure/System/isabelle_tool.scala
src/Pure/System/scala.ML
src/Pure/System/scala.scala
src/Pure/Thy/bibtex.scala
src/Pure/Tools/debugger.scala
src/Pure/Tools/doc.scala
--- a/src/HOL/Tools/ATP/system_on_tptp.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/HOL/Tools/ATP/system_on_tptp.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -40,7 +40,7 @@
         "ListStatus" -> "READY",
         "QuietFlag" -> "-q0"))
 
-  object List_Systems extends Scala.Fun("SystemOnTPTP.list_systems", thread = true)
+  object List_Systems extends Scala.Fun_String("SystemOnTPTP.list_systems", thread = true)
   {
     val here = Scala_Project.here
     def apply(url: String): String = list_systems(Url(url)).text
@@ -69,7 +69,7 @@
     post_request(url, paramaters, timeout = timeout + Time.seconds(15))
   }
 
-  object Run_System extends Scala.Fun("SystemOnTPTP.run_system", thread = true)
+  object Run_System extends Scala.Fun_String("SystemOnTPTP.run_system", thread = true)
   {
     val here = Scala_Project.here
     def apply(arg: String): String =
--- a/src/HOL/Tools/Nitpick/kodkod.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/HOL/Tools/Nitpick/kodkod.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -154,7 +154,7 @@
 
   /** scala function **/
 
-  object Fun extends Scala.Fun("kodkod", thread = true)
+  object Fun extends Scala.Fun_String("kodkod", thread = true)
   {
     val here = Scala_Project.here
     def apply(args: String): String =
--- a/src/Pure/PIDE/resources.ML	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/PIDE/resources.ML	Mon Apr 12 18:10:13 2021 +0200
@@ -13,7 +13,7 @@
      session_chapters: (string * string) list,
      bibtex_entries: (string * string list) list,
      command_timings: Properties.T list,
-     scala_functions: (string * Position.T) list,
+     scala_functions: (string * (bool * Position.T)) list,
      global_theories: (string * string) list,
      loaded_theories: string list} -> unit
   val init_session_yxml: string -> unit
@@ -25,7 +25,7 @@
   val session_chapter: string -> string
   val last_timing: Toplevel.transition -> Time.time
   val scala_functions: unit -> string list
-  val check_scala_function: Proof.context -> string * Position.T -> string
+  val check_scala_function: Proof.context -> string * Position.T -> string * bool
   val master_directory: theory -> Path.T
   val imports_of: theory -> (string * Position.T) list
   val begin_theory: Path.T -> Thy_Header.header -> theory list -> theory
@@ -104,7 +104,7 @@
     session_chapters = Symtab.empty: string Symtab.table,
     bibtex_entries = Symtab.empty: string list Symtab.table,
     timings = empty_timings,
-    scala_functions = Symtab.empty: Position.T Symtab.table},
+    scala_functions = Symtab.empty: (bool * Position.T) Symtab.table},
    {global_theories = Symtab.empty: string Symtab.table,
     loaded_theories = Symtab.empty: unit Symtab.table});
 
@@ -136,7 +136,7 @@
           (pair (list (pair string properties))
             (pair (list (pair string string)) (pair (list (pair string string))
               (pair (list (pair string (list string))) (pair (list properties)
-                (pair (list (pair string properties))
+                (pair (list (pair string (pair bool properties)))
                   (pair (list (pair string string)) (list string))))))))
         end;
   in
@@ -146,7 +146,7 @@
        session_chapters = session_chapters,
        bibtex_entries = bibtex_entries,
        command_timings = command_timings,
-       scala_functions = map (apsnd Position.of_properties) scala_functions,
+       scala_functions = (map o apsnd o apsnd) Position.of_properties scala_functions,
        global_theories = global_theories,
        loaded_theories = loaded_theories}
   end;
@@ -184,23 +184,31 @@
 fun scala_functions () = space_explode "," (getenv "ISABELLE_SCALA_FUNCTIONS");
 
 (*regular resources*)
-fun scala_function_pos name =
-  (name, the_default Position.none (Symtab.lookup (get_session_base1 #scala_functions) name));
+fun scala_function a =
+  (a, the_default (false, Position.none) (Symtab.lookup (get_session_base1 #scala_functions) a));
 
 fun check_scala_function ctxt arg =
-  Completion.check_entity Markup.scala_functionN
-    (scala_functions () |> sort_strings |> map scala_function_pos) ctxt arg;
+  let
+    val funs = scala_functions () |> sort_strings |> map scala_function;
+    val name = Completion.check_entity Markup.scala_functionN (map (apsnd #2) funs) ctxt arg;
+    val multi =
+      (case AList.lookup (op =) funs name of
+        SOME (multi, _) => multi
+      | NONE => false);
+  in (name, multi) end;
 
 val _ = Theory.setup
  (Thy_Output.antiquotation_verbatim_embedded \<^binding>\<open>scala_function\<close>
-    (Scan.lift Parse.embedded_position) check_scala_function #>
+    (Scan.lift Parse.embedded_position) (#1 oo check_scala_function) #>
   ML_Antiquotation.inline_embedded \<^binding>\<open>scala_function\<close>
     (Args.context -- Scan.lift Parse.embedded_position
-      >> (uncurry check_scala_function #> ML_Syntax.print_string)) #>
+      >> (uncurry check_scala_function #> #1 #> ML_Syntax.print_string)) #>
   ML_Antiquotation.value_embedded \<^binding>\<open>scala\<close>
     (Args.context -- Scan.lift Args.embedded_position >> (fn (ctxt, arg) =>
-      let val name = check_scala_function ctxt arg
-      in ML_Syntax.atomic ("Scala.function " ^ ML_Syntax.print_string name) end)));
+      let
+        val (name, multi) = check_scala_function ctxt arg;
+        val func = if multi then "Scala.function" else "Scala.function1";
+      in ML_Syntax.atomic (func ^ " " ^ ML_Syntax.print_string name) end)));
 
 
 (* manage source files *)
--- a/src/Pure/PIDE/resources.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/PIDE/resources.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -39,14 +39,14 @@
       pair(list(pair(string, string)),
       pair(list(pair(string, list(string))),
       pair(list(properties),
-      pair(list(pair(string, properties)),
+      pair(list(pair(string, pair(bool, properties))),
       pair(list(pair(string, string)), list(string))))))))(
        (sessions_structure.session_positions,
        (sessions_structure.dest_session_directories,
        (sessions_structure.session_chapters,
        (sessions_structure.bibtex_entries,
        (command_timings,
-       (Scala.functions.map(fun => (fun.name, fun.position)),
+       (Scala.functions.map(fun => (fun.name, (fun.multi, fun.position))),
        (session_base.global_theories.toList,
         session_base.loaded_theories.keys)))))))))
   }
--- a/src/Pure/PIDE/session.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/PIDE/session.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -209,7 +209,8 @@
   private case object Stop
   private case class Get_State(promise: Promise[Document.State])
   private case class Cancel_Exec(exec_id: Document_ID.Exec)
-  private case class Protocol_Command(name: String, args: List[String])
+  private case class Protocol_Command_Raw(name: String, args: List[Bytes])
+  private case class Protocol_Command_Args(name: String, args: List[String])
   private case class Update_Options(options: Options)
   private case object Consolidate_Execution
   private case object Prune_History
@@ -668,7 +669,10 @@
             prover.get.dialog_result(serial, result)
             handle_output(new Prover.Output(Protocol.Dialog_Result(id, serial, result)))
 
-          case Protocol_Command(name, args) if prover.defined =>
+          case Protocol_Command_Raw(name, args) if prover.defined =>
+            prover.get.protocol_command_raw(name, args)
+
+          case Protocol_Command_Args(name, args) if prover.defined =>
             prover.get.protocol_command_args(name, args)
 
           case change: Session.Change if prover.defined =>
@@ -757,8 +761,14 @@
     }
   }
 
+  def protocol_command_raw(name: String, args: List[Bytes]): Unit =
+    manager.send(Protocol_Command_Raw(name, args))
+
+  def protocol_command_args(name: String, args: List[String]): Unit =
+    manager.send(Protocol_Command_Args(name, args))
+
   def protocol_command(name: String, args: String*): Unit =
-    manager.send(Protocol_Command(name, args.toList))
+    protocol_command_args(name, args.toList)
 
   def cancel_exec(exec_id: Document_ID.Exec): Unit =
     manager.send(Cancel_Exec(exec_id))
--- a/src/Pure/System/bash.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/System/bash.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -206,12 +206,12 @@
 
   /* Scala function */
 
-  object Process extends Scala.Fun("bash_process", thread = true)
+  object Process extends Scala.Fun_Strings("bash_process", thread = true)
   {
     val here = Scala_Project.here
-    def apply(script: String): String =
+    def apply(args: List[String]): List[String] =
     {
-      val result = Exn.capture { Isabelle_System.bash(script) }
+      val result = Exn.capture { Isabelle_System.bash(cat_lines(args)) }
 
       val is_interrupt =
         result match {
@@ -220,15 +220,14 @@
         }
 
       result match {
-        case _ if is_interrupt => ""
-        case Exn.Exn(exn) => Exn.message(exn)
+        case _ if is_interrupt => Nil
+        case Exn.Exn(exn) => List(Exn.message(exn))
         case Exn.Res(res) =>
-          Library.cat_strings0(
-            res.rc.toString ::
-            res.timing.elapsed.ms.toString ::
-            res.timing.cpu.ms.toString ::
-            res.out_lines.length.toString ::
-            res.out_lines ::: res.err_lines)
+          res.rc.toString ::
+          res.timing.elapsed.ms.toString ::
+          res.timing.cpu.ms.toString ::
+          res.out_lines.length.toString ::
+          res.out_lines ::: res.err_lines
       }
     }
   }
--- a/src/Pure/System/isabelle_system.ML	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/System/isabelle_system.ML	Mon Apr 12 18:10:13 2021 +0200
@@ -35,8 +35,7 @@
 
 fun bash_process script =
   Scala.function "bash_process"
-    ("export ISABELLE_TMP=" ^ Bash.string (getenv "ISABELLE_TMP") ^ "\n" ^ script)
-  |> split_strings0
+    ["export ISABELLE_TMP=" ^ Bash.string (getenv "ISABELLE_TMP"), script]
   |> (fn [] => raise Exn.Interrupt
       | [err] => error err
       | a :: b :: c :: d :: lines =>
@@ -78,7 +77,7 @@
 (* directory and file operations *)
 
 val absolute_path = Path.implode o File.absolute_path;
-fun scala_function0 name = ignore o Scala.function name o cat_strings0;
+fun scala_function0 name = ignore o Scala.function1 name o cat_strings0;
 fun scala_function name = scala_function0 name o map absolute_path;
 
 fun make_directory path = (scala_function "make_directory" [path]; path);
@@ -118,12 +117,12 @@
 (* download file *)
 
 fun download url file =
-  ignore (Scala.function "download" (cat_strings0 [url, absolute_path file]));
+  ignore (Scala.function "download" [url, absolute_path file]);
 
 
 (* Isabelle distribution identification *)
 
-fun isabelle_id () = Scala.function "isabelle_id" "";
+fun isabelle_id () = Scala.function1 "isabelle_id" "";
 
 fun isabelle_identifier () = try getenv_strict "ISABELLE_IDENTIFIER";
 
--- a/src/Pure/System/isabelle_system.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/System/isabelle_system.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -195,7 +195,7 @@
       else error("Failed to identify Isabelle distribution " + root)
     }
 
-  object Isabelle_Id extends Scala.Fun("isabelle_id")
+  object Isabelle_Id extends Scala.Fun_String("isabelle_id")
   {
     val here = Scala_Project.here
     def apply(arg: String): String = isabelle_id()
@@ -273,13 +273,13 @@
   }
 
 
-  object Make_Directory extends Scala.Fun("make_directory")
+  object Make_Directory extends Scala.Fun_String("make_directory")
   {
     val here = Scala_Project.here
     def apply(arg: String): String = apply_paths1(arg, make_directory)
   }
 
-  object Copy_Dir extends Scala.Fun("copy_dir")
+  object Copy_Dir extends Scala.Fun_String("copy_dir")
   {
     val here = Scala_Project.here
     def apply(arg: String): String = apply_paths2(arg, copy_dir)
@@ -316,13 +316,13 @@
   }
 
 
-  object Copy_File extends Scala.Fun("copy_file")
+  object Copy_File extends Scala.Fun_String("copy_file")
   {
     val here = Scala_Project.here
     def apply(arg: String): String = apply_paths2(arg, copy_file)
   }
 
-  object Copy_File_Base extends Scala.Fun("copy_file_base")
+  object Copy_File_Base extends Scala.Fun_String("copy_file_base")
   {
     val here = Scala_Project.here
     def apply(arg: String): String = apply_paths3(arg, copy_file_base)
@@ -416,7 +416,7 @@
 
   def rm_tree(root: Path): Unit = rm_tree(root.file)
 
-  object Rm_Tree extends Scala.Fun("rm_tree")
+  object Rm_Tree extends Scala.Fun_String("rm_tree")
   {
     val here = Scala_Project.here
     def apply(arg: String): String = apply_paths1(arg, rm_tree)
@@ -604,12 +604,12 @@
     Bytes.write(file, content.bytes)
   }
 
-  object Download extends Scala.Fun("download", thread = true)
+  object Download extends Scala.Fun_Strings("download", thread = true)
   {
     val here = Scala_Project.here
-    def apply(arg: String): String =
-      Library.split_strings0(arg) match {
-        case List(url, file) => download(url, Path.explode(file)); ""
+    override def apply(args: List[String]): List[String] =
+      args match {
+        case List(url, file) => download(url, Path.explode(file)); Nil
       }
   }
 }
--- a/src/Pure/System/isabelle_tool.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/System/isabelle_tool.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -134,7 +134,7 @@
   def isabelle_tools(): List[Entry] =
     (external_tools() ::: internal_tools).sortBy(_.name)
 
-  object Isabelle_Tools extends Scala.Fun("isabelle_tools")
+  object Isabelle_Tools extends Scala.Fun_String("isabelle_tools")
   {
     val here = Scala_Project.here
     def apply(arg: String): String =
--- a/src/Pure/System/scala.ML	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/System/scala.ML	Mon Apr 12 18:10:13 2021 +0200
@@ -7,7 +7,8 @@
 signature SCALA =
 sig
   exception Null
-  val function: string -> string -> string
+  val function: string -> string list -> string list
+  val function1: string -> string -> string
 end;
 
 structure Scala: SCALA =
@@ -20,31 +21,31 @@
 val new_id = string_of_int o Counter.make ();
 
 val results =
-  Synchronized.var "Scala.results" (Symtab.empty: string Exn.result Symtab.table);
+  Synchronized.var "Scala.results" (Symtab.empty: string list Exn.result Symtab.table);
 
 val _ =
   Protocol_Command.define "Scala.result"
-    (fn [id, tag, res] =>
+    (fn id :: args =>
       let
         val result =
-          (case tag of
-            "0" => Exn.Exn Null
-          | "1" => Exn.Res res
-          | "2" => Exn.Exn (ERROR res)
-          | "3" => Exn.Exn (Fail res)
-          | "4" => Exn.Exn Exn.Interrupt
-          | _ => raise Fail ("Bad tag: " ^ tag));
+          (case args of
+            ["0"] => Exn.Exn Null
+          | "1" :: rest => Exn.Res rest
+          | ["2", msg] => Exn.Exn (ERROR msg)
+          | ["3", msg] => Exn.Exn (Fail msg)
+          | ["4"] => Exn.Exn Exn.Interrupt
+          | _ => raise Fail "Malformed Scala.result");
       in Synchronized.change results (Symtab.map_entry id (K result)) end);
 
 in
 
-fun function name arg =
+fun function name args =
   Thread_Attributes.uninterruptible (fn restore_attributes => fn () =>
     let
       val id = new_id ();
       fun invoke () =
        (Synchronized.change results (Symtab.update (id, Exn.Exn Match));
-        Output.protocol_message (Markup.invoke_scala name id) [[XML.Text arg]]);
+        Output.protocol_message (Markup.invoke_scala name id) (map (single o XML.Text) args));
       fun cancel () =
        (Synchronized.change results (Symtab.delete_safe id);
         Output.protocol_message (Markup.cancel_scala id) []);
@@ -61,6 +62,8 @@
         handle exn => (if Exn.is_interrupt exn then cancel () else (); Exn.reraise exn)
     end) ();
 
+val function1 = singleton o function;
+
 end;
 
 end;
--- a/src/Pure/System/scala.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/System/scala.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -18,11 +18,31 @@
   /** registered functions **/
 
   abstract class Fun(val name: String, val thread: Boolean = false)
-    extends Function[String, String]
   {
     override def toString: String = name
+    def multi: Boolean = true
     def position: Properties.T = here.position
     def here: Scala_Project.Here
+    def invoke(args: List[Bytes]): List[Bytes]
+  }
+
+  abstract class Fun_Strings(name: String, thread: Boolean = false)
+    extends Fun(name, thread = thread)
+  {
+    override def invoke(args: List[Bytes]): List[Bytes] =
+      apply(args.map(_.text)).map(Bytes.apply)
+    def apply(args: List[String]): List[String]
+  }
+
+  abstract class Fun_String(name: String, thread: Boolean = false)
+    extends Fun_Strings(name, thread = thread)
+  {
+    override def multi: Boolean = false
+    override def apply(args: List[String]): List[String] =
+      args match {
+        case List(arg) => List(apply(arg))
+        case _ => error("Expected single argument for Scala function " + quote(name))
+      }
     def apply(arg: String): String
   }
 
@@ -35,13 +55,13 @@
 
   /** demo functions **/
 
-  object Echo extends Fun("echo")
+  object Echo extends Fun_String("echo")
   {
     val here = Scala_Project.here
     def apply(arg: String): String = arg
   }
 
-  object Sleep extends Fun("sleep")
+  object Sleep extends Fun_String("sleep")
   {
     val here = Scala_Project.here
     def apply(seconds: String): String =
@@ -127,7 +147,7 @@
     }
   }
 
-  object Toplevel extends Fun("scala_toplevel")
+  object Toplevel extends Fun_String("scala_toplevel")
   {
     val here = Scala_Project.here
     def apply(arg: String): String =
@@ -162,16 +182,16 @@
       case None => false
     }
 
-  def function_body(name: String, arg: String): (Tag.Value, String) =
+  def function_body(name: String, args: List[Bytes]): (Tag.Value, List[Bytes]) =
     functions.find(fun => fun.name == name) match {
       case Some(fun) =>
-        Exn.capture { fun(arg) } match {
-          case Exn.Res(null) => (Tag.NULL, "")
+        Exn.capture { fun.invoke(args) } match {
+          case Exn.Res(null) => (Tag.NULL, Nil)
           case Exn.Res(res) => (Tag.OK, res)
-          case Exn.Exn(Exn.Interrupt()) => (Tag.INTERRUPT, "")
-          case Exn.Exn(e) => (Tag.ERROR, Exn.message(e))
+          case Exn.Exn(Exn.Interrupt()) => (Tag.INTERRUPT, Nil)
+          case Exn.Exn(e) => (Tag.ERROR, List(Bytes(Exn.message(e))))
         }
-      case None => (Tag.FAIL, "Unknown Isabelle/Scala function: " + quote(name))
+      case None => (Tag.FAIL, List(Bytes("Unknown Isabelle/Scala function: " + quote(name))))
     }
 
 
@@ -191,11 +211,11 @@
       futures = Map.empty
     }
 
-    private def result(id: String, tag: Scala.Tag.Value, res: String): Unit =
+    private def result(id: String, tag: Scala.Tag.Value, res: List[Bytes]): Unit =
       synchronized
       {
         if (futures.isDefinedAt(id)) {
-          session.protocol_command("Scala.result", id, tag.id.toString, res)
+          session.protocol_command_raw("Scala.result", Bytes(id) :: Bytes(tag.id.toString) :: res)
           futures -= id
         }
       }
@@ -203,7 +223,7 @@
     private def cancel(id: String, future: Future[Unit]): Unit =
     {
       future.cancel()
-      result(id, Scala.Tag.INTERRUPT, "")
+      result(id, Scala.Tag.INTERRUPT, Nil)
     }
 
     private def invoke_scala(msg: Prover.Protocol_Output): Boolean = synchronized
@@ -212,7 +232,7 @@
         case Markup.Invoke_Scala(name, id) =>
           def body: Unit =
           {
-            val (tag, res) = Scala.function_body(name, msg.text)
+            val (tag, res) = Scala.function_body(name, msg.chunks)
             result(id, tag, res)
           }
           val future =
--- a/src/Pure/Thy/bibtex.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/Thy/bibtex.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -146,7 +146,7 @@
     })
   }
 
-  object Check_Database extends Scala.Fun("bibtex_check_database")
+  object Check_Database extends Scala.Fun_String("bibtex_check_database")
   {
     val here = Scala_Project.here
     def apply(database: String): String =
--- a/src/Pure/Tools/debugger.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/Tools/debugger.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -259,7 +259,7 @@
   }
 
   def input(thread_name: String, msg: String*): Unit =
-    session.protocol_command("Debugger.input", (thread_name :: msg.toList):_*)
+    session.protocol_command_args("Debugger.input", thread_name :: msg.toList)
 
   def continue(thread_name: String): Unit = input(thread_name, "continue")
   def step(thread_name: String): Unit = input(thread_name, "step")
--- a/src/Pure/Tools/doc.scala	Mon Apr 12 15:00:03 2021 +0200
+++ b/src/Pure/Tools/doc.scala	Mon Apr 12 18:10:13 2021 +0200
@@ -73,7 +73,7 @@
     examples() ::: release_notes() ::: main_contents
   }
 
-  object Doc_Names extends Scala.Fun("doc_names")
+  object Doc_Names extends Scala.Fun_String("doc_names")
   {
     val here = Scala_Project.here
     def apply(arg: String): String =