clarified protocol: ML worker thread blocks and awaits result from Scala, to avoid excessive replacement threads;
authorwenzelm
Fri, 14 Aug 2020 13:26:12 +0200
changeset 72151 64df1e514005
parent 72150 510ebf846696
child 72152 3fa75db844f5
clarified protocol: ML worker thread blocks and awaits result from Scala, to avoid excessive replacement threads;
src/Pure/System/scala.ML
src/Pure/System/scala.scala
--- a/src/Pure/System/scala.ML	Thu Aug 13 15:52:40 2020 +0200
+++ b/src/Pure/System/scala.ML	Fri Aug 14 13:26:12 2020 +0200
@@ -6,71 +6,72 @@
 
 signature SCALA =
 sig
+  exception Null
+  val function: string -> string -> string
   val functions: unit -> string list
   val check_function: Proof.context -> string * Position.T -> string
-  val promise_function: string -> string -> string future
-  val function: string -> string -> string
-  exception Null
 end;
 
 structure Scala: SCALA =
 struct
 
-(** invoke Scala functions from ML **)
+(** protocol for Scala function invocation from ML **)
 
-val _ = Session.protocol_handler "isabelle.Scala";
+exception Null;
 
-
-(* pending promises *)
+local
 
 val new_id = string_of_int o Counter.make ();
 
-val promises =
-  Synchronized.var "Scala.promises" (Symtab.empty: string future Symtab.table);
+val results =
+  Synchronized.var "Scala.results" (Symtab.empty: string Exn.result Symtab.table);
 
+val _ =
+  Isabelle_Process.protocol_command "Scala.result"
+    (fn [id, tag, res] =>
+      let
+        val result =
+          (case tag of
+            "0" => Exn.Exn Null
+          | "1" => Exn.Res res
+          | "2" => Exn.Exn (ERROR res)
+          | "3" => Exn.Exn (Fail res)
+          | "4" => Exn.Exn Exn.Interrupt
+          | _ => raise Fail ("Bad tag: " ^ tag));
+      in Synchronized.change results (Symtab.map_entry id (K result)) end);
 
-(* invoke function *)
+val _ = Session.protocol_handler "isabelle.Scala";
+
+in
 
-fun promise_function name arg =
-  let
-    val id = new_id ();
-    fun abort () = Output.protocol_message (Markup.cancel_scala id) [];
-    val promise = Future.promise_name "invoke_scala" abort : string future;
-    val _ = Synchronized.change promises (Symtab.update (id, promise));
-    val _ = Output.protocol_message (Markup.invoke_scala name id) [XML.Text arg];
-  in promise end;
+fun function name arg =
+  Thread_Attributes.uninterruptible (fn restore_attributes => fn () =>
+    let
+      val id = new_id ();
+      fun invoke () =
+       (Synchronized.change results (Symtab.update (id, Exn.Exn Match));
+        Output.protocol_message (Markup.invoke_scala name id) [XML.Text arg]);
+      fun cancel () =
+       (Synchronized.change results (Symtab.delete_safe id);
+        Output.protocol_message (Markup.cancel_scala id) []);
+      fun await_result () =
+        Synchronized.guarded_access results
+          (fn tab =>
+            (case Symtab.lookup tab id of
+              SOME (Exn.Exn Match) => NONE
+            | SOME result => SOME (result, Symtab.delete id tab)
+            | NONE => SOME (Exn.Exn Exn.Interrupt, tab)))
+        handle exn => (if Exn.is_interrupt exn then cancel () else (); Exn.reraise exn);
+    in
+      invoke ();
+      Exn.release (restore_attributes await_result ())
+    end) ();
 
-fun function name arg = Future.join (promise_function name arg);
+end;
 
 
-(* fulfill *)
 
-exception Null;
-
-fun fulfill id tag res =
-  let
-    val result =
-      (case tag of
-        "0" => Exn.Exn Null
-      | "1" => Exn.Res res
-      | "2" => Exn.Exn (ERROR res)
-      | "3" => Exn.Exn (Fail res)
-      | "4" => Exn.Exn Exn.Interrupt
-      | _ => raise Fail "Bad tag");
-    val promise =
-      Synchronized.change_result promises
-        (fn tab => (the (Symtab.lookup tab id), Symtab.delete id tab));
-    val _ = Future.fulfill_result promise result;
-  in () end;
-
-val _ =
-  Isabelle_Process.protocol_command "Scala.fulfill"
-    (fn [id, tag, res] =>
-      fulfill id tag res
-        handle exn => if Exn.is_interrupt exn then () else Exn.reraise exn);
-
-
-(* registered functions *)
+(** registered Scala functions **)
 
 fun functions () = space_explode "," (getenv "ISABELLE_SCALA_FUNCTIONS");
 
--- a/src/Pure/System/scala.scala	Thu Aug 13 15:52:40 2020 +0200
+++ b/src/Pure/System/scala.scala	Fri Aug 14 13:26:12 2020 +0200
@@ -138,11 +138,11 @@
     futures = Map.empty
   }
 
-  private def fulfill(id: String, tag: Scala.Tag.Value, res: String): Unit =
+  private def result(id: String, tag: Scala.Tag.Value, res: String): Unit =
     synchronized
     {
       if (futures.isDefinedAt(id)) {
-        session.protocol_command("Scala.fulfill", id, tag.id.toString, res)
+        session.protocol_command("Scala.result", id, tag.id.toString, res)
         futures -= id
       }
     }
@@ -150,7 +150,7 @@
   private def cancel(id: String, future: Future[Unit])
   {
     future.cancel
-    fulfill(id, Scala.Tag.INTERRUPT, "")
+    result(id, Scala.Tag.INTERRUPT, "")
   }
 
   private def invoke_scala(msg: Prover.Protocol_Output): Boolean = synchronized
@@ -159,8 +159,8 @@
       case Markup.Invoke_Scala(name, id) =>
         futures += (id ->
           Future.fork {
-            val (tag, result) = Scala.function(name, msg.text)
-            fulfill(id, tag, result)
+            val (tag, res) = Scala.function(name, msg.text)
+            result(id, tag, res)
           })
         true
       case _ => false