"""This module provides an application-agnostic http/ws server
See HybridServerBase for more information.
"""
from abc import ABC, abstractmethod
from pathlib import Path
import asyncio
import threading
import os
import webbrowser
import time
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
import uvicorn
__all__ = ["HybridServerBase"]
class HybridServerBase(ABC):
"""An application-agnostic http and websocket server
This class lets the user register http request handlers at arbitrary
URLs. Additionally, requests at the `/ws` URL are automatically
upgraded to a persistent websocket connection. This allows the
server to send messages to clients without waiting for an http
request from the client. The server will also automatically serve a
directory of static files, specified by the `static_path`
constructor argument. These are served relative to the `/static`
URL. This class is built on top of fastapi and uvicorn but takes
care of the details of running the server in an async loop on a
background.
This class is abstract, so to use it, one must define a subclass
that implements the missing methods even if the implementations do
nothing. These include:
- onConnect: called when a websocket connection is created
- onMessage: called when receiving a websocket message
- onDisconnect: called when a websocket connection closes
Additionally, one may override the setupRoutes method to register
additional http request handlers. If overriding this method it's
recommended to call `super().setupRoutes()` to still setup the basic
handlers created by this class.
"""
def __init__(
self,
*,
host: str | None = None,
port: int = 8000,
autorun: bool = True,
static_path: Path | None = None,
log_level: str = "warning",
):
"""HybridServerBase constructor
Parameters
----------
host: str | None
The publicly accessible address for the server.
If None, defaults to KARANA_WEBUI_ADDRESS if
defined or finally 127.0.0.1.
port: int
The port to bind to. The port must not already be
in use. Specify 0 to pick an arbitrary open port.
Defaults to 8000.
autorun: bool
Whether to automatically run the server after
construction. Otherwise the run method must be
called to start the server. Defaults to True.
static_path: Path | None
If not None, file directory to serve as static
files under `/static`.
log_level: str
The log verbosity level. This is a string such as
"info" or "warning" and is passed through to
uvicorn without modification. Defaults to
"warning".
"""
if host:
self.host = host
else:
self.host = os.getenv("KARANA_WEBUI_ADDRESS", "127.0.0.1")
self.user_port = port
self.static_path = static_path
self.log_level = log_level
self.app = FastAPI()
self.thread = None
self.clients = set()
self.loop = None # Server's event loop
self.server = None
self.setupRoutes()
if autorun:
self.run()
@property
def started(self) -> bool:
"""Whether the server has started
There are really three possible states: not started,
starting but not ready, and ready. To simplify logic
and avoid race conditions, if the server is starting
but not fully ready this will block until the server
is fully ready then return True.
Returns
-------
bool
Whether the server is started and ready to use
"""
if not self.thread:
return False
while not self.server or not self.server.started:
# the server is starting but isn't fully ready
time.sleep(0.1)
return True
[docs]
def run(self):
"""Idempotently start the server on a background thread"""
if self.thread is None:
self.thread = threading.Thread(target=self._run, daemon=True)
self.thread.start()
[docs]
def close(self, timeout: int | float = 5):
"""Idempotently shutdown the server and all connections
Parameters
----------
timeout: int | float
Maximum wait time in seconds for the server thread to stop
"""
if self.loop and self.server:
# Signal uvicorn to shutdown
self.server.should_exit = True
# Wait for thread to exit
if self.thread:
self.thread.join(timeout=timeout)
if self.thread.is_alive():
print(f"[{self.port}] Warning: Server thread did not exit in time")
self.thread = None
self.loop = None
self.server = None
self.clients.clear()
[docs]
def __del__(self):
"""Ensure the server is shutdown upon going out of scope"""
try:
self.close()
except Exception:
pass
[docs]
def broadcast(self, message: str):
"""Send a message to all connected clients
May be called from any thread
"""
if not self.started:
raise RuntimeError("Server not running")
self.loop.call_soon_threadsafe(
lambda: asyncio.create_task(self._broadcastToClients(message))
)
[docs]
@abstractmethod
async def onConnect(self, websocket: WebSocket):
"""Called just after a new websocket client connects
This must be overriden with the desired behavior, even if
that's a no-op.
"""
[docs]
@abstractmethod
async def onMessage(self, websocket: WebSocket, message: str):
"""Called whenever a message is received from a websocket client
This must be overriden with the desired behavior, even if
that's a no-op.
"""
[docs]
@abstractmethod
async def onDisconnect(self, websocket: WebSocket):
"""Called just after a websocket client disconnects
This must be overriden with the desired behavior, even if
that's a no-op.
"""
[docs]
def setupRoutes(self):
"""Setup the API endpoints including the websocket handler
This base implementation adds the following routes:
- `/ws`: upgrade into a websocket connection
- `/static/*`: recursively serve files in `self.static_path` if
not None
- `/`: if serving a file at `/static/index.html` also serve it
here
- `/health`: respond with 200 OK. This gives clients a
lightweight method of checking whether the server is up.
This may be overriden to add more routes, but in this case it is
recommended to call `super().setupRoutes()` so that the basic
routes added by this class are preserved.
"""
if self.static_path:
index_path = self.static_path / "index.html"
if index_path.is_file():
# Serve static_path / index.html at /
@self.app.get("/", response_class=FileResponse)
async def serve_index():
return FileResponse(index_path)
# Recursively serve static_path under /static route
self.app.mount("/static", StaticFiles(directory=self.static_path), name="static")
@self.app.get("/health")
def health_check():
# send 200 OK and indicate not to cache the response
# anywhere downstream
return Response(status_code=200, headers={"Cache-Control": "no-store"})
@self.app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
# Upgrade the request to a websocket connection
await websocket.accept()
# Add the connection to the clients set
self.clients.add(websocket)
try:
# Call user's onConnect hook
await self.onConnect(websocket)
while True:
# Await client messages until connection is closed
message = await websocket.receive_text()
# Call user's onMessage hook
await self.onMessage(websocket, message)
except Exception as e:
pass
finally:
# Idempotently remove connection from the clients set
self.clients.discard(websocket)
# Call user's onDisconnect hook
await self.onDisconnect(websocket)
@property
def port(self) -> int:
"""Get the server's port
Usually this is whatever was passed to the constructor, but for
port=0 this will be replaced with the actual port number after
the server is started.
Returns
-------
int
The port number
"""
if not self.started:
return self.user_port
for server in self.server.servers:
return server.sockets[0].getsockname()[1]
@property
def url(self) -> str:
"""Get the base URL of the server
This includes the protocol, address, and port number. Note that
this is merely a best guess and may be inaccurate, for instance
if accessing the server via a proxy.
Returns
-------
str
The server's base url
"""
return f"http://{self.host}:{self.port}"
[docs]
def block(self, prompt: str = "Press Ctrl-C to shutdown the server"):
"""Block the thread from which this is called
This is mainly intended as a way to keep the server alive when
the main thread is about to exit.
Parameters
----------
prompt: str
The message to print before blocking
"""
if not self.started:
raise RuntimeError("Server not running")
if prompt:
print(prompt)
try:
# Loop until an exception or the server shuts down somehow
while self.started:
# Don't spin the CPU
time.sleep(1)
except KeyboardInterrupt:
# Make Ctrl-C silent
pass
[docs]
def launchLocalClient(self):
"""Open a browser tab at the server's URL"""
webbrowser.open(self.url)
def _run(self):
"""Entry point for the server thread"""
# Create and set the event loop for this thread
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.user_port, log_level=self.log_level, loop="asyncio"
)
self.server = uvicorn.Server(config)
self.loop.run_until_complete(self.server.serve())
async def _broadcastToClients(self, message: str):
"""Async task to send a message to all connected clients
Unless you know what you're doing, don't call this directly, and
use the broadcast method instead.
"""
# Async task to be executed on the server thread.
# Do not call from another context.
to_remove = []
for client in self.clients:
try:
await client.send_text(message)
except Exception as e:
print(f"[{self.port}] Failed to send message: {e}")
to_remove.append(client)
for client in to_remove:
self.clients.discard(client)