src/Tools/WWW_Find/scgi_server.ML
changeset 33817 f6a4da31f2f1
child 33823 24090eae50b6
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Tools/WWW_Find/scgi_server.ML	Fri Nov 20 18:36:44 2009 +1100
@@ -0,0 +1,125 @@
+(*  Title:      scgi_echo.ML
+    Author:     Timothy Bourke, NICTA
+
+Simple SCGI server.
+*)
+
+signature SCGI_SERVER =
+sig
+  val max_threads : int Unsynchronized.ref
+  type handler = ScgiReq.t * Word8Vector.vector * (string -> unit) -> unit
+  val register : (string * Mime.t option * handler) -> unit
+  val server : string -> int -> unit
+  val server' : int -> string -> int -> unit (* keeps trying for port *)
+end;
+
+structure ScgiServer : SCGI_SERVER =
+struct
+val max_threads = Unsynchronized.ref 5;
+
+type handler = ScgiReq.t * Word8Vector.vector * (string -> unit) -> unit;
+
+local
+val servers = Unsynchronized.ref (Symtab.empty : (Mime.t option * handler) Symtab.table);
+in
+fun register (name, mime, f) =
+  Unsynchronized.change servers (Symtab.update_new (name, (mime, f)));
+fun lookup name = Symtab.lookup (!servers) name;
+
+fun dump_handlers () = (
+    tracing("  with handlers:");
+    app (fn (x, _) => tracing ("    - " ^ x)) (Symtab.dest (!servers)))
+end;
+
+fun server server_prefix port =
+  let
+    val passive_sock = SocketUtil.init_server_socket (SOME "localhost") port;
+
+    val thread_wait = ConditionVar.conditionVar ();
+    val thread_wait_mutex = Mutex.mutex ();
+
+    local
+    val threads = Unsynchronized.ref ([] : Thread.thread list);
+    fun purge () = Unsynchronized.change threads (filter Thread.isActive);
+    in
+    fun add_thread th = Unsynchronized.change threads (cons th);
+
+    fun launch_thread threadf =
+      (purge ();
+       if length (!threads) < (!max_threads) then ()
+       else (tracing ("Waiting for a free thread...");
+             ConditionVar.wait (thread_wait, thread_wait_mutex));
+       add_thread
+         (Thread.fork
+            (fn () => exception_trace threadf,
+             [Thread.EnableBroadcastInterrupt true,
+              Thread.InterruptState
+              Thread.InterruptAsynchOnce])))
+    end;
+
+    fun loop () =
+      let
+        val (sock, _)= Socket.accept passive_sock;
+
+        val (sin, sout) = SocketUtil.make_streams sock;
+
+        fun send msg = BinIO.output (sout, Byte.stringToBytes msg);
+        fun send_log msg = (tracing msg; send msg);
+
+        fun get_content (st, 0) = Word8Vector.fromList []
+          | get_content x = BinIO.inputN x;
+
+        fun do_req () =
+          let
+            val (req as ScgiReq.Req {path_info, request_method, ...},
+                 content_is) =
+              ScgiReq.parse sin
+              handle ScgiReq.InvalidReq s =>
+                (send
+                   (HttpUtil.reply_header (HttpStatus.bad_request, NONE, []));
+                 raise Fail ("Invalid request: " ^ s));
+            val () = tracing ("request: " ^ path_info);
+          in
+            (case lookup (unprefix server_prefix path_info) of
+               NONE => send (HttpUtil.reply_header (HttpStatus.not_found, NONE, []))
+             | SOME (NONE, f) => f (req, get_content content_is, send)
+             | SOME (t, f) =>
+                (send (HttpUtil.reply_header (HttpStatus.ok, t, []));
+                 if request_method = ScgiReq.Head then ()
+                 else f (req, get_content content_is, send)))
+          end;
+
+        fun thread_req () =
+          (do_req () handle e => (warning (exnMessage e));
+           BinIO.closeOut sout handle e => warning (exnMessage e);
+           BinIO.closeIn sin handle e => warning (exnMessage e);
+           Socket.close sock handle e => warning (exnMessage e);
+           tracing ("request done.");
+           ConditionVar.signal thread_wait);
+      in
+        launch_thread thread_req;
+        loop ()
+      end;
+  in
+    tracing ("SCGI server started.");
+    dump_handlers ();
+    loop ();
+    Socket.close passive_sock
+  end;
+
+local
+val delay = 5;
+in
+fun server' 0 server_prefix port = (warning "Giving up."; exit 1)
+  | server' countdown server_prefix port =
+      server server_prefix port
+        handle OS.SysErr ("bind failed", _) =>
+          (warning ("Could not acquire port "
+                    ^ Int.toString port ^ ". Trying again in "
+                    ^ Int.toString delay ^ " seconds...");
+           OS.Process.sleep (Time.fromSeconds delay);
+           server' (countdown - 1) server_prefix port);
+end;
+
+end;
+