"""This module provides an application-agnostic http/ws server.
See HybridServerBase for more information.
"""
import socket
from abc import ABC, abstractmethod
from pathlib import Path
import asyncio
import threading
import os
import webbrowser
import time
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
import uvicorn
__all__ = ["HybridServerBase"]
class HybridServerBase(ABC):
"""An application-agnostic http and websocket server.
This class lets the user register http request handlers at arbitrary
URLs. Additionally, requests at the `/ws` URL are automatically
upgraded to a persistent websocket connection. This allows the
server to send messages to clients without waiting for an http
request from the client. The server will also automatically serve a
directory of static files, specified by the `static_path`
constructor argument. These are served relative to the `/static`
URL. This class is built on top of fastapi and uvicorn but takes
care of the details of running the server in an async loop on a
background.
This class is abstract, so to use it, one must define a subclass
that implements the missing methods even if the implementations do
nothing. These include:
- onConnect: called when a websocket connection is created
- onMessage: called when receiving a websocket message
- onDisconnect: called when a websocket connection closes
Additionally, one may override the setupRoutes method to register
additional http request handlers. If overriding this method it's
recommended to call `super().setupRoutes()` to still setup the basic
handlers created by this class.
"""
def __init__(
self,
*,
host: str | None = None,
port: int = 8000,
autorun: bool = True,
static_path: Path | None = None,
log_level: str = "warning",
):
"""Create a HybridServerBase instance.
Parameters
----------
host: str | None
The publicly accessible address for the server.
If None, defaults to KARANA_WEBUI_ADDRESS if
defined or finally 127.0.0.1.
port: int
The port to bind to. The port must not already be
in use. Specify 0 to pick an arbitrary open port.
Defaults to 8000.
autorun: bool
Whether to automatically run the server after
construction. Otherwise the run method must be
called to start the server. Defaults to True.
static_path: Path | None
If not None, file directory to serve as static
files under `/static`.
log_level: str
The log verbosity level. This is a string such as
"info" or "warning" and is passed through to
uvicorn without modification. Defaults to
"warning".
"""
if host:
self.host = host
else:
self.host = os.getenv("KARANA_WEBUI_ADDRESS", "127.0.0.1")
if port != 0:
port_offset_str = os.getenv("KARANA_PORT_OFFSET", None)
port_offset = int(port_offset_str) if port_offset_str else 0
port += port_offset
self.user_port = port
self.static_path = static_path
self.log_level = log_level
self.server_error = None
self.app = FastAPI()
self.thread = None
self.clients = set()
self.loop = None # Server's event loop
self.server = None
self.setupRoutes()
if autorun:
self.run()
@property
def started(self) -> bool:
"""Whether the server has started.
There are really three possible states: not started,
starting but not ready, and ready. To simplify logic
and avoid race conditions, if the server is starting
but not fully ready this will block until the server
is fully ready then return True.
Returns
-------
bool
Whether the server is started and ready to use
"""
if not self.thread:
return False
while not self.server or not self.server.started:
if self.server_error:
raise RuntimeError(f"[visjs] Error starting server: {self.server_error}")
# the server is starting but isn't fully ready
time.sleep(0.1)
return True
[docs]
def run(self):
"""Idempotently start the server on a background thread."""
if self.thread is None:
self.thread = threading.Thread(target=self._run, daemon=True)
self.thread.start()
[docs]
def close(self, timeout: int | float = 5):
"""Idempotently shutdown the server and all connections.
Parameters
----------
timeout: int | float
Maximum wait time in seconds for the server thread to stop
"""
if self.loop and self.server:
# Signal uvicorn to shutdown
self.server.should_exit = True
# Wait for thread to exit
if self.thread:
self.thread.join(timeout=timeout)
if self.thread.is_alive():
print(f"[{self.port}] Warning: Server thread did not exit in time")
self.thread = None
self.loop = None
self.server = None
self.clients.clear()
[docs]
def __del__(self):
"""Ensure the server is shutdown upon going out of scope."""
try:
self.close()
except Exception:
pass
[docs]
def broadcast(self, message: str):
"""Send a message to all connected clients.
May be called from any thread.
"""
if not self.started:
raise RuntimeError("Server not running")
self.loop.call_soon_threadsafe(
lambda: asyncio.create_task(self._broadcastToClients(message))
)
[docs]
@abstractmethod
async def onConnect(self, websocket: WebSocket):
"""When a new websocket client connects, this method is called.
This must be overriden with the desired behavior, even if
that's a no-op.
"""
[docs]
@abstractmethod
async def onMessage(self, websocket: WebSocket, message: str):
"""When a message is received from a websocket client, this method is called.
This must be overriden with the desired behavior, even if
that's a no-op.
"""
[docs]
@abstractmethod
async def onDisconnect(self, websocket: WebSocket):
"""When a websocket client disconnects, this method is called.
This must be overriden with the desired behavior, even if
that's a no-op.
"""
[docs]
def setupRoutes(self):
"""Set up the API endpoints including the websocket handler.
This base implementation adds the following routes:
- `/ws`: upgrade into a websocket connection
- `/static/*`: recursively serve files in `self.static_path` if
not None
- `/`: if serving a file at `/static/index.html` also serve it
here
- `/health`: respond with 200 OK. This gives clients a
lightweight method of checking whether the server is up.
This may be overriden to add more routes, but in this case it is
recommended to call `super().setupRoutes()` so that the basic
routes added by this class are preserved.
"""
if self.static_path:
index_path = self.static_path / "index.html"
if index_path.is_file():
# Serve static_path / index.html at /
@self.app.get("/", response_class=FileResponse)
async def serveIndex():
return FileResponse(index_path)
# Recursively serve static_path under /static route
self.app.mount("/static", StaticFiles(directory=self.static_path), name="static")
@self.app.get("/health")
def healthCheck():
# send 200 OK and indicate not to cache the response
# anywhere downstream
return Response(status_code=200, headers={"Cache-Control": "no-store"})
@self.app.websocket("/ws")
async def websocketEndpoint(websocket: WebSocket):
# Upgrade the request to a websocket connection
await websocket.accept()
# Add the connection to the clients set
self.clients.add(websocket)
try:
# Call user's onConnect hook
await self.onConnect(websocket)
while True:
# Await client messages until connection is closed
message = await websocket.receive_text()
# Call user's onMessage hook
await self.onMessage(websocket, message)
except Exception as e:
pass
finally:
# Idempotently remove connection from the clients set
self.clients.discard(websocket)
# Call user's onDisconnect hook
await self.onDisconnect(websocket)
@property
def port(self) -> int:
"""Get the server's port.
Usually this is whatever was passed to the constructor, but for
port=0 this will be replaced with the actual port number after
the server is started.
Returns
-------
int
The port number
"""
if not self.started:
return self.user_port
for server in self.server.servers:
return server.sockets[0].getsockname()[1]
@property
def url(self) -> str:
"""Get the base URL of the server.
This includes the protocol, address, and port number. Note that
this is merely a best guess and may be inaccurate, for instance
if accessing the server via a proxy.
Returns
-------
str
The server's base url
"""
return f"http://{self.host}:{self.port}"
[docs]
def block(self, prompt: str = "Press Ctrl-C to shutdown the server"):
"""Block the thread from which this is called.
This is mainly intended as a way to keep the server alive when
the main thread is about to exit.
Parameters
----------
prompt: str
The message to print before blocking
"""
if not self.started:
raise RuntimeError("Server not running")
if prompt:
print(prompt)
try:
# Loop until an exception or the server shuts down somehow
while self.started:
# Don't spin the CPU
time.sleep(1)
except KeyboardInterrupt:
# Make Ctrl-C silent
pass
[docs]
def launchLocalClient(self):
"""Open a browser tab at the server's URL."""
webbrowser.open(self.url)
[docs]
def printConnectionInfo(self):
"""Print a message about how to connect to the server."""
if not self.started:
raise ValueError("Server is not running")
lines = [
f"[visjs] Web server is running on port {self.port}",
"You may be able to connect in your browser at:",
f"\t\033[1m{self.url}\033[0m", # ANSI bold
]
print("\n".join(lines))
def _run(self):
"""Entry point for the server thread."""
# Create and set the event loop for this thread
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.user_port, log_level=self.log_level, loop="asyncio"
)
try:
self.server = uvicorn.Server(config)
# Check that the port is free
if config.port != 0:
try:
s = socket.socket()
s.bind((config.host, config.port))
finally:
s.close()
self.loop.run_until_complete(self.server.serve())
except OSError as err:
self.server_error = str(err)
async def _broadcastToClients(self, message: str):
"""Async task to send a message to all connected clients.
Unless you know what you're doing, don't call this directly, and
use the broadcast method instead.
"""
# Async task to be executed on the server thread.
# Do not call from another context.
to_remove = []
for client in self.clients:
try:
await client.send_text(message)
except Exception as e:
print(f"[{self.port}] Failed to send message: {e}")
to_remove.append(client)
for client in to_remove:
self.clients.discard(client)