# HG changeset patch # User wenzelm # Date 1700837904 -3600 # Node ID 10b6add456d0af3033ee8a21d01e3c93f4be0916 # Parent caddfe4949a8b389cad86dc49f6d8d8d2bfda253 support for Unix-domain sockets, using java.nio.channels.ServerSocketChannel; diff -r caddfe4949a8 -r 10b6add456d0 src/Pure/General/socket_io.ML --- a/src/Pure/General/socket_io.ML Fri Nov 24 14:11:01 2023 +0100 +++ b/src/Pure/General/socket_io.ML Fri Nov 24 15:58:24 2023 +0100 @@ -68,22 +68,39 @@ in (in_stream, out_stream) end; +fun socket_name_inet name = + (case space_explode ":" name of + [h, p] => + (case (NetHostDB.getByName h, Int.fromString p) of + (SOME host, SOME port) => SOME (host, port) + | _ => NONE) + | _ => NONE); -fun open_streams name = +fun open_streams_inet (host, port) = let - fun err () = error ("Bad socket name: " ^ quote name); - val (host, port) = - (case space_explode ":" name of - [h, p] => - (case NetHostDB.getByName h of SOME host => host | NONE => err (), - case Int.fromString p of SOME port => port | NONE => err ()) - | _ => err ()); val socket: Socket.active INetSock.stream_sock = INetSock.TCP.socket (); val _ = Socket.connect (socket, INetSock.toAddr (NetHostDB.addr host, port)); val (socket_host, socket_port) = INetSock.fromAddr (Socket.Ctl.getSockName socket); val socket_name = NetHostDB.toString socket_host ^ ":" ^ string_of_int socket_port; - in make_streams socket_name socket end + in make_streams socket_name socket end; + +fun open_streams_unix path = + \<^if_windows>\raise Fail "Cannot create Unix-domain socket on Windows"\ + \<^if_unix>\ + let + val socket_name = File.platform_path path; + val socket: Socket.active UnixSock.stream_sock = UnixSock.Strm.socket (); + val _ = Socket.connect (socket, UnixSock.toAddr socket_name); + in make_streams socket_name socket end\ + +fun open_streams name = + (case socket_name_inet name of + SOME inet => open_streams_inet inet + | NONE => + (case try Path.explode name of + SOME path => open_streams_unix path + | NONE => error ("Bad socket name: " ^ quote name))) handle OS.SysErr (msg, _) => error (msg ^ ": failed to open socket " ^ name); fun with_streams f = diff -r caddfe4949a8 -r 10b6add456d0 src/Pure/ROOT.ML --- a/src/Pure/ROOT.ML Fri Nov 24 14:11:01 2023 +0100 +++ b/src/Pure/ROOT.ML Fri Nov 24 15:58:24 2023 +0100 @@ -100,7 +100,6 @@ ML_file "PIDE/byte_message.ML"; ML_file "PIDE/protocol_message.ML"; ML_file "PIDE/document_id.ML"; -ML_file "General/socket_io.ML"; ML_file "General/graph.ML"; @@ -301,6 +300,7 @@ ML_file "Proof/extraction.ML"; (*Isabelle system*) +ML_file "General/socket_io.ML"; ML_file "PIDE/protocol_command.ML"; ML_file "System/java.ML"; ML_file "System/scala.ML"; diff -r caddfe4949a8 -r 10b6add456d0 src/Pure/System/system_channel.scala --- a/src/Pure/System/system_channel.scala Fri Nov 24 14:11:01 2023 +0100 +++ b/src/Pure/System/system_channel.scala Fri Nov 24 15:58:24 2023 +0100 @@ -8,18 +8,40 @@ import java.io.{InputStream, OutputStream} -import java.net.{ServerSocket, InetAddress} +import java.net.{InetAddress, InetSocketAddress, ProtocolFamily, ServerSocket, SocketAddress, StandardProtocolFamily, UnixDomainSocketAddress} +import java.nio.channels.{ServerSocketChannel, Channels} object System_Channel { - def apply(): System_Channel = new System_Channel + def apply(unix_domain: Boolean = Platform.is_unix): System_Channel = + if (unix_domain) new Unix else new Inet + + class Inet extends System_Channel(StandardProtocolFamily.INET) { + server.bind(new InetSocketAddress(Server.localhost, 0), 50) + + override def address: String = + Server.print_address(server.getLocalAddress.asInstanceOf[InetSocketAddress].getPort) + } + + class Unix extends System_Channel(StandardProtocolFamily.UNIX) { + private val socket_file = Isabelle_System.tmp_file("socket", initialized = false) + private val socket_file_name = socket_file.getPath + + server.bind(UnixDomainSocketAddress.of(socket_file_name)) + + override def address: String = socket_file_name + override def shutdown(): Unit = { + super.shutdown() + socket_file.delete + } + } } -class System_Channel private { - private val server = new ServerSocket(0, 50, Server.localhost) +abstract class System_Channel private(protocol_family: ProtocolFamily) { + protected val server: ServerSocketChannel = ServerSocketChannel.open(protocol_family) - val address: String = Server.print_address(server.getLocalPort) - val password: String = UUID.random().toString + def address: String + lazy val password: String = UUID.random().toString override def toString: String = address @@ -28,8 +50,8 @@ def rendezvous(): (OutputStream, InputStream) = { val socket = server.accept try { - val out_stream = socket.getOutputStream - val in_stream = socket.getInputStream + val out_stream = Channels.newOutputStream(socket) + val in_stream = Channels.newInputStream(socket) Byte_Message.read_line(in_stream) match { case Some(bs) if bs.text == password => (out_stream, in_stream)