support for Unix-domain sockets, using java.nio.channels.ServerSocketChannel;
authorwenzelm
Fri, 24 Nov 2023 15:58:24 +0100
changeset 79049 10b6add456d0
parent 79048 caddfe4949a8
child 79050 4d8716098d41
support for Unix-domain sockets, using java.nio.channels.ServerSocketChannel;
src/Pure/General/socket_io.ML
src/Pure/ROOT.ML
src/Pure/System/system_channel.scala
--- a/src/Pure/General/socket_io.ML	Fri Nov 24 14:11:01 2023 +0100
+++ b/src/Pure/General/socket_io.ML	Fri Nov 24 15:58:24 2023 +0100
@@ -68,22 +68,39 @@
 
   in (in_stream, out_stream) end;
 
+fun socket_name_inet name =
+  (case space_explode ":" name of
+    [h, p] =>
+     (case (NetHostDB.getByName h, Int.fromString p) of
+       (SOME host, SOME port) => SOME (host, port)
+     | _ => NONE)
+  | _ => NONE);
 
-fun open_streams name =
+fun open_streams_inet (host, port) =
   let
-    fun err () = error ("Bad socket name: " ^ quote name);
-    val (host, port) =
-      (case space_explode ":" name of
-        [h, p] =>
-         (case NetHostDB.getByName h of SOME host => host | NONE => err (),
-          case Int.fromString p of SOME port => port | NONE => err ())
-      | _ => err ());
     val socket: Socket.active INetSock.stream_sock = INetSock.TCP.socket ();
     val _ = Socket.connect (socket, INetSock.toAddr (NetHostDB.addr host, port));
 
     val (socket_host, socket_port) = INetSock.fromAddr (Socket.Ctl.getSockName socket);
     val socket_name = NetHostDB.toString socket_host ^ ":" ^ string_of_int socket_port;
-  in make_streams socket_name socket end
+  in make_streams socket_name socket end;
+
+fun open_streams_unix path =
+  \<^if_windows>\<open>raise Fail "Cannot create Unix-domain socket on Windows"\<close>
+  \<^if_unix>\<open>
+    let
+      val socket_name = File.platform_path path;
+      val socket: Socket.active UnixSock.stream_sock = UnixSock.Strm.socket ();
+      val _ = Socket.connect (socket, UnixSock.toAddr socket_name);
+    in make_streams socket_name socket end\<close>
+
+fun open_streams name =
+  (case socket_name_inet name of
+    SOME inet => open_streams_inet inet
+  | NONE =>
+    (case try Path.explode name of
+      SOME path => open_streams_unix path
+    | NONE => error ("Bad socket name: " ^ quote name)))
   handle OS.SysErr (msg, _) => error (msg ^ ": failed to open socket " ^ name);
 
 fun with_streams f =
--- a/src/Pure/ROOT.ML	Fri Nov 24 14:11:01 2023 +0100
+++ b/src/Pure/ROOT.ML	Fri Nov 24 15:58:24 2023 +0100
@@ -100,7 +100,6 @@
 ML_file "PIDE/byte_message.ML";
 ML_file "PIDE/protocol_message.ML";
 ML_file "PIDE/document_id.ML";
-ML_file "General/socket_io.ML";
 
 ML_file "General/graph.ML";
 
@@ -301,6 +300,7 @@
 ML_file "Proof/extraction.ML";
 
 (*Isabelle system*)
+ML_file "General/socket_io.ML";
 ML_file "PIDE/protocol_command.ML";
 ML_file "System/java.ML";
 ML_file "System/scala.ML";
--- a/src/Pure/System/system_channel.scala	Fri Nov 24 14:11:01 2023 +0100
+++ b/src/Pure/System/system_channel.scala	Fri Nov 24 15:58:24 2023 +0100
@@ -8,18 +8,40 @@
 
 
 import java.io.{InputStream, OutputStream}
-import java.net.{ServerSocket, InetAddress}
+import java.net.{InetAddress, InetSocketAddress, ProtocolFamily, ServerSocket, SocketAddress, StandardProtocolFamily, UnixDomainSocketAddress}
+import java.nio.channels.{ServerSocketChannel, Channels}
 
 
 object System_Channel {
-  def apply(): System_Channel = new System_Channel
+  def apply(unix_domain: Boolean = Platform.is_unix): System_Channel =
+    if (unix_domain) new Unix else new Inet
+
+  class Inet extends System_Channel(StandardProtocolFamily.INET) {
+    server.bind(new InetSocketAddress(Server.localhost, 0), 50)
+
+    override def address: String =
+      Server.print_address(server.getLocalAddress.asInstanceOf[InetSocketAddress].getPort)
+  }
+
+  class Unix extends System_Channel(StandardProtocolFamily.UNIX) {
+    private val socket_file = Isabelle_System.tmp_file("socket", initialized = false)
+    private val socket_file_name = socket_file.getPath
+
+    server.bind(UnixDomainSocketAddress.of(socket_file_name))
+
+    override def address: String = socket_file_name
+    override def shutdown(): Unit = {
+      super.shutdown()
+      socket_file.delete
+    }
+  }
 }
 
-class System_Channel private {
-  private val server = new ServerSocket(0, 50, Server.localhost)
+abstract class System_Channel private(protocol_family: ProtocolFamily) {
+  protected val server: ServerSocketChannel = ServerSocketChannel.open(protocol_family)
 
-  val address: String = Server.print_address(server.getLocalPort)
-  val password: String = UUID.random().toString
+  def address: String
+  lazy val password: String = UUID.random().toString
 
   override def toString: String = address
 
@@ -28,8 +50,8 @@
   def rendezvous(): (OutputStream, InputStream) = {
     val socket = server.accept
     try {
-      val out_stream = socket.getOutputStream
-      val in_stream = socket.getInputStream
+      val out_stream = Channels.newOutputStream(socket)
+      val in_stream = Channels.newInputStream(socket)
 
       Byte_Message.read_line(in_stream) match {
         case Some(bs) if bs.text == password => (out_stream, in_stream)