Source code for Karana.KUtils.DataPlotter

# Copyright (c) 2024-2026 Karana Dynamics Pty Ltd. All rights reserved.
#
# NOTICE TO USER:
#
# This source code and/or documentation (the "Licensed Materials") is
# the confidential and proprietary information of Karana Dynamics Inc.
# Use of these Licensed Materials is governed by the terms and conditions
# of a separate software license agreement between Karana Dynamics and the
# Licensee ("License Agreement"). Unless expressly permitted under that
# agreement, any reproduction, modification, distribution, or disclosure
# of the Licensed Materials, in whole or in part, to any third party
# without the prior written consent of Karana Dynamics is strictly prohibited.
#
# THE LICENSED MATERIALS ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND.
# KARANA DYNAMICS DISCLAIMS ALL WARRANTIES, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-INFRINGEMENT, AND
# FITNESS FOR A PARTICULAR PURPOSE.
#
# IN NO EVENT SHALL KARANA DYNAMICS BE LIABLE FOR ANY DAMAGES WHATSOEVER,
# INCLUDING BUT NOT LIMITED TO LOSS OF PROFITS, DATA, OR USE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGES, WHETHER IN CONTRACT, TORT,
# OR OTHERWISE ARISING OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS.
#
# U.S. Government End Users: The Licensed Materials are a "commercial item"
# as defined at 48 C.F.R. 2.101, and are provided to the U.S. Government
# only as a commercial end item under the terms of this license.
#
# Any use of the Licensed Materials in individual or commercial software must
# include, in the user documentation and internal source code comments,
# this Notice, Disclaimer, and U.S. Government Use Provision.

"""Classes and functions used to stream data to a DashApp.

This module contains classes and functions used to stream data to a DashApp. These are useful for 
plotting data while a simulation is running. The DashApp is a useful format, as it allows one to 
view the results from a web browser, and that browser does not have to be on the same machine as
the machine generating the data. This is particularly useful when running a simulation on a 
remote machine without an X session.
"""

import os
from Karana.Core import debug, error
from threading import Thread, Event
from werkzeug.serving import make_server, WSGIRequestHandler
from typing import TYPE_CHECKING, Any, cast, overload
import numpy as np
from dash import Dash, html, Input, Output, dcc
from subprocess import run
from pathlib import Path
import plotly.graph_objects as go
from dash_extensions import WebSocket
import websockets
import asyncio
from json import dumps
from Karana.KUtils._KUtils_Py import SinglePlotData, PlotData


class _DashServer:
    """_DashSever is an internal class used to start/stop the server.

    Users should not create instances of this themselves. We do this in
    two classes to avoid reference cycles. Users need only delete the
    DashApp instance, and the server should shutdown automatically.
    """

    def __init__(self, title: str = "DataPlotter", host: str = "0.0.0.0", port: int = 8050):
        """Initialize the DashServer.

        Parameters
        ----------
        host: str
            The host to host the server on.
        port: int
            The port to use for the server.
        """
        self.host = host
        if self.host == "0.0.0.0":
            self.hostname = run(
                "hostname", shell=True, text=True, capture_output=True
            ).stdout.strip()
        else:
            self.hostname = self.host

        if port != 0:
            port_offset_str = os.getenv("KARANA_PORT_OFFSET", None)
            port_offset = int(port_offset_str) if port_offset_str else 0
            port += port_offset

        asset_folder = Path(__file__).parents[0].joinpath("assets")
        self._websocket_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self._websocket_loop)
        self.port: int
        self.app = Dash(__name__, title=title, assets_folder=str(asset_folder))

        self.update_clients = set()
        self.data_client = None
        self.refresh_clients = set()
        self._startServer(port)

        # Store the figures on the DashServer as well. If we keep them on the main, then we can't cleanup.
        self.figs: list[go.Figure] = []
        self.gs: list[dcc.Graph] = []
        self.fig_labels: list[str] = []

        # Counter for the figures
        self.figure_counter: int = 0

        # New data to be appended to plot. Only used by the DashApp class.
        self.new_plot_data: list[dict[str, Any]] = []

    def _runServer(self, port: int, event: Event):
        """Run the server.

        This is the method run by the _server_thread.
        """

        class SilentHandler(WSGIRequestHandler):
            def log(self, type, message, *args):
                pass  # Do nothing — suppress all logs

        try:
            self._server = make_server(
                self.host, port, self.app.server, request_handler=SilentHandler
            )
            # If port 0 was passed in, then make_server will select an available port.
            # If non-zero, then this should be the same as port anyway. Use this so
            # any printed messages have the real port number.
        except Exception:
            event.set()
            raise
        self.port = self._server.port
        event.set()
        self._server.serve_forever()

    def _runWebsocket(self, event: Event):
        """Run the websocket.

        This is the method run by the _websocket_thread.
        """

        async def handler(websocket):
            if websocket.request.path == "/data":
                if self.data_client is not None:
                    # Turn off pings. Sometimes the boost beast socket will not send pongs fast enough when
                    # the sim thread is bogged down.
                    websocket.ping_interval = None
                    try:
                        async for msg in websocket:
                            await self.data_client.send(msg)
                    except (
                        websockets.exceptions.ConnectionClosed
                    ):  # pragma: no cover - Exclude from code coverage, as this won't run in testing
                        pass
                    except (
                        Exception
                    ) as e:  # pragma: no cover - Exclude from code coverage, as this won't run in testing
                        error(
                            f"Something has caused the data client to error out. A common occurance is a NaN or inf value in a plotting function. Exception follows: {e}"
                        )

                else:
                    self.data_client = websocket
                    try:
                        async for _ in websocket:
                            # Keep the connection alive
                            pass  # pragma: no cover - This doesn't get hit in reg tests, nor should it.
                    except (
                        websockets.exceptions.ConnectionClosed
                    ):  # pragma: no cover - Exclude from code coverage, as this won't run in testing
                        pass
                    finally:
                        debug(lambda: "Client disconnected")
                        self.data_client = None

            else:
                if websocket.request.path == "/refresh":
                    self.refresh_clients.add(websocket)
                else:
                    self.update_clients.add(websocket)
                debug(lambda: "Client connected")
                try:
                    async for _ in websocket:
                        # Keep the connection alive
                        pass  # pragma: no cover - This doesn't get hit in reg tests, nor should it.
                except (
                    websockets.exceptions.ConnectionClosed
                ):  # pragma: no cover - Exclude from code coverage, as this won't run in testing
                    pass
                finally:
                    debug(lambda: "Client disconnected")
                    if websocket.request.path == "/refresh":
                        self.refresh_clients.remove(websocket)
                    else:
                        self.update_clients.remove(websocket)

        self._websocket_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self._websocket_loop)
        self._websocket_stop = asyncio.Event()
        self.websocket_port = asyncio.Future()

        async def startServer():
            server = await websockets.serve(handler, "0.0.0.0", 0)
            self.websocket_port.set_result(server.sockets[0].getsockname()[1])
            event.set()
            debug(lambda: f"Starting WebSocket server on port {self.websocket_port.result()}...")
            await self._websocket_stop.wait()
            # Before sutting down, send a message to shutdown the listener that updates the DashApp
            self.sendWebsocketMessage("SHUTDOWN", path="/data")
            server.close()
            await server.wait_closed()
            debug(lambda: f"WebSocket server on port {self.websocket_port.result()} closed")

        self._websocket_loop.run_until_complete(startServer())
        self._websocket_loop.stop()

    def _startServer(self, port: int):
        """Start the server.

        This creates a thread and adds _run to it.
        """
        if not self.isRunning():
            # We use daemon threads so atexit functions can run without waiting for these threads to finish.
            # Use events so we know when the ports are set
            e1 = Event()
            e2 = Event()
            self._server_thread = Thread(target=self._runServer, args=(port, e1))
            self._server_thread.daemon = True
            self._server_thread.start()

            self._websocket_thread = Thread(target=self._runWebsocket, args=(e2,))
            self._websocket_thread.daemon = True
            self._websocket_thread.start()
            e1.wait()
            e2.wait()
            print(f"Dash app started at http://{self.hostname}:{self.port}")

    def _stopServer(self):
        """Stop the server and join the associated thread."""
        if self._websocket_loop and not self._websocket_loop.is_closed():
            self._websocket_loop.call_soon_threadsafe(lambda: self._websocket_stop.set())
        if self._server:
            self._server.shutdown()
            self._server = None
            print("Dash server shutting down...")
        if self._server_thread:
            self._server_thread.join()
        if self._websocket_thread:
            self._websocket_thread.join()

    async def _sendWebsocketMessage(self, message: str, path="/update"):
        """Async function to ping the websocket."""
        to_remove = set()

        if path == "/update":
            clients = self.update_clients
        elif path == "/data":
            clients = [self.data_client] if self.data_client else []
        else:
            clients = self.refresh_clients
        for ws in clients.copy():
            try:
                await ws.send(message)
            except (
                websockets.exceptions.ConnectionClosed
            ):  # pragma: no cover - Exclude from code coverage, as this won't run in testing
                debug(lambda: "Client disconnected before send")
                to_remove.add(ws)

        clients.difference_update(to_remove)

    def pingWebsocket(self):
        """Ping the websocket.

        Pinging the websocket will also update the plots.
        """
        message = {"msg": "update graph!"}
        asyncio.run_coroutine_threadsafe(
            self._sendWebsocketMessage(dumps(message)), self._websocket_loop
        )

    def sendWebsocketMessage(self, message: str, path="/update"):
        """Ping the websocket.

        Pinging the websocket will also update the plots.
        """
        asyncio.run_coroutine_threadsafe(
            self._sendWebsocketMessage(message, path=path), self._websocket_loop
        )

    def isRunning(self) -> bool:
        """Return a boolean that indicates whether the server is running or not."""
        return hasattr(self, "_server_thread") and self._server_thread.is_alive()

    def __del__(
        self,
    ):  # pragma: no cover - Normal reference cycles keep this from being run. We keep it around in case users ever trigger a strange scenario where this is removed (in which case it should shut down the server).
        """Stop the server if it is running."""
        self._stopServer()


class _DashAppBase:
    """Base class for the DashApp* classes."""

    def __init__(
        self,
        title: str = "DataPlotter",
        host: str = "0.0.0.0",
        port: int = 8050,
    ):
        """Initialize the DashApp.

        This starts a server on the provided host and port. This server should shutdown
        automatically once this instance is deleted.

        Parameters
        ----------
        title: str
            The title of the DashApp webpage.
        host: str
            The host to host the server on.
        port: int
            The port to use for the server.
        """
        self._server = _DashServer(host=host, port=port, title=title)
        app = self._server.app
        app.callback
        self.url = f"http://{self._server.hostname}:{self._server.websocket_port.result()}/update"

        app.layout = [
            html.Header(
                [
                    html.H1(children=title, style={"textAlign": "center", "flex-grow": "1"}),
                    html.Button(
                        r"☀️",
                        title="Toggle light/dark theme",
                        id="theme-toggle",
                        className="toggle-button",
                    ),
                ],
                style={
                    "display": "flex",
                    "justify-content": "space-between",
                    "align-items": "center",
                    "margin-bottom": "20px",
                    "padding-bottom": "10px",
                },
            ),
            dcc.Store(id="update-data", data=0),
            WebSocket(url=self.url, id="ws"),
        ]

        app.clientside_callback(
            r"""
            function(n_clicks) {
                const body = document.body;
                const button = document.getElementById('theme-toggle');
                const current = body.getAttribute('data-theme') || 'light';
                if (current === 'light') {
                    body.setAttribute('data-theme', 'dark');
                    button.textContent = '\u{1F319}';
                } else {
                    body.setAttribute('data-theme', 'light');
                    button.textContent = '\u2600\uFE0F';
                }
                return '';
            }
            """,
            Output("theme-toggle", "n_clicks"),
            Input("theme-toggle", "n_clicks"),
        )

    def printConnectionInfo(self):
        """Print a message about how to connect to the server."""
        http_url = f"http://{self._server.hostname}:{self._server.port}"
        lines = [
            f"[DashApp] Web server is running on port {self._server.port}",
            "You may be able to connect in your browser at:",
            f"\t\033[1m{http_url}\033[0m",  # ANSI bold
        ]
        print("\n".join(lines))

    def __del__(self):
        """Delete method.

        This stops the server.'
        """
        if hasattr(self, "_server"):
            self._server._stopServer()


[docs] class DashAppStatic(_DashAppBase): """Use this class to create a DashApp servers up static figures. These figures can be changed manually using the figs property. However, they are not designed to be updated regularly. For that, please use the DashApp class. """ def __init__( self, figs: list[go.Figure], title: str = "DataPlotter", host: str = "0.0.0.0", port: int = 8050, ): """Initialize the DashApp. This starts a server on the provided host and port. This server should shutdown automatically once this instance is deleted. Parameters ---------- title: str The title of the DashApp webpage. host: str The host to host the server on. port: int The port to use for the server. """ super().__init__(title=title, host=host, port=port) self.figs = figs self._createPlots(self._server, self._server.app) @property def figs(self) -> list[go.Figure]: """Figures shown by the DashApp.""" return self._server.figs @figs.setter def figs(self, figs: list[go.Figure]): self._server.figs = figs self._server.pingWebsocket() # Using a static method here so we don't create reference cycles that prevent d from being deleted. @staticmethod def _createPlots(server, app): gs = [ dcc.Graph(id=fig.layout.title.text or str(k), figure=fig) for k, fig in enumerate(server.figs) ] app.layout.append(html.Div(id="plots-div", children=gs)) @app.callback(Output("plots-div", "children"), Input("ws", "message")) def updatePlots(_): gs = [ dcc.Graph(id=fig.layout.title.text or str(k), figure=fig) for k, fig in enumerate(server.figs) ] return gs
class _DashAppClient(_DashAppBase): """Use this class to create a DashApp that you can stream data to. This class will automatically create empty plots based on the `SinglePlotData`s provided. In addition, it will start a websocket and http server for clients to connect to view the plots. Calling `update` will update the data of the plots based on the data provided in the `SinglePlotData`s. """ def __init__( self, data: list[dict[str, Any]], title: str = "DataPlotter", host: str = "0.0.0.0", port: int = 8050, ): """Initialize the DashApp. This starts a server on the provided host and port. This server should shutdown automatically once this instance is deleted. Parameters ---------- data : list[SinglePlotData] Data to be plotted and update functions. title: str The title of the DashApp webpage. host: str The host to host the server on. port: int The port to use for the server. """ super().__init__(title=title, host=host, port=port) # Add the plots for d in data: # Using underscore method here to avoid calling the refresh method, since it doesn't exist yet. self._addPlot(d) # Add extra HTML elements self._server.app.layout.append( html.Div( [ WebSocket( url=f"http://{self._server.hostname}:{self._server.websocket_port.result()}/refresh", id="ws-refresh", ), html.Div( id="ws-url-div", children=f"ws://{self._server.hostname}:{self._server.websocket_port.result()}/update", style={"display": "none"}, ), html.Div(id="plots-div"), ] ) ) # Use the static method to create callbacks. This avoids tying up self, which leads to reference cycles # that cause issues when tearing things down. self._createPlots(self._server, self._server.app) @staticmethod def _createPlots(server, app): # This callback refreshes the plots entirely @app.callback(Output("plots-div", "children"), Input("ws-refresh", "message")) def refreshPlots(msg): server.gs = [ dcc.Graph(id=label, figure=fig) for label, fig in zip(server.fig_labels, server.figs) ] return server.gs # Refresh the plots with the initial data server.sendWebsocketMessage(dumps({"msg": "update graph!"}), path="/refresh") def update(self, inc_data: list[dict[str, Any]]): """Update the plots. This updates the figures with the data provided in inc_data. Then, it pings the websocket to trigger an update for all the connected clients. Parameters ---------- inc_data: list[dict[str, Any]] The data used to update the plots. """ k = 0 self._server.new_plot_data = [] for fig, plot_id, data in zip(self._server.figs, self._server.fig_labels, inc_data): x_data = data["x_data"] trace_count = 0 for val in data["y_data"]: if isinstance(val, float): fig.data[trace_count].x = list(fig.data[trace_count].x) + [x_data] fig.data[trace_count].y = list(fig.data[trace_count].y) + [val] self._server.new_plot_data.append( {"plot_id": plot_id, "trace_index": str(trace_count), "x": x_data, "y": val} ) trace_count += 1 else: for y in val: fig.data[trace_count].x = list(fig.data[trace_count].x) + [x_data] fig.data[trace_count].y = list(fig.data[trace_count].y) + [y] self._server.new_plot_data.append( { "plot_id": plot_id, "trace_index": str(trace_count), "x": x_data, "y": y, } ) trace_count += 1 k += 1 self._server.sendWebsocketMessage(dumps(self._server.new_plot_data)) def addPlot(self, data: dict[str, Any]): """Add a new plot the window using the associated plot data. Parameters ---------- data : dict[str, Any] Example data associated with the plot. """ self._addPlot(data) self._server.sendWebsocketMessage(dumps({"msg": "new plot added."}), path="/refresh") def removePlot(self, index: int): """Remove the plot at the associated index. The index corresponds to the figure in the slot of self._server.figs. Parameters ---------- index : int The index of the plot to remove. """ self._server.figs.pop(index) self._server.fig_labels.pop(index) self._server.gs.pop(index) self._server.sendWebsocketMessage(dumps({"msg": "removing plot"}), path="/refresh") def _addPlot(self, data: dict[str, Any]): """Add a new plot the window using the associated plot data. This does not call the plot refresh. Use the addPlot method for that. Parameters ---------- data : dict[str, Any] The data associated with updating the plot. """ # Create the figure based on the incoming data fig = go.Figure() for j, val in data["y_data"].items(): # Handle NaN values coming in. This happens sometimes if a plot is added before # things are initialized. if val is None: val = 0.0 if isinstance(val, float): data_len = 1 else: data_len = len(val) if data_len == 1: fig.add_trace(go.Scatter(x=[], y=[], mode="lines+markers", name=j)) else: for i in range(data_len): fig.add_trace(go.Scatter(x=[], y=[], mode="lines+markers", name=f"{j}[{i}]")) fig.update_layout( template="plotly_dark", xaxis_title=data["x_data_name"], title=data["title"] ) # Add this to the server's list of figures self._server.figs.append(fig) label = f"karana-plot-{self._server.figure_counter}" self._server.fig_labels.append(label) # Add this to the server's Graph objects self._server.gs.append(dcc.Graph(id=label, figure=fig)) # Increment the figure counter self._server.figure_counter += 1 if TYPE_CHECKING: from multiprocessing import Event, Queue from multiprocessing.synchronize import Event as Evt def _startProcess( shutdown: "Evt", change_plot: "Evt", q: "Queue", data: dict[str, Any], title: str, host: str, port: int, ): """Start function for the DashApp multiprocessing process. Parameters ---------- shutdown : Evt An event to indicate shutdown. change_plot : Evt An event to indicate we are changing something fundamental about the plots. q : Queue A multiprocessing Queue to exchange data. data : list[SinglePlotData] The data to use to create the plots. title : str The title of the DashApp. host : str The host for the DashApp. port : int The port to server the DashApp on. """ client = _DashAppClient(data, title, host, port) import asyncio from websockets.asyncio.client import connect from json import loads e = Event() async def getData(): # Use the 'async with' statement to ensure the connection is closed properly uri = f"ws://{client._server.hostname}:{client._server.websocket_port.result()}/data" # Replace with your WebSocket server URI async with connect(uri) as websocket: e.set() # Receive a message while True: message = loads(await websocket.recv()) if isinstance(message, list): # Put the update first, as this is the most likely client.update(message) elif isinstance(message, dict) and "REMOVE_PLOT" in message: client.removePlot(message["REMOVE_PLOT"]) change_plot.set() elif isinstance(message, dict) and "ADD_PLOT" in message: client.addPlot(message["ADD_PLOT"]) change_plot.set() elif isinstance(message, dict) and "GET_FIGS" in message: q.put(client._server.figs) elif isinstance(message, dict) and "SET_FIGS" in message: client._server.figs = q.get() client._server.sendWebsocketMessage( dumps({"msg": "updating plot"}), path="/refresh" ) change_plot.set() elif message == "SHUTDOWN": break # Run this loop to process data and pass it to the client asyncio.run_coroutine_threadsafe(getData(), client._server._websocket_loop) e.wait() # Send data back to the parent process q.put(f"http://{client._server.hostname}:{client._server.port}") q.put(client._server.hostname) q.put(client._server.websocket_port.result()) # Wait until we get the signal from the parent to shutdown shutdown.wait() # Explicitly stop the server. del just removes things from the local scope, but # the process may exit after that since interpolation is techincally done (while # logic from __del__ is still running). This ensures the server is shutdown # every time before closing. client._server._stopServer() # Delete and close everything del client
[docs] class DashApp(PlotData): """Use this class to create a DashApp that you can stream data to. This class will automatically create empty plots based on the `SinglePlotData`s provided. In addition, it will start a websocket and http server for clients to connect to view the plots. Calling `update` will update the data of the plots based on the data provided in the `SinglePlotData`s. """ def __init__( self, data: list[SinglePlotData], title: str = "DataPlotter", host: str = "0.0.0.0", port: int = 8050, ): """Initialize the DashApp. This starts a server on the provided host and port. This server should shutdown automatically once this instance is deleted. Parameters ---------- data : list[SinglePlotData] Data to be plotted and update functions. title: str The title of the DashApp webpage. host: str The host to host the server on. port: int The port to use for the server. """ from multiprocessing import Event, Process, Queue self.shutdown_event = shutdown = Event() self.change_plot = change_plot = Event() q = self._mp_queue = Queue() # Get data for the initial plots dat = [] for d in data: dark = { "title": d.title, "x_data_name": d.x_data_name, "x_data": d.x_data(), "y_data": {}, } for k, val in zip(d.y_data_names, d.y_data_fns): v = val() if isinstance(v, np.ndarray): v = v.tolist() dark["y_data"][k] = v dat.append(dark) self.p = Process( target=_startProcess, args=(shutdown, change_plot, q, dat, title, host, port) ) self.p.daemon = True self.p.start() # Wait on the values from the queue # This is the server url self.url = q.get() # This is the websocket hostname/port hostname = cast(str, q.get()) p = str(q.get()) self._ws_url = f"ws://{hostname}:{port}/data" # Call the parent constructor super().__init__(data, hostname, p, "/data") @overload def removePlot(self, title: str, /): """Remove the plot with the associated title. This will only work if there is one and only one title associated with the plot. """ ... @overload def removePlot(self, index: int, /): """Remove the plot at the associated index. The index corresponds to the figure in the slot of self._server.figs. Parameters ---------- index : int The index of the plot to remove. """ ...
[docs] def removePlot( self, arg: int | str, / ): # pyright: ignore - override with different signature on purpuse """Remove the plot based on the value of arg. See overloads for more details. """ index = None if isinstance(arg, str): for k, d in enumerate(self.plot_fns): if d.title == arg: if index is not None: raise ValueError(f"More than one plot has the title {arg}") else: index = k if index is None: raise ValueError(f"No match found for title {arg}") else: index = arg self.change_plot.clear() super().removePlot(index) self.change_plot.wait()
[docs] def addPlot(self, plot_data: SinglePlotData): """Add a plot. Parameters ---------- plot_fns : SinglePlotData The plotting functions associated with the new plot. """ self.change_plot.clear() super().addPlot(plot_data) self.change_plot.wait()
@property def figs(self) -> list[go.Figure]: """Get the figures associated with this DashApp.""" self.sendWebsocketMessage(dumps({"GET_FIGS": 1})) return self._mp_queue.get() @figs.setter def figs(self, figs: list[go.Figure]): """Set the figures associated with this DashApp.""" self._mp_queue.put(figs) self.sendWebsocketMessage(dumps({"SET_FIGS": 1})) self.change_plot.wait()
[docs] def __del__(self): """Shut everything down.""" self.shutdown() if self.p.is_alive(): # If p is still alive, try to shut it down cleanly. self.shutdown_event.set() if os.environ.get("DTEST_RUNNING", False): # If we are running tests in parallel, then don't set a timeout. # The strain of the tests means this may take a little while if the CPUs # are maxed. self.p.join() else: self.p.join(timeout=10.0)