Implementing an HTTP Server With Scala & JAVA NIO
These days I find myself working with Scala & the JVM more and more. The ecosystem is very rich, and there are a lot of tools & libraries to help you get things done.
Coming from a language like Go
that has a strong networking library that allows you to go as detailed in managing connections as you want, I couldn’t help myself
but start digging into how networking & I/O are done in the JVM world. So I decided to implement a toy HTTP server with I/O multiplexing
by using Java’s NIO package & Scala.
Handling connections
def create: Server = {
val listenerTry = Try(ServerSocketChannel.open().bind(new InetSocketAddress(80)))
listenerTry match {
case Success(listener) => {
implicit val ec = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(1))
var root = RoutingNode("", indexHandler, collection.mutable.Map.empty)
Routes.insert(root, "GET /", indexHandler)
Routes.insert(root, "GET /foo", fooHandler)
Routes.insert(root, "GET /foo/bar", barHandler)
val selector = Selector.open()
val server = new Server(listener, selector, root)
Signal.handle(new Signal("INT"), server.handle )
Signal.handle(new Signal("TERM"), server.handle )
server
}
case Failure(e: java.io.IOException) => {
Logger.error(s"opening socket: $e")
scala.sys.exit(1)
}
}
}
We first create a socket, open it and bind it to port 80
. NIO is referring to this as a ServerSocketChannel.
The ServerSocketChannel
implements SelectableChannel, which as its name states is a channel
where all the events flow through. It is also selectable, so it will allow us to select the types of events we are actually curious about. In order to do so
we also have to create a Selector
, which we’ll register with the socket channel, and it’s the thing that’s going to allow us to multiplex the various I/O operations. We then create a Server
object that will handle the server logic.
All this initialization is wrapped into a Try
, just in case something would go wrong during the initialization process. If we’re unable to open/bind the socket we’ll log the error and exit the program.
We also add some Signal
handlers to the server, so we can gracefully shut it down when we receive a SIGINT
or SIGTERM
signal.
class Server(
socket: ServerSocketChannel,
selector: Selector,
var root: RoutingNode,
var isRunning: Boolean = false,
)(implicit ec: ExecutionContext) extends SignalHandler {
def run() = {
socket.configureBlocking(false)
socket.register(selector, SelectionKey.OP_ACCEPT)
isRunning = true
Logger.info("ready to accept new connections")
while (isRunning) {
try {
selector.select()
val selectedKeys = selector.selectedKeys()
val iterator = selectedKeys.iterator()
while(iterator.hasNext) {
val key = iterator.next()
iterator.remove() // super important to remove the selected key!
if (key.isAcceptable) {
val client = socket.accept()
client.configureBlocking(false)
client.register(selector, SelectionKey.OP_READ)
}
if (key.isReadable) {
handleRequest(key.channel().asInstanceOf[SocketChannel])
}
}
} catch {
case e: java.io.IOException => {
Logger.error(s"IO exception while running server: ${e.printStackTrace()}")
}
case e: Exception => {
Logger.error(s"unexpected error: ${e.printStackTrace()}")
}
}
}
}
// this is where we catch signals
override def handle(sig: Signal): Unit = {
Logger.warn("received an interrupt; time to shutdown...")
isRunning = false
selector.wakeup()
socket.close()
}
}
As soon as we enter the run()
function we configure the socket to be non-blocking, register it with the selector, and tell it that we are interested in OP_ACCEPT
events,
which means we are interested in accepting new connections.
We then enter a loop, and start selecting keys whose channels are ready for I/O operations. We iterate over the selected keys, first we check if there’s connections
to be accepted, if yes we register them for reading. If the key is readable we pass it on to the handleRequest
function. Technically we could have broken up
read & write operations, but for the sake of simplicity I’ve decided to keep it as a single operation.
You might have noticed that I am explicitly passing the selector to the Server
object. I could have easily initialized it in my run()
function, but I had to
make sure it is accessible by the signal catcher function. This is because we need to “wake up” the selector when shutting down the server, otherwise it will
end up blocking the main thread all the way until a new request/event comes in.
Routing Tree
For the routing part I decided to implement a Trie data structure. It’s a tree data structure that is commonly used for HTTP routing.
We use a RoutingNode
to represent each node in the tree:
case class RoutingNode(
segment: String,
handler: Handler,
children: collection.mutable.Map[String, RoutingNode],
)
Each RoutingNode
can have multiple children, which is why I opted to use a Map
to store them.
Each segment refers to either the HTTP method (GET, POST etc.) or a part of the URL. So in the case of "https://domain.io/foo/bar"
, both "foo"
and "bar"
will be assigned to a RoutingNode
.
Each node points to an HTTP Handler, which in our case is a very simple trait definition:
trait Handler {
def apply(request: Request): Response
}
case class Request(methodWithPath: String)
case class Response(status: String, body: String)
The handlers take a Request
which will contain the HTTP method & path to the requested resource.
And they’ll return the type Response
where the field status
is a status code & message, and the field body
represents the HTTP body we want to send back to the client.
Then we have a Routes
object that contains the logic for searching & inserting nodes into the tree.
object Routes {
def traverse(root: RoutingNode, input: String)(segmentOp: (RoutingNode, String) => RoutingNode): Option[RoutingNode] = {
var current = root
segmentsOpt(input) match {
case segments => {
if (segments.length == 0) { return None }
segments.foreach(segment => {
current.children.get(segment) match {
case Some(child) => current = child
case None => {
current = segmentOp(current, segment)
}
}
})
} case _ => {
Logger.error(s"invalid search input => $input")
return None
}
}
Some(current)
}
def search(root: RoutingNode, input: String): Option[RoutingNode] = {
traverse(root, input) { (_, _) =>
return None
}
}
def insert(root: RoutingNode, input: String, handler: Handler): Option[RoutingNode] = {
val child = traverse(root, input) { (current, segment) =>
current.children(segment) =
RoutingNode(segment, handler, collection.mutable.Map.empty[String, RoutingNode])
current.children(segment)
}
Option.when(child.isDefined)(root) // the traversal function returns a child, but when inserting
// we want to present the root node instead
}
// we return a list of (METHOD, PATH_SEGMENTS...) eg: (GET, foo, bar, baz)
def segmentsOpt(input: String): Array[String] = {
input.split(" ") match {
case Array(method, path) => {
Array(method) ++ path.split("/").filter(_.nonEmpty) // here we take the tail/nonEmpty because the first '/' creates an empty string
}
case _ => {
Logger.error(s"invalid search input => $input")
Array.empty
}
}
}
}
The segmentsOpt
function is a helper that takes our input (GET /foo/bar
) and returns an array of segments where the leading /
is omitted eg.: (GET, foo, bar, baz)
.
We use the shared traverse
helper for both the search
& insert
functions, which is accompanied by an operation (segmentOp
).
During traversal if a segment doesn’t belong to a node we apply the operation, which will simply return None
for search, and insert a new node for the inserts.
The traverse
function returns the current node, which is great for searching, but we want to swap it out with the root node for insert operations.
Originally I implemented the search
& insert
functions without the traverse
helper, but I found that the code was very repetitive, so this is my hacky way of DRYing it up.
Parsing Requests & Building Responses
object Server {
def indexHandler: Handler = new Handler {
override def apply(request: Request): Response = {
Response("200 OK", "HELLO FROM INDEX")
}
}
def fooHandler: Handler = new Handler {
override def apply(request: Request): Response = {
Response("200 OK", "HELLO FROM FOO")
}
}
def barHandler: Handler = new Handler {
override def apply(request: Request): Response = {
Response("200 OK", "HELLO FROM BAR")
}
}
...
// during server creation
var root = RoutingNode("", indexHandler, collection.mutable.Map.empty)
Routes.insert(root, "GET /", indexHandler)
Routes.insert(root, "GET /foo", fooHandler)
Routes.insert(root, "GET /foo/bar", barHandler)
}
Since this is a toy server I kept this part super simple. First I’ve defined some handlers, which implement the Handler
trait defined earlier,
and added them to our routing tree.
private def handleRequest(channel: SocketChannel) = {
Logger.info(s"received a connection from ${channel.getRemoteAddress}")
try {
val buffer = ByteBuffer.allocate(256)
val readBytes = channel.read(buffer)
if (readBytes == -1) {
channel.close()
} else {
buffer.flip()
val request = parseRequest(buffer)
val resp = Server.buildResponse(root, request)
buffer.clear()
buffer.put(resp)
buffer.flip()
channel.write(buffer)
buffer.clear()
channel.close()
}
} catch {
case e: SocketException => { // we're unable to talk to the socket anymore
Logger.error(s"communicating with socket: ${e.printStackTrace()}")
}
case e: Exception => { // catch all errors
Logger.error(s"handling connection: ${e.printStackTrace()}")
}
}
}
For reading the request we use the amazing ByteBuffer class that is part of the NIO package.
It’s basically a byte array encapsulated in an object. It has very convenient methods around dealing with its various indices, transferring data between various
data types & viewing the data.
In this case I opted to use the same buffer for both reading the request and writing the response, and we currently have a limit of 256
bytes, which is just enough
for our use case!
First, we read the incoming bytes and set our buffer to the beginning of the data, so that the parsing can start doing its thing from the beginning of the
buffer.
Then we have a function buildResponse
that takes the root node, a request and returns the HTTP response in byte format. We clear our buffer, put in the response
bytes, flip()
it before writing it back to the channel.
The buildResponse
itself looks like this:
def buildResponse(root: RoutingNode, request: Request): Array[Byte] = {
val response = (Routes.search(root, request.methodWithPath)) match {
case Some(node) => {
node.handler(request)
}
case _ => {
Response("404 NOT FOUND", EMPTY_BODY)
}
}
val now = ZonedDateTime.now(ZoneId.of("GMT"))
val formatter = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss z")
val nowFormatted = now.format(formatter)
val contents = (
s"$PROTOCOL ${response.status}" + CRLF
+ s"Date: $nowFormatted" + CRLF
+ CONTENT_TYPE + CRLF
+ s"Content-Length: ${response.body.length}" + CRLF
+ SERVER_HEADER + CRLF
+ ALLOW_ORIGIN + CRLF
+ ALLOW_CREDENTIALS + CRLF
+ "Connection: close" + CRLF
+ CRLF
+ response.body
)
contents.getBytes("ISO-8859-1")
}
It’s a little bit hacky and hardcoded-y, but it’s good enough for our purposes!
And there you have it! A simple HTTP server that supports I/O multiplexing, implemented with Scala & Java NIO. I hope you liked it!!
For the full code refer to THIS LINK