Implementing an HTTP Server With Scala & JAVA NIO

These days I find myself working with Scala & the JVM more and more. The ecosystem is very rich, and there are a lot of tools & libraries to help you get things done. Coming from a language like Go that has a strong networking library that allows you to go as detailed in managing connections as you want, I couldn’t help myself but start digging into how networking & I/O are done in the JVM world. So I decided to implement a toy HTTP server with I/O multiplexing by using Java’s NIO package & Scala.

Handling connections

def create: Server = {
  val listenerTry = Try(ServerSocketChannel.open().bind(new InetSocketAddress(80)))

  listenerTry match {
    case Success(listener) => {
      implicit val ec = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(1))

      var root = RoutingNode("", indexHandler, collection.mutable.Map.empty)
      Routes.insert(root, "GET /", indexHandler)
      Routes.insert(root, "GET /foo", fooHandler)
      Routes.insert(root, "GET /foo/bar", barHandler)

      val selector = Selector.open()
      val server = new Server(listener, selector, root)

      Signal.handle(new Signal("INT"), server.handle )
      Signal.handle(new Signal("TERM"), server.handle )

      server
    }
    case Failure(e: java.io.IOException) => {
      Logger.error(s"opening socket: $e")
      scala.sys.exit(1)
    }
  }
}

We first create a socket, open it and bind it to port 80. NIO is referring to this as a ServerSocketChannel. The ServerSocketChannel implements SelectableChannel, which as its name states is a channel where all the events flow through. It is also selectable, so it will allow us to select the types of events we are actually curious about. In order to do so we also have to create a Selector, which we’ll register with the socket channel, and it’s the thing that’s going to allow us to multiplex the various I/O operations. We then create a Server object that will handle the server logic.

All this initialization is wrapped into a Try, just in case something would go wrong during the initialization process. If we’re unable to open/bind the socket we’ll log the error and exit the program. We also add some Signal handlers to the server, so we can gracefully shut it down when we receive a SIGINT or SIGTERM signal.

class Server(
  socket: ServerSocketChannel,
  selector: Selector,
  var root: RoutingNode,
  var isRunning: Boolean = false,
)(implicit ec: ExecutionContext) extends SignalHandler {
  def run() = {
    socket.configureBlocking(false)
    socket.register(selector, SelectionKey.OP_ACCEPT)

    isRunning = true

    Logger.info("ready to accept new connections")

    while (isRunning) {
      try {
        selector.select()
        val selectedKeys = selector.selectedKeys()
        val iterator = selectedKeys.iterator()

        while(iterator.hasNext) {
          val key = iterator.next()
          iterator.remove() // super important to remove the selected key!

          if (key.isAcceptable) {
            val client = socket.accept()
            client.configureBlocking(false)
            client.register(selector, SelectionKey.OP_READ)
          }

          if (key.isReadable) {
            handleRequest(key.channel().asInstanceOf[SocketChannel])
          }
        }
      } catch {
        case e: java.io.IOException => {
          Logger.error(s"IO exception while running server: ${e.printStackTrace()}")
        }
        case e: Exception => {
          Logger.error(s"unexpected error: ${e.printStackTrace()}")
        }
      }
    }
  }

  // this is where we catch signals
  override def handle(sig: Signal): Unit = {
    Logger.warn("received an interrupt; time to shutdown...")
    isRunning = false
    selector.wakeup()
    socket.close()
  }
}

As soon as we enter the run() function we configure the socket to be non-blocking, register it with the selector, and tell it that we are interested in OP_ACCEPT events, which means we are interested in accepting new connections. We then enter a loop, and start selecting keys whose channels are ready for I/O operations. We iterate over the selected keys, first we check if there’s connections to be accepted, if yes we register them for reading. If the key is readable we pass it on to the handleRequest function. Technically we could have broken up read & write operations, but for the sake of simplicity I’ve decided to keep it as a single operation.

You might have noticed that I am explicitly passing the selector to the Server object. I could have easily initialized it in my run() function, but I had to make sure it is accessible by the signal catcher function. This is because we need to “wake up” the selector when shutting down the server, otherwise it will end up blocking the main thread all the way until a new request/event comes in.

Routing Tree

For the routing part I decided to implement a Trie data structure. It’s a tree data structure that is commonly used for HTTP routing.

We use a RoutingNode to represent each node in the tree:

case class RoutingNode(
  segment: String,
  handler: Handler,
  children: collection.mutable.Map[String, RoutingNode],
)

Each RoutingNode can have multiple children, which is why I opted to use a Map to store them. Each segment refers to either the HTTP method (GET, POST etc.) or a part of the URL. So in the case of "https://domain.io/foo/bar", both "foo" and "bar" will be assigned to a RoutingNode. Each node points to an HTTP Handler, which in our case is a very simple trait definition:

trait Handler {
  def apply(request: Request): Response
}

case class Request(methodWithPath: String)
case class Response(status: String, body: String)

The handlers take a Request which will contain the HTTP method & path to the requested resource. And they’ll return the type Response where the field status is a status code & message, and the field body represents the HTTP body we want to send back to the client.

Then we have a Routes object that contains the logic for searching & inserting nodes into the tree.

object Routes {
  def traverse(root: RoutingNode, input: String)(segmentOp: (RoutingNode, String) => RoutingNode): Option[RoutingNode] = {
    var current = root

    segmentsOpt(input) match {
      case segments => {
        if (segments.length == 0) { return None }

        segments.foreach(segment => {
          current.children.get(segment) match {
            case Some(child) => current = child
            case None => {
              current = segmentOp(current, segment)
            }
          }
        })
      } case _ => {
        Logger.error(s"invalid search input => $input")
        return None
      }
    }
    Some(current)
  }

  def search(root: RoutingNode, input: String): Option[RoutingNode] = {
    traverse(root, input) { (_, _) =>
      return None
    }
  }

  def insert(root: RoutingNode, input: String, handler: Handler): Option[RoutingNode] = {
    val child = traverse(root, input) { (current, segment) =>
      current.children(segment) =
        RoutingNode(segment, handler, collection.mutable.Map.empty[String, RoutingNode])
      current.children(segment)
    }
    Option.when(child.isDefined)(root)  // the traversal function returns a child, but when inserting
                                        // we want to present the root node instead
  }

  // we return a list of (METHOD, PATH_SEGMENTS...) eg: (GET, foo, bar, baz)
  def segmentsOpt(input: String): Array[String] = {
    input.split(" ") match {
      case Array(method, path) => {
        Array(method) ++ path.split("/").filter(_.nonEmpty) // here we take the tail/nonEmpty because the first '/' creates an empty string
      }
      case _ => {
        Logger.error(s"invalid search input => $input")
        Array.empty
      }
    }
  }
}

The segmentsOpt function is a helper that takes our input (GET /foo/bar) and returns an array of segments where the leading / is omitted eg.: (GET, foo, bar, baz). We use the shared traverse helper for both the search & insert functions, which is accompanied by an operation (segmentOp). During traversal if a segment doesn’t belong to a node we apply the operation, which will simply return None for search, and insert a new node for the inserts. The traverse function returns the current node, which is great for searching, but we want to swap it out with the root node for insert operations.

Originally I implemented the search & insert functions without the traverse helper, but I found that the code was very repetitive, so this is my hacky way of DRYing it up.

Parsing Requests & Building Responses

object Server {
  def indexHandler: Handler = new Handler {
    override def apply(request: Request): Response = {
      Response("200 OK", "HELLO FROM INDEX")
    }
  }

  def fooHandler: Handler = new Handler {
    override def apply(request: Request): Response = {
      Response("200 OK", "HELLO FROM FOO")
    }
  }

  def barHandler: Handler = new Handler {
    override def apply(request: Request): Response = {
      Response("200 OK", "HELLO FROM BAR")
    }
  }

  ...

  // during server creation
  var root = RoutingNode("", indexHandler, collection.mutable.Map.empty)
  Routes.insert(root, "GET /", indexHandler)
  Routes.insert(root, "GET /foo", fooHandler)
  Routes.insert(root, "GET /foo/bar", barHandler)
}

Since this is a toy server I kept this part super simple. First I’ve defined some handlers, which implement the Handler trait defined earlier, and added them to our routing tree.

private def handleRequest(channel: SocketChannel) = {
  Logger.info(s"received a connection from ${channel.getRemoteAddress}")

  try {
    val buffer = ByteBuffer.allocate(256)

    val readBytes = channel.read(buffer)
    if (readBytes == -1) {
      channel.close()
    } else {
      buffer.flip()

      val request = parseRequest(buffer)
      val resp = Server.buildResponse(root, request)

      buffer.clear()
      buffer.put(resp)
      buffer.flip()

      channel.write(buffer)
      buffer.clear()
      channel.close()
    }
  } catch {
    case e: SocketException => {  // we're unable to talk to the socket anymore
      Logger.error(s"communicating with socket: ${e.printStackTrace()}")
    }
    case e: Exception => { // catch all errors
      Logger.error(s"handling connection: ${e.printStackTrace()}")
    }
  }
}

For reading the request we use the amazing ByteBuffer class that is part of the NIO package. It’s basically a byte array encapsulated in an object. It has very convenient methods around dealing with its various indices, transferring data between various data types & viewing the data. In this case I opted to use the same buffer for both reading the request and writing the response, and we currently have a limit of 256 bytes, which is just enough for our use case!

First, we read the incoming bytes and set our buffer to the beginning of the data, so that the parsing can start doing its thing from the beginning of the buffer. Then we have a function buildResponse that takes the root node, a request and returns the HTTP response in byte format. We clear our buffer, put in the response bytes, flip() it before writing it back to the channel.

The buildResponse itself looks like this:

def buildResponse(root: RoutingNode, request: Request): Array[Byte] = {
  val response = (Routes.search(root, request.methodWithPath)) match {
    case Some(node) => {
      node.handler(request)
    }
    case _ => {
      Response("404 NOT FOUND", EMPTY_BODY)
    }
  }

  val now = ZonedDateTime.now(ZoneId.of("GMT"))
  val formatter = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss z")
  val nowFormatted = now.format(formatter)

  val contents = (
    s"$PROTOCOL ${response.status}" + CRLF
      + s"Date: $nowFormatted" + CRLF
      + CONTENT_TYPE + CRLF
      + s"Content-Length: ${response.body.length}" + CRLF
      + SERVER_HEADER + CRLF
      + ALLOW_ORIGIN + CRLF
      + ALLOW_CREDENTIALS + CRLF
      + "Connection: close" + CRLF
      + CRLF
      + response.body
    )
  contents.getBytes("ISO-8859-1")
}

It’s a little bit hacky and hardcoded-y, but it’s good enough for our purposes!

And there you have it! A simple HTTP server that supports I/O multiplexing, implemented with Scala & Java NIO. I hope you liked it!!

For the full code refer to THIS LINK