Source code for janus.viz_mpl

from __future__ import annotations

from typing import TYPE_CHECKING, Any

try:
    import matplotlib.pyplot as plt
    import networkx as nx

    MPL_AVAILABLE = True
except ImportError:
    MPL_AVAILABLE = False

if TYPE_CHECKING:
    from janus.base import JanusBase


[docs] class MatplotlibBackend: """ Renders the multiverse DAG using Matplotlib and NetworkX. This backend provides a more traditional node-link diagram with customizable colors and layout. """
[docs] def plot(self, obj: JanusBase, **kwargs: Any) -> Any: """ Renders the multiverse graph using Matplotlib. Args: obj: The Multiverse instance to visualize. show: Whether to call plt.show() immediately. figsize: Tuple of (width, height) for the figure. node_size: Size of the nodes in the plot. font_size: Size of the text labels. title: Custom title for the plot. Returns: A Matplotlib Figure object. """ if not MPL_AVAILABLE: raise ImportError( "Matplotlib backend requires 'matplotlib' and 'networkx'. " "Install them to use this feature." ) data = obj._engine.get_graph_data() G = nx.DiGraph() node_labels = {} node_colors = [] for node in data: nid = node["id"] labels = node["labels"] is_current = node["is_current"] G.add_node(nid) label_suffix = f"\n({', '.join(labels)})" if labels else "" node_labels[nid] = f"N{nid}{label_suffix}" if is_current: node_colors.append("#ff9ce6") elif labels: node_colors.append("#e1f5fe") else: node_colors.append("#f5f5f5") for node in data: nid = node["id"] for p_id in node["parents"]: G.add_edge(p_id, nid) # Layout: Rank by topological generation for a balanced timeline look try: generations = list(nx.topological_generations(G)) pos = {} for x, nodes in enumerate(generations): # Spread nodes vertically within the same generation offset = (len(nodes) - 1) / 2.0 for y, nid in enumerate(nodes): pos[nid] = (x, offset - y) except (nx.NetworkXUnfeasible, nx.NetworkXError): pos = nx.spring_layout(G) fig, ax = plt.subplots(figsize=kwargs.get("figsize", (8, 5))) nx.draw( G, pos, labels=node_labels, with_labels=True, node_color=node_colors, node_size=kwargs.get("node_size", 1200), node_shape="o", font_size=kwargs.get("font_size", 7), arrows=True, ax=ax, edge_color="#888", width=1.0, edgecolors="#333", ) ax.set_title(kwargs.get("title", f"Janus History: {type(obj).__name__}")) if kwargs.get("show", False): plt.show() return fig