"""Tools for viewing a Subgraph with vis-network.js."""
from itertools import count
from dataclasses import dataclass
from ._datatypes import (
Node,
Edge,
NodeOptions,
NodeShape,
NetworkGraph,
NodeColorOptions,
NodeFontOptions,
EdgeOptions,
EdgeColorOptions,
ArrowOptions,
ArrowStyle,
SmoothOptions,
Button,
)
from ._server import GraphServer
from Karana.Dynamics import SubGraph, HingeType, PhysicalBody, BilateralConstraintType
__all__ = ["subgraphToGraph", "multibodyConstraintEdges", "MultibodyGraphServer"]
edge_colors = {
HingeType.LOCKED: "#FFFF00", # yellow
HingeType.SLIDER: "#00FF00", # green
HingeType.PIN: "bisque", # bisque
HingeType.BALL: "#FF00FF", # magenta
HingeType.UJOINT: "#800080", # purple
HingeType.CYLINDRICAL: "#00FFFF", # cyan
HingeType.CUSTOM: "#0000FF", # blue
}
coupler_color = "#808080"
def _fillLabelMap(
subgraph: SubGraph, label_map: dict[int | str, str] | None = None
) -> dict[int | str, str]:
"""Fill out the label_map with an entry for every body.
This is to ensure the same labels will be used consistently for different
subgraphs of the same multibody
"""
if label_map:
# Avoid mutating the user's map
label_map = label_map.copy()
else:
label_map = {}
# Generates 0, 1, 2, 3, ... for node ids
counter = count()
for body in [
subgraph.virtualRoot()
] + subgraph.sortedBodiesList(): # subgraph.physicalBodiesList():
if label_map.get(body.id(), None):
continue
if label_map.get(body.name(), None):
continue
# Use the next sequential number from the counter
if body.isRootBody():
label_map[body.id()] = str(next(counter))
continue
label_map[body.id()] = body.name()
return label_map
def subgraphToGraph(
subgraph: SubGraph,
title="Multibody System",
label_map: dict[int | str, str] | None = None,
constraints: bool = True,
) -> NetworkGraph:
"""Convert a SubGraph to a visjs NetworkGraph.
Parameters
----------
subgraph: SubGraph
The subgraph object to show
title: str
Title of the graph, defaults to "Multibody System"
label_map: dict[int | str, str] | None
Optional map to look up string labels for bodies. Keys may be
either body ids or body names.
constraints: bool
Whether to created edges for constraints; True by default
Returns
-------
NetworkGraph
The NetworkGraph for the SubGraph, ready for visualization
"""
label_map = _fillLabelMap(subgraph, label_map)
nodes = []
edges = []
penwidth = 2.0
def toNode(body: PhysicalBody) -> Node:
"""Create a node based on a body."""
if body.isRootBody():
color = "coral"
elif body.isCompoundBody():
color = "palegreen"
elif subgraph.isBaseBody(body):
color = "#FFFF00"
else:
color = "#87CEEB"
# First try looking up label by body ID
label = label_map.get(body.id(), None)
# If that fails try looking up label by body name
if not label:
label = label_map.get(body.name(), None)
# Create node for a body
fontsz = 16 if body.isCompoundBody() else 12
node = Node(
id=body.id(),
label=str(label),
title=body.name(),
options=NodeOptions(
shape="box",
size=25,
color=NodeColorOptions(
background=color,
border=color,
),
font=NodeFontOptions(
color="#000000", # White text for dark mode
size=fontsz,
face="Arial",
),
),
)
return node
def toEdge(parent: PhysicalBody, child: PhysicalBody) -> Edge:
"""Create an edge between adjacent bodies."""
hinge_type = child.parentHinge().hingeType()
style = "solid"
if hinge_type == HingeType.FULL6DOF:
style = "dashed"
color = edge_colors.get(hinge_type, "#FFFFFF")
# Create edge between adjacent bodies
edge = Edge(
from_=parent.id(),
to=child.id(),
title=str(hinge_type),
options=EdgeOptions(
color=color,
width=penwidth,
dashes=(style == "dashed"),
arrows=ArrowOptions(
to=ArrowStyle(
enabled=True,
scaleFactor=1.2,
),
),
smooth=SmoothOptions(
type="cubicBezier",
roundness=0.5,
),
),
)
return edge
def addChildren(parent):
"""Recursively add children of a given body."""
for child in subgraph.childrenBodies(parent):
nodes.append(toNode(child))
edges.append(toEdge(parent, child))
addChildren(child) # Recursively process children
# Start with virtual root
vroot = subgraph.virtualRoot()
nodes.append(toNode(vroot)) # Add root node
addChildren(vroot) # Process all children recursively
if constraints:
edges.extend(multibodyConstraintEdges(subgraph))
return NetworkGraph(nodes=nodes, edges=edges, title=title)
def multibodyConstraintEdges(
subgraph: SubGraph,
) -> list[Edge]:
"""Create edges for all constraints in the provided SubGraph.
Parameters
----------
subgraph : SubGraph
The SubGraph whose constraints will be used to create edges.
Returns
-------
list[Edge]
A list of edges, where every edge is associated with a constraint in the provided SubGraph.
"""
edges = []
penwidth = 2.0
for constraint in subgraph.enabledConstraints():
if constraint.type() in [
BilateralConstraintType.HINGE_LOOP,
BilateralConstraintType.CONVEL_LOOP,
]:
loop_constraint = constraint
source_node = loop_constraint.sourceNode().parentBody()
target_node = loop_constraint.targetNode().parentBody()
if loop_constraint.hasHinge():
hinge_type = loop_constraint.hinge().hingeType()
edge = Edge(
from_=source_node.id(),
to=target_node.id(),
title=f"{loop_constraint.name()}/{hinge_type}",
options=EdgeOptions(
color=edge_colors[hinge_type],
width=penwidth,
dashes=True,
arrows=ArrowOptions(
to=ArrowStyle(
enabled=True,
scaleFactor=1.0,
),
),
smooth=SmoothOptions(
type="cubicBezier",
roundness=0.3,
),
),
)
edges.append(edge)
else:
edge = Edge(
from_=source_node.id(),
to=target_node.id(),
title=f"{loop_constraint.name()}/CONVEL",
options=EdgeOptions(
color="#FF7F24", # chocolate1
width=penwidth,
dashes=True,
arrows=ArrowOptions(
to=ArrowStyle(
enabled=True,
scaleFactor=1.0,
),
),
smooth=SmoothOptions(
type="cubicBezier",
roundness=0.3,
),
),
)
edges.append(edge)
else:
# Add coordinate constraints
# for coord_constraint in subgraph.enabledCoordinateConstraints():
coord_constraint = constraint
obody = coord_constraint.osubhinge().parentHinge().pnode().parentBody()
pbody = coord_constraint.psubhinge().parentHinge().pnode().parentBody()
scale = coord_constraint.getScaleRatio()
edge = Edge(
from_=obody.id(),
to=pbody.id(),
title=f"{coord_constraint.name()}/COUPLER/{scale}",
options=EdgeOptions(
color="palegreen",
width=penwidth,
dashes=True,
arrows=ArrowOptions(
to=ArrowStyle(
enabled=True,
scaleFactor=1.0,
),
),
smooth=SmoothOptions(
type="cubicBezier",
roundness=0.3,
),
),
)
edges.append(edge)
return edges
class MultibodyGraphServer(GraphServer):
"""Specialized GraphServer for subgraph viewing.
Given a subgraph, this automatically generates a set of graphs
with and without constraints and other sets of extra edges. Buttons
are automatically added to toggle different parts of the graph on
and off.
"""
@dataclass
class _GraphData:
label: str
graph: NetworkGraph
enabled: bool
def __init__(
self,
subgraph: SubGraph,
*,
title="Multibody System",
port=8765,
autorun: bool = True,
log_level: str = "warning",
buttons: list[Button] | None = None,
label_map: dict[int | str, str] | None = None,
extra_edges: dict[str, list[Edge]] | list[Edge] | None = None,
):
"""Create a MultibodyGraphServer instance."""
tree_graph = subgraphToGraph(subgraph, title=title, constraints=False, label_map=label_map)
self.subgraphs = {
"tree": MultibodyGraphServer._GraphData(
label="tree",
graph=tree_graph,
enabled=True,
)
}
if not extra_edges:
extra_edges = {}
if isinstance(extra_edges, list):
extra_edges = {"extra": extra_edges}
extra_edges["constraints"] = multibodyConstraintEdges(subgraph)
for label, edges in extra_edges.items():
if not edges:
# Skip empty edge lists
continue
graph = tree_graph.clone()
# Replace the edges in the tree graph
graph.edges = edges
# Remove nodes not touched by the new edges
graph.removeIsolatedNodes()
self.subgraphs[label] = MultibodyGraphServer._GraphData(
label=label,
graph=graph,
enabled=False, # Don't show initially
)
buttons = self._defaultButtons() + (buttons or [])
super().__init__(
graph=tree_graph, port=port, autorun=autorun, log_level=log_level, buttons=buttons
)
[docs]
def enableSubgraph(self, label: str):
"""Enable the SubGraph with the given label.
Parameters
----------
label : str
The SubGraph to enable.
"""
self.subgraphs[label].enabled = True
self._rebuildGraph()
[docs]
def disableSubgraph(self, label: str):
"""Disable the SubGraph with the given lable.
Parameters
----------
label : str
The SubGraph to disable.
"""
self.subgraphs[label].enabled = False
self._rebuildGraph()
def _rebuildGraph(self):
"""Rebuild the graph."""
graph = NetworkGraph()
for subgraph_data in self.subgraphs.values():
if subgraph_data.enabled:
graph |= subgraph_data.graph
self.graph = graph
self.updateClientGraphs()
def _defaultButtons(self) -> list[Button]:
"""Define UI buttons always created by this class."""
if len(self.subgraphs) == 1:
return []
buttons = []
def _addToggle(label):
buttons.append(
Button(
text=f"{label} on",
callback=lambda: self.enableSubgraph(label),
style={}, # Add CSS styling here
)
)
buttons.append(
Button(
text=f"{label} off",
callback=lambda: self.disableSubgraph(label),
style={}, # Add CSS styling here
)
)
for label in self.subgraphs:
_addToggle(label)
return buttons