# Copyright (c) 2024-2025 Karana Dynamics Pty Ltd. All rights reserved.
#
# NOTICE TO USER:
#
# This source code and/or documentation (the "Licensed Materials") is
# the confidential and proprietary information of Karana Dynamics Inc.
# Use of these Licensed Materials is governed by the terms and conditions
# of a separate software license agreement between Karana Dynamics and the
# Licensee ("License Agreement"). Unless expressly permitted under that
# agreement, any reproduction, modification, distribution, or disclosure
# of the Licensed Materials, in whole or in part, to any third party
# without the prior written consent of Karana Dynamics is strictly prohibited.
#
# THE LICENSED MATERIALS ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND.
# KARANA DYNAMICS DISCLAIMS ALL WARRANTIES, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-INFRINGEMENT, AND
# FITNESS FOR A PARTICULAR PURPOSE.
#
# IN NO EVENT SHALL KARANA DYNAMICS BE LIABLE FOR ANY DAMAGES WHATSOEVER,
# INCLUDING BUT NOT LIMITED TO LOSS OF PROFITS, DATA, OR USE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGES, WHETHER IN CONTRACT, TORT,
# OR OTHERWISE ARISING OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS.
#
# U.S. Government End Users: The Licensed Materials are a "commercial item"
# as defined at 48 C.F.R. 2.101, and are provided to the U.S. Government
# only as a commercial end item under the terms of this license.
#
# Any use of the Licensed Materials in individual or commercial software must
# include, in the user documentation and internal source code comments,
# this Notice, Disclaimer, and U.S. Government Use Provision.
"""Classes and functions used to stream data to a DashApp.
This module contains classes and functions used to stream data to a DashApp. These are useful for
plotting data while a simulation is running. The DashApp is a useful format, as it allows one to
view the results from a web browser, and that browser does not have to be on the same machine as
the machine generating the data. This is particularly useful when running a simulation on a
remote machine without an X session.
"""
import os
from Karana.Core import debug
from threading import Thread, Event
from werkzeug.serving import make_server, WSGIRequestHandler
from typing import Callable, Any, overload
from Karana.KUtils.DataStruct import DataStruct
from Karana.KUtils.Ktyping import Vec
from dash import Dash, html, Input, Output, dcc, ClientsideFunction
from subprocess import run
from pathlib import Path
import plotly.graph_objects as go
from dash_extensions import WebSocket
import websockets
import asyncio
from copy import deepcopy
from json import dumps
[docs]
class PlotDataDS(DataStruct):
"""Holds functions to update the plot data for the DataPlotter.
Parameters
----------
title : str
Title of the graph.
x_data_name : str
The name of the x-axis.
x_data : Callable[[], float]
The callable used to get the x-axis data.
y_data : dict[str, Callable[[], Vec | float]]
This dictionary holds the names and callables for a plot. The name will be the
name used in the legend. The callable gets the data. If the data is a vector with size
greater than 1, then the name will be original name plus an index.
"""
title: str
x_data: Callable[[], float]
x_data_name: str
y_data: dict[str, Callable[[], Vec | float]]
class _DashServer:
"""_DashSever is an internal class used to start/stop the server.
Users should not create instances of this themselves. We do this in
two classes to avoid reference cycles. Users need only delete the
DashApp instance, and the server should shutdown automatically.
"""
def __init__(self, title: str = "DataPlotter", host: str = "0.0.0.0", port: int = 8050):
"""Initialize the DashServer.
Parameters
----------
host: str
The host to host the server on.
port: int
The port to use for the server.
"""
self.host = host
if self.host == "0.0.0.0":
self.hostname = run(
"hostname", shell=True, text=True, capture_output=True
).stdout.strip()
else:
self.hostname = self.host
if port != 0:
port_offset_str = os.getenv("KARANA_PORT_OFFSET", None)
port_offset = int(port_offset_str) if port_offset_str else 0
port += port_offset
asset_folder = Path(__file__).parents[0].joinpath("assets")
self._websocket_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._websocket_loop)
self.port: int
self.app = Dash(__name__, title=title, assets_folder=str(asset_folder))
self.update_clients = set()
self.refresh_clients = set()
self._startServer(port)
# Store the figures on the DashServer as well. If we keep them on the main, then we can't cleanup.
self.figs: list[go.Figure] = []
self.gs: list[dcc.Graph] = []
self.fig_labels: list[str] = []
# Counter for the figures
self.figure_counter: int = 0
# New data to be appended to plot. Only used by the DashApp class.
self.new_plot_data: list[dict[str, Any]] = []
def _runServer(self, port: int, event: Event):
"""Run the server.
This is the method run by the _server_thread.
"""
class SilentHandler(WSGIRequestHandler):
def log(self, type, message, *args):
pass # Do nothing — suppress all logs
try:
self._server = make_server(
self.host, port, self.app.server, request_handler=SilentHandler
)
# If port 0 was passed in, then make_server will select an available port.
# If non-zero, then this should be the same as port anyway. Use this so
# any printed messages have the real port number.
except Exception:
event.set()
raise
self.port = self._server.port
event.set()
self._server.serve_forever()
def _runWebsocket(self, event: Event):
"""Run the websocket.
This is the method run by the _websocket_thread.
"""
async def handler(websocket):
if websocket.request.path == "/refresh":
self.refresh_clients.add(websocket)
else:
self.update_clients.add(websocket)
debug(lambda: "Client connected")
try:
async for _ in websocket:
# Keep the connection alive
pass # pragma: no cover - This doesn't get hit in reg tests, nor should it.
except (
websockets.exceptions.ConnectionClosed
): # pragma: no cover - Exclude from code coverage, as this won't run in testing
pass
finally:
debug(lambda: "Client disconnected")
if websocket.request.path == "/refresh":
self.refresh_clients.remove(websocket)
else:
self.update_clients.remove(websocket)
self._websocket_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._websocket_loop)
self._websocket_stop = asyncio.Event()
self.websocket_port = asyncio.Future()
async def startServer():
server = await websockets.serve(handler, "0.0.0.0", 0)
self.websocket_port.set_result(server.sockets[0].getsockname()[1])
event.set()
debug(lambda: f"Starting WebSocket server on port {self.websocket_port.result()}...")
await self._websocket_stop.wait()
server.close()
await server.wait_closed()
debug(lambda: f"WebSocket server on port {self.websocket_port.result()} closed")
self._websocket_loop.run_until_complete(startServer())
self._websocket_loop.stop()
def _startServer(self, port: int):
"""Start the server.
This creates a thread and adds _run to it.
"""
if not self.isRunning():
# We use daemon threads so atexit functions can run without waiting for these threads to finish.
# Use events so we know when the ports are set
e1 = Event()
e2 = Event()
self._server_thread = Thread(target=self._runServer, args=(port, e1))
self._server_thread.daemon = True
self._server_thread.start()
self._websocket_thread = Thread(target=self._runWebsocket, args=(e2,))
self._websocket_thread.daemon = True
self._websocket_thread.start()
e1.wait()
e2.wait()
print(f"Dash app started at http://{self.hostname}:{self.port}")
def _stopServer(self):
"""Stop the server and join the associated thread."""
if self._websocket_loop and not self._websocket_loop.is_closed():
self._websocket_loop.call_soon_threadsafe(lambda: self._websocket_stop.set())
if self._server:
self._server.shutdown()
self._server = None
print("Dash server shutting down...")
if self._server_thread:
self._server_thread.join()
if self._websocket_thread:
self._websocket_thread.join()
async def _sendWebsocketMessage(self, message: str, path="/update"):
"""Async function to ping the websocket."""
to_remove = set()
if path == "/update":
clients = self.update_clients
else:
clients = self.refresh_clients
for ws in clients.copy():
try:
await ws.send(message)
except (
websockets.exceptions.ConnectionClosed
): # pragma: no cover - Exclude from code coverage, as this won't run in testing
debug(lambda: "Client disconnected before send")
to_remove.add(ws)
clients.difference_update(to_remove)
def pingWebsocket(self):
"""Ping the websocket.
Pinging the websocket will also update the plots.
"""
message = {"msg": "update graph!"}
asyncio.run_coroutine_threadsafe(
self._sendWebsocketMessage(dumps(message)), self._websocket_loop
)
def sendWebsocketMessage(self, message: str, path="/update"):
"""Ping the websocket.
Pinging the websocket will also update the plots.
"""
asyncio.run_coroutine_threadsafe(
self._sendWebsocketMessage(message, path=path), self._websocket_loop
)
def isRunning(self) -> bool:
"""Return a boolean that indicates whether the server is running or not."""
return hasattr(self, "_server_thread") and self._server_thread.is_alive()
def __del__(
self,
): # pragma: no cover - Normal reference cycles keep this from being run. We keep it around in case users ever trigger a strange scenario where this is removed (in which case it should shut down the server).
"""Stop the server if it is running."""
self._stopServer()
class _DashAppBase:
"""Base class for the DashApp* classes."""
def __init__(
self,
title: str = "DataPlotter",
host: str = "0.0.0.0",
port: int = 8050,
):
"""Initialize the DashApp.
This starts a server on the provided host and port. This server should shutdown
automatically once this instance is deleted.
Parameters
----------
title: str
The title of the DashApp webpage.
host: str
The host to host the server on.
port: int
The port to use for the server.
"""
self._server = _DashServer(host=host, port=port, title=title)
app = self._server.app
app.callback
self.url = f"http://{self._server.hostname}:{self._server.websocket_port.result()}/update"
app.layout = [
html.Header(
[
html.H1(children=title, style={"textAlign": "center", "flex-grow": "1"}),
html.Button(
r"☀️",
title="Toggle light/dark theme",
id="theme-toggle",
className="toggle-button",
),
],
style={
"display": "flex",
"justify-content": "space-between",
"align-items": "center",
"margin-bottom": "20px",
"padding-bottom": "10px",
},
),
dcc.Store(id="update-data", data=0),
WebSocket(url=self.url, id="ws"),
]
app.clientside_callback(
r"""
function(n_clicks) {
const body = document.body;
const button = document.getElementById('theme-toggle');
const current = body.getAttribute('data-theme') || 'light';
if (current === 'light') {
body.setAttribute('data-theme', 'dark');
button.textContent = '\u{1F319}';
} else {
body.setAttribute('data-theme', 'light');
button.textContent = '\u2600\uFE0F';
}
return '';
}
""",
Output("theme-toggle", "n_clicks"),
Input("theme-toggle", "n_clicks"),
)
def printConnectionInfo(self):
"""Print a message about how to connect to the server."""
http_url = f"http://{self._server.hostname}:{self._server.port}"
lines = [
f"[DashApp] Web server is running on port {self._server.port}",
"You may be able to connect in your browser at:",
f"\t\033[1m{http_url}\033[0m", # ANSI bold
]
print("\n".join(lines))
def __del__(self):
"""Delete method.
This stops the server.'
"""
if hasattr(self, "_server"):
self._server._stopServer()
[docs]
class DashAppStatic(_DashAppBase):
"""Use this class to create a DashApp servers up static figures.
These figures can be changed manually using the figs property. However, they are not designed
to be updated regularly. For that, please use the DashApp class.
"""
def __init__(
self,
figs: list[go.Figure],
title: str = "DataPlotter",
host: str = "0.0.0.0",
port: int = 8050,
):
"""Initialize the DashApp.
This starts a server on the provided host and port. This server should shutdown
automatically once this instance is deleted.
Parameters
----------
title: str
The title of the DashApp webpage.
host: str
The host to host the server on.
port: int
The port to use for the server.
"""
super().__init__(title=title, host=host, port=port)
self.figs = figs
self._createPlots(self._server, self._server.app)
@property
def figs(self) -> list[go.Figure]:
"""Figures shown by the DashApp."""
return self._server.figs
@figs.setter
def figs(self, figs: list[go.Figure]):
self._server.figs = figs
self._server.pingWebsocket()
# Using a static method here so we don't create reference cycles that prevent d from being deleted.
@staticmethod
def _createPlots(server, app):
gs = [
dcc.Graph(id=fig.layout.title.text or str(k), figure=fig)
for k, fig in enumerate(server.figs)
]
app.layout.append(html.Div(id="plots-div", children=gs))
@app.callback(Output("plots-div", "children"), Input("ws", "message"))
def updatePlots(_):
gs = [
dcc.Graph(id=fig.layout.title.text or str(k), figure=fig)
for k, fig in enumerate(server.figs)
]
return gs
[docs]
class DashApp(_DashAppBase):
"""Use this class to create a DashApp that you can stream data to.
This class will automatically create empty plots based on the `PlotDataDS` provided. In addition,
it will start a websocket and http server for clients to connect to view the plots. Calling
`update` will update the data of the plots based on the data provided in the `PlotDataDS`.
"""
def __init__(
self,
data: list[PlotDataDS],
title: str = "DataPlotter",
host: str = "0.0.0.0",
port: int = 8050,
):
"""Initialize the DashApp.
This starts a server on the provided host and port. This server should shutdown
automatically once this instance is deleted.
Parameters
----------
data : list[PlotDataDS]
Data to be plotted and update functions.
title: str
The title of the DashApp webpage.
host: str
The host to host the server on.
port: int
The port to use for the server.
"""
super().__init__(title=title, host=host, port=port)
# Add the data to the server
self.data: list[PlotDataDS] = []
for d in data:
# Using underscore method here to avoid calling the refresh method, since it doesn't exist yet.
self._addPlot(d)
# Add extra HTML elements
self._server.app.layout.append(
html.Div(
[
WebSocket(
url=f"http://{self._server.hostname}:{self._server.websocket_port.result()}/refresh",
id="ws-refresh",
),
html.Div(
id="ws-url-div",
children=f"ws://{self._server.hostname}:{self._server.websocket_port.result()}/update",
style={"display": "none"},
),
html.Div(id="plots-div"),
]
)
)
# Use the static method to create callbacks. This avoids tying up self, which leads to reference cycles
# that cause issues when tearing things down.
self._createPlots(self._server, self._server.app)
@staticmethod
def _createPlots(server, app):
# This callback refreshes the plots entirely
@app.callback(Output("plots-div", "children"), Input("ws-refresh", "message"))
def refreshPlots(msg):
return server.gs
# Refresh the plots with the initial data
server.sendWebsocketMessage(dumps({"msg": "update graph!"}), path="/refresh")
[docs]
def update(self):
"""Update the plots.
This updates the figures with the data provided in `PlotDataDS`. Then, it pings the
websocket to trigger an update for all the connected clients.
"""
k = 0
self._server.new_plot_data = []
for fig, plot_id, data in zip(self._server.figs, self._server.fig_labels, self.data):
x_data = data.x_data()
trace_count = 0
for v in data.y_data.values():
val = v()
if isinstance(val, float):
fig.data[trace_count].x = list(fig.data[trace_count].x) + [x_data]
fig.data[trace_count].y = list(fig.data[trace_count].y) + [val]
self._server.new_plot_data.append(
{"plot_id": plot_id, "trace_index": str(trace_count), "x": x_data, "y": val}
)
trace_count += 1
else:
for y in val:
fig.data[trace_count].x = list(fig.data[trace_count].x) + [x_data]
fig.data[trace_count].y = list(fig.data[trace_count].y) + [y]
self._server.new_plot_data.append(
{
"plot_id": plot_id,
"trace_index": str(trace_count),
"x": x_data,
"y": y,
}
)
trace_count += 1
k += 1
self._server.sendWebsocketMessage(dumps(self._server.new_plot_data))
[docs]
def addPlot(self, data: PlotDataDS):
"""Add a new plot the window using the associated plot data.
Parameters
----------
data : PlotDataDS
The data associated with updating the plot.
"""
self._addPlot(data)
self._server.sendWebsocketMessage(dumps({"msg": "new plot added."}), path="/refresh")
@overload
def removePlot(self, title: str):
"""Remove the plot with the associated title.
This will only work if there is one and only one title associated
with the plot.
"""
...
@overload
def removePlot(self, index: int):
"""Remove the plot at the associated index.
The index corresponds to the figure in the slot of self._server.figs.
Parameters
----------
index : int
The index of the plot to remove.
"""
...
[docs]
def removePlot(self, arg):
"""Remove the plot based on the value of arg.
See overloads for more details.
"""
index = None
if isinstance(arg, str):
for k, d in enumerate(self.data):
if d.title == arg:
if index is not None:
raise ValueError(f"More than one plot has the title {arg}")
else:
index = k
if index is None:
raise ValueError(f"No match found for title {arg}")
else:
index = arg
self._server.figs.pop(index)
self._server.fig_labels.pop(index)
self._server.gs.pop(index)
self.data.pop(index)
self._server.sendWebsocketMessage(dumps({"msg": "removing plot"}), path="/refresh")
def _addPlot(self, data: PlotDataDS):
"""Add a new plot the window using the associated plot data.
This does not call the plot refresh. Use the addPlot method for that.
Parameters
----------
data : PlotDataDS
The data associated with updating the plot.
"""
# Create the figure based on the incoming data
fig = go.Figure()
for j, v in data.y_data.items():
val = v()
if isinstance(val, float):
data_len = 1
else:
data_len = len(v())
if data_len == 1:
fig.add_trace(go.Scatter(x=[], y=[], mode="lines+markers", name=j))
else:
for i in range(data_len):
fig.add_trace(go.Scatter(x=[], y=[], mode="lines+markers", name=f"{j}[{i}]"))
fig.update_layout(template="plotly_dark", xaxis_title=data.x_data_name, title=data.title)
# Add the data the DashApp
self.data.append(data)
# Add this to the server's list of figures
self._server.figs.append(fig)
label = f"karana-plot-{self._server.figure_counter}"
self._server.fig_labels.append(label)
# Add this to the server's Graph objects
self._server.gs.append(dcc.Graph(id=label, figure=fig))
# Increment the figure counter
self._server.figure_counter += 1