Source code for Karana.KUtils.visjs._mbody

"""Tools for viewing a Multibody with vis-network.js"""

from itertools import count
from dataclasses import dataclass

from ._datatypes import (
    Node,
    Edge,
    NodeOptions,
    NodeShape,
    NetworkGraph,
    NodeColorOptions,
    NodeFontOptions,
    EdgeOptions,
    EdgeColorOptions,
    ArrowOptions,
    ArrowStyle,
    SmoothOptions,
    Button,
)
from ._server import GraphServer

from Karana.Dynamics import Multibody, HINGE_TYPE, PhysicalBody, BILATERAL_CONSTRAINT_TYPE


__all__ = ["multibodyToGraph", "multibodyConstraintEdges", "MultibodyGraphServer"]


edge_colors = {
    HINGE_TYPE.LOCKED: "#FFFF00",  # yellow
    HINGE_TYPE.SLIDER: "#00FF00",  # green
    HINGE_TYPE.PIN: "bisque",  # bisque
    HINGE_TYPE.BALL: "#FF00FF",  # magenta
    HINGE_TYPE.UJOINT: "#800080",  # purple
    HINGE_TYPE.CYLINDRICAL: "#00FFFF",  # cyan
    HINGE_TYPE.CUSTOM: "#0000FF",  # blue
}

coupler_color = "#808080"


def _fillLabelMap(
    multibody: Multibody, label_map: dict[int | str, str] | None = None
) -> dict[int | str, str]:
    """Fill out the label_map with an entry for every body

    This is to ensure the same labels will be used consistently for different
    subgraphs of the same multibody

    """
    if label_map:
        # Avoid mutating the user's map
        label_map = label_map.copy()
    else:
        label_map = {}

    # Generates 0, 1, 2, 3, ... for node ids
    counter = count()

    for body in [multibody.virtualRoot()] + multibody.physicalBodiesList():
        if label_map.get(body.id(), None):
            continue
        if label_map.get(body.name(), None):
            continue
        # Use the next sequential number from the counter
        label_map[body.id()] = str(next(counter))
    return label_map


def multibodyToGraph(
    multibody: Multibody,
    title="Multibody System",
    label_map: dict[int | str, str] | None = None,
    constraints: bool = True,
) -> NetworkGraph:
    """Convert a Multibody to a visjs NetworkGraph

    Parameters
    ----------
    multibody: Multibody
        The multibody object to show
    title: str
        Title of the graph, defaults to "Multibody System"
    label_map: dict[int | str, str] | None
        Optional map to look up string labels for bodies. Keys may be
        either body ids or body names.
    constraints: bool
        Whether to created edges for constraints; True by default

    Returns
    -------
    NetworkGraph
        The NetworkGraph for the Multibody, ready for visualization

    """
    label_map = _fillLabelMap(multibody, label_map)

    nodes = []
    edges = []

    penwidth = 2.0

    def toNode(body: PhysicalBody) -> Node:
        """Create a node based on a body"""

        if body.isRootBody():
            color = "coral"
        elif multibody.isBaseBody(body):
            color = "#FFFF00"
        else:
            color = "#87CEEB"

        # First try looking up label by body ID
        label = label_map.get(body.id(), None)

        # If that fails try looking up label by body name
        if not label:
            label = label_map.get(body.name(), None)

        # Create node for a body
        node = Node(
            id=body.id(),
            label=str(label),
            title=body.name(),
            options=NodeOptions(
                shape="box",
                size=25,
                color=NodeColorOptions(
                    background=color,
                    border=color,
                ),
                font=NodeFontOptions(
                    color="#000000",  # White text for dark mode
                    size=12,
                    face="Arial",
                ),
            ),
        )
        return node

    def toEdge(parent: PhysicalBody, child: PhysicalBody) -> Edge:
        """Create an edge between adjacent bodies."""
        hinge_type = child.parentHinge().hingeType()
        style = "solid"
        if hinge_type == HINGE_TYPE.FULL6DOF:
            style = "dashed"

        color = edge_colors.get(hinge_type, "#FFFFFF")

        # Create edge between adjacent bodies
        edge = Edge(
            from_=parent.id(),
            to=child.id(),
            title=str(hinge_type),
            options=EdgeOptions(
                color=color,
                width=penwidth,
                dashes=(style == "dashed"),
                arrows=ArrowOptions(
                    to=ArrowStyle(
                        enabled=True,
                        scaleFactor=1.2,
                    ),
                ),
                smooth=SmoothOptions(
                    type="cubicBezier",
                    roundness=0.5,
                ),
            ),
        )
        return edge

    def addChildren(parent):
        """Recursively add children of a given body"""
        for child in multibody.childrenBodies(parent):
            nodes.append(toNode(child))
            edges.append(toEdge(parent, child))
            addChildren(child)  # Recursively process children

    # Start with virtual root
    vroot = multibody.virtualRoot()
    nodes.append(toNode(vroot))  # Add root node
    addChildren(vroot)  # Process all children recursively

    if constraints:
        edges.extend(multibodyConstraintEdges(multibody))

    return NetworkGraph(nodes=nodes, edges=edges, title=title)


def multibodyConstraintEdges(
    multibody: Multibody,
) -> list[Edge]:
    edges = []
    penwidth = 2.0

    for constraint in multibody.enabledConstraints():
        if constraint.type() in [
            BILATERAL_CONSTRAINT_TYPE.HINGE_LOOP,
            BILATERAL_CONSTRAINT_TYPE.CONVEL_LOOP,
        ]:

            loop_constraint = constraint
            source_node = loop_constraint.sourceNode().parentBody()
            target_node = loop_constraint.targetNode().parentBody()
            if loop_constraint.hasHinge():
                hinge_type = loop_constraint.hinge().hingeType()
                edge = Edge(
                    from_=source_node.id(),
                    to=target_node.id(),
                    title=f"{loop_constraint.name()}/{hinge_type}",
                    options=EdgeOptions(
                        color=edge_colors[hinge_type],
                        width=penwidth,
                        dashes=True,
                        arrows=ArrowOptions(
                            to=ArrowStyle(
                                enabled=True,
                                scaleFactor=1.0,
                            ),
                        ),
                        smooth=SmoothOptions(
                            type="cubicBezier",
                            roundness=0.3,
                        ),
                    ),
                )
                edges.append(edge)
            else:
                edge = Edge(
                    from_=source_node.id(),
                    to=target_node.id(),
                    title=f"{loop_constraint.name()}/CONVEL",
                    options=EdgeOptions(
                        color="#FF7F24",  # chocolate1
                        width=penwidth,
                        dashes=True,
                        arrows=ArrowOptions(
                            to=ArrowStyle(
                                enabled=True,
                                scaleFactor=1.0,
                            ),
                        ),
                        smooth=SmoothOptions(
                            type="cubicBezier",
                            roundness=0.3,
                        ),
                    ),
                )
                edges.append(edge)
        else:
            # Add coordinate constraints
            # for coord_constraint in multibody.enabledCoordinateConstraints():
            coord_constraint = constraint
            obody = coord_constraint.osubhinge().parentHinge().pnode().parentBody()
            pbody = coord_constraint.psubhinge().parentHinge().pnode().parentBody()
            scale = coord_constraint.getScaleRatio()

            edge = Edge(
                from_=obody.id(),
                to=pbody.id(),
                title=f"{coord_constraint.name()}/COUPLER/{scale}",
                options=EdgeOptions(
                    color="palegreen",
                    width=penwidth,
                    dashes=True,
                    arrows=ArrowOptions(
                        to=ArrowStyle(
                            enabled=True,
                            scaleFactor=1.0,
                        ),
                    ),
                    smooth=SmoothOptions(
                        type="cubicBezier",
                        roundness=0.3,
                    ),
                ),
            )
            edges.append(edge)
    return edges


class MultibodyGraphServer(GraphServer):
    """Specialized GraphServer for multibody viewing

    Given a multibody, this automatically generates a set of graphs
    with and without constraints and other sets of extra edges. Buttons
    are automatically added to toggle different parts of the graph on
    and off.

    """

    @dataclass
    class _GraphData:
        label: str
        graph: NetworkGraph
        enabled: bool

    def __init__(
        self,
        multibody: Multibody,
        *,
        title="Multibody System",
        port=8765,
        autorun: bool = True,
        log_level: str = "warning",
        buttons: list[Button] | None = None,
        label_map: dict[int | str, str] | None = None,
        extra_edges: dict[str, list[Edge]] | list[Edge] | None = None,
    ):
        """MultibodyGraphServer constructor"""
        tree_graph = multibodyToGraph(
            multibody, title=title, constraints=False, label_map=label_map
        )

        self.subgraphs = {
            "tree": MultibodyGraphServer._GraphData(
                label="tree",
                graph=tree_graph,
                enabled=True,
            )
        }

        if not extra_edges:
            extra_edges = {}
        if isinstance(extra_edges, list):
            extra_edges = {"extra": extra_edges}

        extra_edges["constraints"] = multibodyConstraintEdges(multibody)

        for label, edges in extra_edges.items():
            if not edges:
                # Skip empty edge lists
                continue
            graph = tree_graph.clone()
            # Replace the edges in the tree graph
            graph.edges = edges
            # Remove nodes not touched by the new edges
            graph.removeIsolatedNodes()
            self.subgraphs[label] = MultibodyGraphServer._GraphData(
                label=label,
                graph=graph,
                enabled=False,  # Don't show initially
            )

        buttons = self._defaultButtons() + (buttons or [])

        super().__init__(
            graph=tree_graph, port=port, autorun=autorun, log_level=log_level, buttons=buttons
        )

[docs] def enableSubgraph(self, label: str): self.subgraphs[label].enabled = True self._rebuildGraph()
[docs] def disableSubgraph(self, label: str): self.subgraphs[label].enabled = False self._rebuildGraph()
def _rebuildGraph(self): graph = NetworkGraph() for subgraph_data in self.subgraphs.values(): if subgraph_data.enabled: graph |= subgraph_data.graph self.graph = graph self.updateClientGraphs() def _defaultButtons(self) -> list[Button]: """Define UI buttons always created by this class""" if len(self.subgraphs) == 1: return [] buttons = [] def _addToggle(label): buttons.append( Button( text=f"{label} on", callback=lambda: self.enableSubgraph(label), style={}, # Add CSS styling here ) ) buttons.append( Button( text=f"{label} off", callback=lambda: self.disableSubgraph(label), style={}, # Add CSS styling here ) ) for label in self.subgraphs: _addToggle(label) return buttons