"""Chord diagram visualizer.
Interactive chord diagrams for showing pairwise relationships in a
circular layout. Ideal for confusion-style matrices, feature
interaction strengths, and model agreement visualization.
Example
-------
>>> from endgame.visualization import ChordDiagramVisualizer
>>> import numpy as np
>>> matrix = np.array([[0, 5, 3], [5, 0, 4], [3, 4, 0]])
>>> viz = ChordDiagramVisualizer(
... matrix=matrix,
... labels=["Model A", "Model B", "Model C"],
... title="Model Agreement",
... )
>>> viz.save("chord.html")
"""
from __future__ import annotations
import math
from collections.abc import Sequence
from typing import Any
import numpy as np
from endgame.visualization._base import BaseVisualizer
[docs]
class ChordDiagramVisualizer(BaseVisualizer):
"""Interactive chord diagram visualizer.
Parameters
----------
matrix : array-like, shape (n, n)
Symmetric or asymmetric flow matrix. ``matrix[i][j]``
is the flow from group *i* to group *j*.
labels : list of str
Group labels.
title : str, optional
Chart title.
palette : str, default='tableau'
Color palette.
width : int, default=650
Chart width.
height : int, default=650
Chart height.
theme : str, default='dark'
'dark' or 'light'.
"""
def __init__(
self,
matrix: Any,
labels: Sequence[str],
*,
title: str = "",
palette: str = "tableau",
width: int = 650,
height: int = 650,
theme: str = "dark",
):
super().__init__(title=title, palette=palette, width=width, height=height, theme=theme)
self._matrix = np.asarray(matrix, dtype=float)
if self._matrix.ndim != 2:
raise ValueError("matrix must be 2D")
self.labels = list(labels)
[docs]
@classmethod
def from_confusion_matrix(
cls,
matrix: Any,
class_names: Sequence[str],
**kwargs,
) -> ChordDiagramVisualizer:
"""Create a chord diagram from a confusion matrix.
Shows misclassification flows between classes.
Parameters
----------
matrix : array-like, shape (n, n)
Confusion matrix.
class_names : list of str
Class names.
**kwargs
Additional keyword arguments.
"""
kwargs.setdefault("title", "Classification Flow")
return cls(matrix, class_names, **kwargs)
def _build_data(self) -> dict[str, Any]:
m = self._matrix
n = m.shape[0]
row_sums = m.sum(axis=1)
# Precompute arc angles per group
total = float(m.sum())
gap = 0.04 # gap between groups in radians
available = 2 * math.pi - n * gap
if total <= 0:
return {"groups": [], "chords": []}
groups = []
angle = 0.0
for i in range(n):
span = (row_sums[i] / total) * available if total > 0 else 0
groups.append({
"label": self.labels[i],
"startAngle": round(angle, 6),
"endAngle": round(angle + span, 6),
"total": round(float(row_sums[i]), 4),
})
angle += span + gap
# Chords
chords = []
for i in range(n):
for j in range(i, n):
val_ij = float(m[i, j])
val_ji = float(m[j, i])
if val_ij <= 0 and val_ji <= 0:
continue
chords.append({
"source": i,
"target": j,
"valueIJ": round(val_ij, 4),
"valueJI": round(val_ji, 4),
})
return {
"groups": groups,
"chords": chords,
"labels": self.labels,
}
def _chart_type(self) -> str:
return "chord"
def _get_chart_js(self) -> str:
return _CHORD_JS
_CHORD_JS = r"""
function renderChart(data, config) {
const container = document.getElementById('chart-container');
const size = Math.min(config.width, config.height);
const svg = EG.svg('svg', {width: size, height: size});
container.appendChild(svg);
container.style.width = size + 'px';
container.style.height = size + 'px';
const palette = config.palette;
const cx = size / 2, cy = size / 2;
const outerR = size / 2 - 50;
const innerR = outerR - 20;
const groups = data.groups;
const chords = data.chords;
const labels = data.labels;
if (groups.length === 0) return;
function arcPath(cx, cy, r, startAngle, endAngle) {
const x1 = cx + r * Math.cos(startAngle - Math.PI/2);
const y1 = cy + r * Math.sin(startAngle - Math.PI/2);
const x2 = cx + r * Math.cos(endAngle - Math.PI/2);
const y2 = cy + r * Math.sin(endAngle - Math.PI/2);
const large = (endAngle - startAngle > Math.PI) ? 1 : 0;
return {x1, y1, x2, y2, large};
}
// Draw group arcs
groups.forEach(function(gr, i) {
const color = palette[i % palette.length];
const o = arcPath(cx, cy, outerR, gr.startAngle, gr.endAngle);
const inn = arcPath(cx, cy, innerR, gr.startAngle, gr.endAngle);
const d = [
'M', o.x1, o.y1,
'A', outerR, outerR, 0, o.large, 1, o.x2, o.y2,
'L', inn.x2, inn.y2,
'A', innerR, innerR, 0, inn.large, 0, inn.x1, inn.y1,
'Z'
].join(' ');
const arc = EG.svg('path', {d: d, fill: color, opacity: 0.85, stroke: 'var(--bg-card)', 'stroke-width': 1});
arc.addEventListener('mouseenter', function(e) {
arc.setAttribute('opacity', '1');
EG.tooltip.show(e, '<b>' + EG.esc(gr.label) + '</b><br>Total: ' + EG.fmt(gr.total));
});
arc.addEventListener('mouseleave', function() { arc.setAttribute('opacity', '0.85'); EG.tooltip.hide(); });
svg.appendChild(arc);
// Label
const midAngle = (gr.startAngle + gr.endAngle) / 2 - Math.PI / 2;
const lx = cx + (outerR + 18) * Math.cos(midAngle);
const ly = cy + (outerR + 18) * Math.sin(midAngle);
const anchor = lx > cx ? 'start' : 'end';
const label = EG.svg('text', {
x: lx, y: ly + 4, 'text-anchor': Math.abs(lx - cx) < 10 ? 'middle' : anchor,
fill: 'var(--text-primary)', 'font-size': '11px', 'font-weight': '500'
});
label.textContent = gr.label;
svg.appendChild(label);
});
// Draw chords
chords.forEach(function(ch) {
const sg = groups[ch.source];
const tg = groups[ch.target];
const color = palette[ch.source % palette.length];
// Simplified: draw a bezier from source arc midpoint to target arc midpoint
const sa = (sg.startAngle + sg.endAngle) / 2 - Math.PI / 2;
const ta = (tg.startAngle + tg.endAngle) / 2 - Math.PI / 2;
const sx = cx + innerR * Math.cos(sa);
const sy = cy + innerR * Math.sin(sa);
const tx = cx + innerR * Math.cos(ta);
const ty = cy + innerR * Math.sin(ta);
const totalVal = ch.valueIJ + ch.valueJI;
const thickness = Math.max(1, Math.min(totalVal / (data.groups[0].total || 1) * 30, 15));
const d = 'M' + sx + ',' + sy + ' Q' + cx + ',' + cy + ' ' + tx + ',' + ty;
const path = EG.svg('path', {
d: d, fill: 'none', stroke: color,
'stroke-width': thickness, opacity: 0.25,
'stroke-linecap': 'round'
});
path.addEventListener('mouseenter', function(e) {
path.setAttribute('opacity', '0.6');
EG.tooltip.show(e,
'<b>' + EG.esc(labels[ch.source]) + ' ↔ ' + EG.esc(labels[ch.target]) + '</b><br>' +
labels[ch.source] + ' → ' + labels[ch.target] + ': ' + EG.fmt(ch.valueIJ) + '<br>' +
labels[ch.target] + ' → ' + labels[ch.source] + ': ' + EG.fmt(ch.valueJI));
});
path.addEventListener('mouseleave', function() { path.setAttribute('opacity', '0.25'); EG.tooltip.hide(); });
svg.appendChild(path);
});
}
"""