Source code for Karana.KUtils.visjs._baseserver

"""This module provides an application-agnostic http/ws server

See HybridServerBase for more information.

"""

from abc import ABC, abstractmethod
from pathlib import Path
import asyncio
import threading
import os
import webbrowser
import time

from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
import uvicorn

__all__ = ["HybridServerBase"]


class HybridServerBase(ABC):
    """An application-agnostic http and websocket server

    This class lets the user register http request handlers at arbitrary
    URLs. Additionally, requests at the `/ws` URL are automatically
    upgraded to a persistent websocket connection. This allows the
    server to send messages to clients without waiting for an http
    request from the client. The server will also automatically serve a
    directory of static files, specified by the `static_path`
    constructor argument. These are served relative to the `/static`
    URL. This class is built on top of fastapi and uvicorn but takes
    care of the details of running the server in an async loop on a
    background.

    This class is abstract, so to use it, one must define a subclass
    that implements the missing methods even if the implementations do
    nothing. These include:

    - onConnect: called when a websocket connection is created
    - onMessage: called when receiving a websocket message
    - onDisconnect: called when a websocket connection closes

    Additionally, one may override the setupRoutes method to register
    additional http request handlers. If overriding this method it's
    recommended to call `super().setupRoutes()` to still setup the basic
    handlers created by this class.

    """

    def __init__(
        self,
        *,
        host: str | None = None,
        port: int = 8000,
        autorun: bool = True,
        static_path: Path | None = None,
        log_level: str = "warning",
    ):
        """HybridServerBase constructor

        Parameters
        ----------
        host: str | None
            The publicly accessible address for the server.
            If None, defaults to KARANA_WEBUI_ADDRESS if
            defined or finally 127.0.0.1.
        port: int
            The port to bind to. The port must not already be
            in use. Specify 0 to pick an arbitrary open port.
            Defaults to 8000.
        autorun: bool
            Whether to automatically run the server after
            construction. Otherwise the run method must be
            called to start the server. Defaults to True.
        static_path: Path | None
            If not None, file directory to serve as static
            files under `/static`.
        log_level: str
            The log verbosity level. This is a string such as
            "info" or "warning" and is passed through to
            uvicorn without modification. Defaults to
            "warning".

        """
        if host:
            self.host = host
        else:
            self.host = os.getenv("KARANA_WEBUI_ADDRESS", "127.0.0.1")
        self.user_port = port
        self.static_path = static_path
        self.log_level = log_level

        self.app = FastAPI()
        self.thread = None
        self.clients = set()
        self.loop = None  # Server's event loop
        self.server = None

        self.setupRoutes()

        if autorun:
            self.run()

    @property
    def started(self) -> bool:
        """Whether the server has started

        There are really three possible states: not started,
        starting but not ready, and ready. To simplify logic
        and avoid race conditions, if the server is starting
        but not fully ready this will block until the server
        is fully ready then return True.

        Returns
        -------
        bool
            Whether the server is started and ready to use

        """
        if not self.thread:
            return False
        while not self.server or not self.server.started:
            # the server is starting but isn't fully ready
            time.sleep(0.1)
        return True

[docs] def run(self): """Idempotently start the server on a background thread""" if self.thread is None: self.thread = threading.Thread(target=self._run, daemon=True) self.thread.start()
[docs] def close(self, timeout: int | float = 5): """Idempotently shutdown the server and all connections Parameters ---------- timeout: int | float Maximum wait time in seconds for the server thread to stop """ if self.loop and self.server: # Signal uvicorn to shutdown self.server.should_exit = True # Wait for thread to exit if self.thread: self.thread.join(timeout=timeout) if self.thread.is_alive(): print(f"[{self.port}] Warning: Server thread did not exit in time") self.thread = None self.loop = None self.server = None self.clients.clear()
[docs] def __del__(self): """Ensure the server is shutdown upon going out of scope""" try: self.close() except Exception: pass
[docs] def broadcast(self, message: str): """Send a message to all connected clients May be called from any thread """ if not self.started: raise RuntimeError("Server not running") self.loop.call_soon_threadsafe( lambda: asyncio.create_task(self._broadcastToClients(message)) )
[docs] @abstractmethod async def onConnect(self, websocket: WebSocket): """Called just after a new websocket client connects This must be overriden with the desired behavior, even if that's a no-op. """
[docs] @abstractmethod async def onMessage(self, websocket: WebSocket, message: str): """Called whenever a message is received from a websocket client This must be overriden with the desired behavior, even if that's a no-op. """
[docs] @abstractmethod async def onDisconnect(self, websocket: WebSocket): """Called just after a websocket client disconnects This must be overriden with the desired behavior, even if that's a no-op. """
[docs] def setupRoutes(self): """Setup the API endpoints including the websocket handler This base implementation adds the following routes: - `/ws`: upgrade into a websocket connection - `/static/*`: recursively serve files in `self.static_path` if not None - `/`: if serving a file at `/static/index.html` also serve it here - `/health`: respond with 200 OK. This gives clients a lightweight method of checking whether the server is up. This may be overriden to add more routes, but in this case it is recommended to call `super().setupRoutes()` so that the basic routes added by this class are preserved. """ if self.static_path: index_path = self.static_path / "index.html" if index_path.is_file(): # Serve static_path / index.html at / @self.app.get("/", response_class=FileResponse) async def serve_index(): return FileResponse(index_path) # Recursively serve static_path under /static route self.app.mount("/static", StaticFiles(directory=self.static_path), name="static") @self.app.get("/health") def health_check(): # send 200 OK and indicate not to cache the response # anywhere downstream return Response(status_code=200, headers={"Cache-Control": "no-store"}) @self.app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): # Upgrade the request to a websocket connection await websocket.accept() # Add the connection to the clients set self.clients.add(websocket) try: # Call user's onConnect hook await self.onConnect(websocket) while True: # Await client messages until connection is closed message = await websocket.receive_text() # Call user's onMessage hook await self.onMessage(websocket, message) except Exception as e: pass finally: # Idempotently remove connection from the clients set self.clients.discard(websocket) # Call user's onDisconnect hook await self.onDisconnect(websocket)
@property def port(self) -> int: """Get the server's port Usually this is whatever was passed to the constructor, but for port=0 this will be replaced with the actual port number after the server is started. Returns ------- int The port number """ if not self.started: return self.user_port for server in self.server.servers: return server.sockets[0].getsockname()[1] @property def url(self) -> str: """Get the base URL of the server This includes the protocol, address, and port number. Note that this is merely a best guess and may be inaccurate, for instance if accessing the server via a proxy. Returns ------- str The server's base url """ return f"http://{self.host}:{self.port}"
[docs] def block(self, prompt: str = "Press Ctrl-C to shutdown the server"): """Block the thread from which this is called This is mainly intended as a way to keep the server alive when the main thread is about to exit. Parameters ---------- prompt: str The message to print before blocking """ if not self.started: raise RuntimeError("Server not running") if prompt: print(prompt) try: # Loop until an exception or the server shuts down somehow while self.started: # Don't spin the CPU time.sleep(1) except KeyboardInterrupt: # Make Ctrl-C silent pass
[docs] def launchLocalClient(self): """Open a browser tab at the server's URL""" webbrowser.open(self.url)
def _run(self): """Entry point for the server thread""" # Create and set the event loop for this thread self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) config = uvicorn.Config( self.app, host="0.0.0.0", port=self.user_port, log_level=self.log_level, loop="asyncio" ) self.server = uvicorn.Server(config) self.loop.run_until_complete(self.server.serve()) async def _broadcastToClients(self, message: str): """Async task to send a message to all connected clients Unless you know what you're doing, don't call this directly, and use the broadcast method instead. """ # Async task to be executed on the server thread. # Do not call from another context. to_remove = [] for client in self.clients: try: await client.send_text(message) except Exception as e: print(f"[{self.port}] Failed to send message: {e}") to_remove.append(client) for client in to_remove: self.clients.discard(client)