"""Network diagram visualizer.
Interactive force-directed network diagrams, specifically designed for
Bayesian Network visualization but general enough for any DAG or graph.
Supports:
- Bayesian Network structure (from endgame Bayesian classifiers)
- General directed and undirected graphs
- Node sizing by importance/degree
- Edge coloring by weight/strength
- Force-directed layout with draggable nodes
Example
-------
>>> from endgame.visualization import NetworkDiagramVisualizer
>>> viz = NetworkDiagramVisualizer(
... nodes=["Rain", "Sprinkler", "Wet Grass"],
... edges=[("Rain", "Wet Grass"), ("Sprinkler", "Wet Grass"),
... ("Rain", "Sprinkler")],
... title="Bayesian Network: Wet Grass",
... )
>>> viz.save("bayesian_network.html")
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import numpy as np
from endgame.visualization._base import BaseVisualizer
[docs]
class NetworkDiagramVisualizer(BaseVisualizer):
"""Interactive network diagram visualizer.
Parameters
----------
nodes : list of str or list of dict
Node labels (strings) or dicts with 'id', 'label', and optional
'group', 'size', 'description' keys.
edges : list of (str, str) or (str, str, float)
Edges as (source, target) or (source, target, weight) tuples.
directed : bool, default=True
If True, draw arrows on edges (for DAGs like Bayesian Networks).
node_sizes : dict of str → float, optional
Custom node sizes. If None, sized by degree.
layout : str, default='force'
Layout algorithm: 'force' (force-directed), 'hierarchical'
(top-down DAG), 'circular'.
title : str, optional
Chart title.
palette : str, default='tableau'
Color palette.
width : int, default=800
Chart width.
height : int, default=600
Chart height.
theme : str, default='dark'
'dark' or 'light'.
"""
def __init__(
self,
nodes: Sequence[str] | Sequence[dict[str, Any]],
edges: Sequence[tuple[str, str] | tuple[str, str, float]],
*,
directed: bool = True,
node_sizes: dict[str, float] | None = None,
layout: str = "force",
title: str = "",
palette: str = "tableau",
width: int = 800,
height: int = 600,
theme: str = "dark",
):
super().__init__(title=title, palette=palette, width=width, height=height, theme=theme)
# Normalize nodes
if nodes and isinstance(nodes[0], str):
self._nodes = [{"id": n, "label": n} for n in nodes]
else:
self._nodes = [dict(n) for n in nodes]
self._edges = list(edges)
self.directed = directed
self.node_sizes = node_sizes or {}
self.layout = layout
# ------------------------------------------------------------------
# Classmethod constructors
# ------------------------------------------------------------------
[docs]
@classmethod
def from_bayesian_network(
cls,
model: Any,
*,
feature_names: Sequence[str] | None = None,
**kwargs,
) -> NetworkDiagramVisualizer:
"""Create from an endgame Bayesian classifier (TAN, KDB, ESKDB).
Extracts the DAG structure from the model's learned network.
Parameters
----------
model : estimator
Fitted Bayesian network classifier with structure information.
Supports endgame TAN, KDB, ESKDB classifiers and any model
with ``edges_`` or ``dag_`` attribute.
feature_names : list of str, optional
Feature names.
**kwargs
Additional keyword arguments.
"""
nodes = []
edges = []
# Try to extract from endgame Bayesian classifiers
if hasattr(model, "feature_names_in_"):
feature_names = list(model.feature_names_in_)
if hasattr(model, "edges_"):
# Direct edge list [(parent_idx, child_idx), ...]
edge_list = model.edges_
n_features = len(feature_names) if feature_names else 0
# Determine feature names
if feature_names is None:
max_idx = max(max(e) for e in edge_list) if edge_list else 0
feature_names = [f"X{i}" for i in range(max_idx + 1)]
nodes = list(feature_names)
# Add class node
if hasattr(model, "classes_"):
nodes.append("Class")
for parent, child in edge_list:
src = feature_names[parent] if parent < len(feature_names) else f"X{parent}"
tgt = feature_names[child] if child < len(feature_names) else f"X{child}"
edges.append((src, tgt))
# In TAN/KDB, class is parent of all features
if hasattr(model, "classes_"):
for fname in feature_names:
edges.append(("Class", fname))
elif hasattr(model, "dag_"):
# Adjacency matrix
dag = np.asarray(model.dag_)
n = dag.shape[0]
if feature_names is None:
feature_names = [f"X{i}" for i in range(n)]
nodes = list(feature_names)
for i in range(n):
for j in range(n):
if dag[i, j] != 0:
edges.append((feature_names[i], feature_names[j]))
elif hasattr(model, "tree_") and hasattr(model, "parents_"):
# Tree-augmented structure
parents = model.parents_
if feature_names is None:
feature_names = [f"X{i}" for i in range(len(parents))]
nodes = list(feature_names)
for i, p in enumerate(parents):
if p >= 0:
edges.append((feature_names[p], feature_names[i]))
else:
raise ValueError(
f"Cannot extract network structure from {type(model).__name__}. "
"Model must have 'edges_', 'dag_', or 'parents_' attribute."
)
kwargs.setdefault("title", f"Bayesian Network ({type(model).__name__})")
kwargs.setdefault("directed", True)
kwargs.setdefault("layout", "hierarchical")
# Compute node importance from degree
degree = {n: 0 for n in nodes}
for e in edges:
if e[0] in degree:
degree[e[0]] += 1
if e[1] in degree:
degree[e[1]] += 1
max_deg = max(degree.values()) if degree else 1
node_sizes = {n: 8 + (d / max_deg) * 20 for n, d in degree.items()}
return cls(nodes, edges, node_sizes=node_sizes, **kwargs)
[docs]
@classmethod
def from_adjacency_matrix(
cls,
matrix: Any,
labels: Sequence[str],
*,
threshold: float = 0.0,
**kwargs,
) -> NetworkDiagramVisualizer:
"""Create from an adjacency/weight matrix.
Parameters
----------
matrix : array-like, shape (n, n)
Adjacency or weight matrix.
labels : list of str
Node labels.
threshold : float, default=0.0
Minimum absolute value to create an edge.
**kwargs
Additional keyword arguments.
"""
m = np.asarray(matrix)
n = m.shape[0]
edges = []
for i in range(n):
for j in range(n):
if i != j and abs(m[i, j]) > threshold:
edges.append((labels[i], labels[j], float(m[i, j])))
return cls(list(labels), edges, **kwargs)
[docs]
@classmethod
def from_edge_list(
cls,
edges: Sequence[tuple[str, str]],
**kwargs,
) -> NetworkDiagramVisualizer:
"""Create from a simple edge list, auto-discovering nodes.
Parameters
----------
edges : list of (str, str)
Edge tuples.
**kwargs
Additional keyword arguments.
"""
node_set = set()
for e in edges:
node_set.add(e[0])
node_set.add(e[1])
return cls(sorted(node_set), edges, **kwargs)
# ------------------------------------------------------------------
# BaseVisualizer interface
# ------------------------------------------------------------------
def _build_data(self) -> dict[str, Any]:
# Node index map
node_ids = [n["id"] for n in self._nodes]
node_idx = {nid: i for i, nid in enumerate(node_ids)}
# Compute degree for sizing
in_deg = {nid: 0 for nid in node_ids}
out_deg = {nid: 0 for nid in node_ids}
for e in self._edges:
src, tgt = e[0], e[1]
if src in out_deg:
out_deg[src] += 1
if tgt in in_deg:
in_deg[tgt] += 1
nodes = []
for n in self._nodes:
nid = n["id"]
degree = in_deg.get(nid, 0) + out_deg.get(nid, 0)
size = self.node_sizes.get(nid, 8 + degree * 3)
group = n.get("group", "")
nodes.append({
"id": nid,
"label": n.get("label", nid),
"group": group,
"size": round(float(size), 2),
"inDeg": in_deg.get(nid, 0),
"outDeg": out_deg.get(nid, 0),
"description": n.get("description", ""),
})
edges = []
for e in self._edges:
src, tgt = e[0], e[1]
weight = float(e[2]) if len(e) > 2 else 1.0
if src in node_idx and tgt in node_idx:
edges.append({
"source": node_idx[src],
"target": node_idx[tgt],
"weight": round(weight, 4),
})
# Initial positions (hierarchical or circular)
positions = self._compute_layout(nodes, edges)
return {
"nodes": nodes,
"edges": edges,
"directed": self.directed,
"layout": self.layout,
"positions": positions,
}
def _compute_layout(self, nodes, edges):
"""Compute initial node positions."""
n = len(nodes)
if n == 0:
return []
if self.layout == "circular":
import math
positions = []
for i in range(n):
angle = 2 * math.pi * i / n - math.pi / 2
positions.append({
"x": round(0.5 + 0.35 * math.cos(angle), 4),
"y": round(0.5 + 0.35 * math.sin(angle), 4),
})
return positions
if self.layout == "hierarchical":
# Topological layering
node_ids = [nd["id"] for nd in nodes]
idx = {nid: i for i, nid in enumerate(node_ids)}
in_deg = [0] * n
adj = [[] for _ in range(n)]
for e in edges:
adj[e["source"]].append(e["target"])
in_deg[e["target"]] += 1
layers = [0] * n
queue = [i for i in range(n) if in_deg[i] == 0]
visited = [False] * n
for i in queue:
visited[i] = True
while queue:
cur = queue.pop(0)
for child in adj[cur]:
layers[child] = max(layers[child], layers[cur] + 1)
in_deg[child] -= 1
if in_deg[child] <= 0 and not visited[child]:
visited[child] = True
queue.append(child)
# Unvisited nodes (cycles) — assign to last layer + 1
max_layer = max(layers) if layers else 0
for i in range(n):
if not visited[i]:
layers[i] = max_layer + 1
max_layer = max(layers)
layer_groups: dict[int, list[int]] = {}
for i, layer in enumerate(layers):
layer_groups.setdefault(layer, []).append(i)
positions = [{"x": 0.5, "y": 0.5}] * n
for layer, group in layer_groups.items():
for pos, ni in enumerate(group):
x = (layer + 0.5) / (max_layer + 1) if max_layer > 0 else 0.5
y = (pos + 0.5) / len(group)
positions[ni] = {"x": round(x, 4), "y": round(y, 4)}
return positions
# Force layout: random initial positions
rng = np.random.RandomState(42)
return [
{"x": round(float(rng.uniform(0.2, 0.8)), 4),
"y": round(float(rng.uniform(0.2, 0.8)), 4)}
for _ in range(n)
]
def _chart_type(self) -> str:
return "network"
def _get_chart_js(self) -> str:
return _NETWORK_JS
# ---------------------------------------------------------------------------
# JavaScript renderer with force-directed simulation
# ---------------------------------------------------------------------------
_NETWORK_JS = r"""
function renderChart(data, config) {
const container = document.getElementById('chart-container');
const W = config.width, H = config.height;
const svg = EG.svg('svg', {width: W, height: H});
container.appendChild(svg);
const palette = config.palette;
const nodes = data.nodes;
const edges = data.edges;
const directed = data.directed;
const n = nodes.length;
if (n === 0) return;
// Initialize positions from layout
const pos = data.positions.map(function(p) {
return {x: p.x * W, y: p.y * H};
});
// Arrow marker for directed edges
if (directed) {
const defs = EG.svg('defs');
svg.appendChild(defs);
const marker = EG.svg('marker', {
id: 'arrowhead', markerWidth: 10, markerHeight: 7,
refX: 10, refY: 3.5, orient: 'auto', fill: 'var(--text-muted)'
});
marker.appendChild(EG.svg('polygon', {points: '0 0, 10 3.5, 0 7'}));
defs.appendChild(marker);
}
// Edge group (drawn first, behind nodes)
const edgeGroup = EG.svg('g');
svg.appendChild(edgeGroup);
const nodeGroup = EG.svg('g');
svg.appendChild(nodeGroup);
const labelGroup = EG.svg('g');
svg.appendChild(labelGroup);
// Create edge elements
const edgeEls = [];
edges.forEach(function(e, ei) {
const line = EG.svg('line', {
'stroke': 'var(--text-muted)', 'stroke-width': Math.max(1, Math.min(e.weight * 2, 5)),
'opacity': 0.4, 'marker-end': directed ? 'url(#arrowhead)' : ''
});
line.addEventListener('mouseenter', function(ev) {
line.setAttribute('opacity', '0.8');
line.setAttribute('stroke', palette[ei % palette.length]);
EG.tooltip.show(ev, '<b>' + EG.esc(nodes[e.source].label) + ' → ' + EG.esc(nodes[e.target].label) + '</b>' +
(e.weight !== 1 ? '<br>Weight: ' + EG.fmt(e.weight) : ''));
});
line.addEventListener('mouseleave', function() {
line.setAttribute('opacity', '0.4');
line.setAttribute('stroke', 'var(--text-muted)');
EG.tooltip.hide();
});
edgeGroup.appendChild(line);
edgeEls.push(line);
});
// Create node elements
const nodeEls = [];
const labelEls = [];
nodes.forEach(function(nd, i) {
const color = nd.group ? palette[nd.group.charCodeAt(0) % palette.length] : palette[i % palette.length];
const r = Math.max(nd.size, 6);
const circle = EG.svg('circle', {
r: r, fill: color, opacity: 0.85,
stroke: 'var(--bg-card)', 'stroke-width': 2,
style: 'cursor: grab;'
});
circle.addEventListener('mouseenter', function(ev) {
circle.setAttribute('opacity', '1');
circle.setAttribute('r', String(r + 3));
let html = '<b>' + EG.esc(nd.label) + '</b>';
if (nd.description) html += '<br>' + EG.esc(nd.description);
html += '<br>In-degree: ' + nd.inDeg + ', Out-degree: ' + nd.outDeg;
EG.tooltip.show(ev, html);
});
circle.addEventListener('mouseleave', function() {
circle.setAttribute('opacity', '0.85');
circle.setAttribute('r', String(r));
EG.tooltip.hide();
});
// Drag support
let dragging = false;
circle.addEventListener('mousedown', function(ev) {
dragging = true;
circle.style.cursor = 'grabbing';
ev.preventDefault();
});
document.addEventListener('mousemove', function(ev) {
if (!dragging) return;
const rect = svg.getBoundingClientRect();
pos[i].x = ev.clientX - rect.left;
pos[i].y = ev.clientY - rect.top;
updatePositions();
});
document.addEventListener('mouseup', function() {
if (dragging) {
dragging = false;
circle.style.cursor = 'grab';
}
});
nodeGroup.appendChild(circle);
nodeEls.push(circle);
// Label
const label = EG.svg('text', {
'text-anchor': 'middle', fill: 'var(--text-primary)',
'font-size': '11px', 'font-weight': '500',
'pointer-events': 'none'
});
label.textContent = nd.label.length > 14 ? nd.label.slice(0,12) + '…' : nd.label;
labelGroup.appendChild(label);
labelEls.push(label);
});
function updatePositions() {
edges.forEach(function(e, ei) {
const sx = pos[e.source].x, sy = pos[e.source].y;
const tx = pos[e.target].x, ty = pos[e.target].y;
// Shorten line to account for node radius
const sr = nodes[e.source].size || 6;
const tr = nodes[e.target].size || 6;
const dx = tx - sx, dy = ty - sy;
const dist = Math.sqrt(dx * dx + dy * dy) || 1;
const ux = dx / dist, uy = dy / dist;
edgeEls[ei].setAttribute('x1', sx + ux * sr);
edgeEls[ei].setAttribute('y1', sy + uy * sr);
edgeEls[ei].setAttribute('x2', tx - ux * (tr + 5));
edgeEls[ei].setAttribute('y2', ty - uy * (tr + 5));
});
nodes.forEach(function(nd, i) {
nodeEls[i].setAttribute('cx', pos[i].x);
nodeEls[i].setAttribute('cy', pos[i].y);
labelEls[i].setAttribute('x', pos[i].x);
labelEls[i].setAttribute('y', pos[i].y + nd.size + 14);
});
}
// Force simulation (simple Fruchterman-Reingold style)
if (data.layout === 'force') {
const area = W * H;
const k = Math.sqrt(area / n) * 0.8;
const iterations = 120;
const temp0 = W / 5;
for (let iter = 0; iter < iterations; iter++) {
const temp = temp0 * (1 - iter / iterations);
const disp = pos.map(function() { return {x: 0, y: 0}; });
// Repulsive forces
for (let i = 0; i < n; i++) {
for (let j = i + 1; j < n; j++) {
let dx = pos[i].x - pos[j].x;
let dy = pos[i].y - pos[j].y;
let dist = Math.sqrt(dx * dx + dy * dy) || 0.1;
let force = k * k / dist;
let fx = dx / dist * force;
let fy = dy / dist * force;
disp[i].x += fx; disp[i].y += fy;
disp[j].x -= fx; disp[j].y -= fy;
}
}
// Attractive forces
edges.forEach(function(e) {
let dx = pos[e.target].x - pos[e.source].x;
let dy = pos[e.target].y - pos[e.source].y;
let dist = Math.sqrt(dx * dx + dy * dy) || 0.1;
let force = dist * dist / k;
let fx = dx / dist * force;
let fy = dy / dist * force;
disp[e.source].x += fx; disp[e.source].y += fy;
disp[e.target].x -= fx; disp[e.target].y -= fy;
});
// Apply with temperature
for (let i = 0; i < n; i++) {
let dist = Math.sqrt(disp[i].x * disp[i].x + disp[i].y * disp[i].y) || 0.1;
pos[i].x += disp[i].x / dist * Math.min(dist, temp);
pos[i].y += disp[i].y / dist * Math.min(dist, temp);
// Keep in bounds
pos[i].x = Math.max(40, Math.min(W - 40, pos[i].x));
pos[i].y = Math.max(40, Math.min(H - 40, pos[i].y));
}
}
}
updatePositions();
}
"""