Source code for endgame.visualization.confusion_matrix

"""Confusion matrix visualizer.

Interactive confusion matrix for classification evaluation with
normalized/raw views and per-class metrics.

Example
-------
>>> from endgame.visualization import ConfusionMatrixVisualizer
>>> viz = ConfusionMatrixVisualizer(
...     matrix=[[50, 3], [7, 40]],
...     class_names=["negative", "positive"],
... )
>>> viz.save("confusion_matrix.html")
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

import numpy as np

from endgame.visualization._base import BaseVisualizer
from endgame.visualization._palettes import DEFAULT_SEQUENTIAL


[docs] class ConfusionMatrixVisualizer(BaseVisualizer): """Interactive confusion matrix visualizer. Parameters ---------- matrix : array-like, shape (n_classes, n_classes) Confusion matrix (rows = true, columns = predicted). class_names : list of str, optional Class label names. normalize : bool, default=False If True, show row-normalized values (recall per class). title : str, optional Chart title. cmap : str, default='blues' Color palette for the matrix cells. width : int, default=650 Chart width. height : int, default=600 Chart height. theme : str, default='dark' 'dark' or 'light'. """ def __init__( self, matrix: Any, *, class_names: Sequence[str] | None = None, normalize: bool = False, title: str = "", cmap: str = DEFAULT_SEQUENTIAL, width: int = 650, height: int = 600, theme: str = "dark", ): super().__init__(title=title or "Confusion Matrix", palette=cmap, width=width, height=height, theme=theme) self._matrix = np.asarray(matrix, dtype=float) if self._matrix.ndim != 2 or self._matrix.shape[0] != self._matrix.shape[1]: raise ValueError("matrix must be square 2D array") n = self._matrix.shape[0] self.class_names = list(class_names) if class_names else [str(i) for i in range(n)] self.normalize = normalize # ------------------------------------------------------------------ # Classmethod constructors # ------------------------------------------------------------------
[docs] @classmethod def from_estimator( cls, model: Any, X: Any, y: Any, *, class_names: Sequence[str] | None = None, **kwargs, ) -> ConfusionMatrixVisualizer: """Create a confusion matrix from a fitted classifier. Parameters ---------- model : estimator Fitted sklearn-compatible classifier. X : array-like Features. y : array-like True labels. class_names : list of str, optional Class names. **kwargs Additional keyword arguments passed to the constructor. """ from sklearn.metrics import confusion_matrix y_pred = model.predict(X) cm = confusion_matrix(y, y_pred) if class_names is None and hasattr(model, "classes_"): class_names = [str(c) for c in model.classes_] return cls(cm, class_names=class_names, **kwargs)
[docs] @classmethod def from_predictions( cls, y_true: Any, y_pred: Any, *, class_names: Sequence[str] | None = None, **kwargs, ) -> ConfusionMatrixVisualizer: """Create a confusion matrix from predictions. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted labels. class_names : list of str, optional Class names. **kwargs Additional keyword arguments. """ from sklearn.metrics import confusion_matrix cm = confusion_matrix(y_true, y_pred) return cls(cm, class_names=class_names, **kwargs)
# ------------------------------------------------------------------ # BaseVisualizer interface # ------------------------------------------------------------------ def _build_data(self) -> dict[str, Any]: cm = self._matrix n = cm.shape[0] # Row-normalized version row_sums = cm.sum(axis=1, keepdims=True) row_sums[row_sums == 0] = 1 cm_norm = cm / row_sums # Per-class metrics precision = np.zeros(n) recall = np.zeros(n) f1 = np.zeros(n) for i in range(n): tp = cm[i, i] fp = cm[:, i].sum() - tp fn = cm[i, :].sum() - tp precision[i] = tp / (tp + fp) if (tp + fp) > 0 else 0 recall[i] = tp / (tp + fn) if (tp + fn) > 0 else 0 f1[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0 accuracy = float(np.trace(cm) / cm.sum()) if cm.sum() > 0 else 0 return { "matrix": cm.tolist(), "matrixNorm": [[round(float(v), 4) for v in row] for row in cm_norm], "classNames": self.class_names, "normalize": self.normalize, "precision": [round(float(v), 4) for v in precision], "recall": [round(float(v), 4) for v in recall], "f1": [round(float(v), 4) for v in f1], "accuracy": round(accuracy, 4), } def _chart_type(self) -> str: return "confusion_matrix" def _get_chart_js(self) -> str: return _CM_JS
# --------------------------------------------------------------------------- # JavaScript renderer # --------------------------------------------------------------------------- _CM_JS = r""" function renderChart(data, config) { const container = document.getElementById('chart-container'); const n = data.classNames.length; const cellSize = Math.min(70, Math.floor((config.width - 180) / n), Math.floor((config.height - 140) / n)); const margin = {top: 30, right: 30, bottom: 70, left: 100}; const plotW = n * cellSize; const plotH = n * cellSize; const totalW = plotW + margin.left + margin.right; const totalH = plotH + margin.top + margin.bottom; const svg = EG.svg('svg', {width: totalW, height: totalH}); container.appendChild(svg); container.style.width = totalW + 'px'; container.style.height = totalH + 'px'; const g = EG.svg('g', {transform: `translate(${margin.left},${margin.top})`}); svg.appendChild(g); const palette = config.palette; const cm = data.normalize ? data.matrixNorm : data.matrix; const cmRaw = data.matrix; const cmNorm = data.matrixNorm; // Find max for color scaling let maxVal = 0; for (let r = 0; r < n; r++) for (let c = 0; c < n; c++) if (cm[r][c] > maxVal) maxVal = cm[r][c]; const colorFn = EG.colorScale(palette, 0, maxVal || 1); // Draw cells for (let r = 0; r < n; r++) { for (let c = 0; c < n; c++) { const v = cm[r][c]; const raw = cmRaw[r][c]; const norm = cmNorm[r][c]; const x = c * cellSize; const y = r * cellSize; const isDiag = r === c; const rect = EG.svg('rect', { x: x, y: y, width: cellSize - 2, height: cellSize - 2, fill: colorFn(v), rx: 3, stroke: isDiag ? 'var(--accent)' : 'none', 'stroke-width': isDiag ? '2' : '0' }); rect.addEventListener('mouseenter', function(e) { rect.setAttribute('stroke', 'var(--text-primary)'); rect.setAttribute('stroke-width', '2'); EG.tooltip.show(e, '<b>True: ' + EG.esc(data.classNames[r]) + '</b><br>' + '<b>Pred: ' + EG.esc(data.classNames[c]) + '</b><br>' + 'Count: ' + Math.round(raw) + '<br>' + 'Rate: ' + EG.pct(norm) ); }); rect.addEventListener('mousemove', function(e) { EG.tooltip.show(e, '<b>True: ' + EG.esc(data.classNames[r]) + '</b><br>' + '<b>Pred: ' + EG.esc(data.classNames[c]) + '</b><br>' + 'Count: ' + Math.round(raw) + '<br>' + 'Rate: ' + EG.pct(norm) ); }); rect.addEventListener('mouseleave', function() { rect.setAttribute('stroke', isDiag ? 'var(--accent)' : 'none'); rect.setAttribute('stroke-width', isDiag ? '2' : '0'); EG.tooltip.hide(); }); g.appendChild(rect); // Annotation const brightness = v / (maxVal || 1); const textColor = brightness > 0.5 ? '#1a1a2e' : '#e0e0e0'; const valStr = data.normalize ? EG.pct(v) : Math.round(v).toString(); const txt = EG.svg('text', { x: x + cellSize / 2, y: y + cellSize / 2 + 5, 'text-anchor': 'middle', fill: textColor, 'font-size': cellSize < 40 ? '10px' : '13px', 'font-weight': isDiag ? '700' : '500' }); txt.textContent = valStr; g.appendChild(txt); } } // Y labels (True) for (let r = 0; r < n; r++) { const lbl = EG.svg('text', { x: -10, y: r * cellSize + cellSize / 2 + 4, 'text-anchor': 'end', fill: 'var(--text-secondary)', 'font-size': '11px' }); lbl.textContent = data.classNames[r]; g.appendChild(lbl); } g.appendChild(EG.svg('text', { x: -10, y: -10, 'text-anchor': 'end', fill: 'var(--text-muted)', 'font-size': '10px', 'font-style': 'italic' })).textContent = 'True ↓'; // X labels (Predicted) for (let c = 0; c < n; c++) { const lbl = EG.svg('text', { x: c * cellSize + cellSize / 2, y: plotH + 20, 'text-anchor': 'middle', fill: 'var(--text-secondary)', 'font-size': '11px' }); lbl.textContent = data.classNames[c]; g.appendChild(lbl); } g.appendChild(EG.svg('text', { x: plotW / 2, y: plotH + 40, 'text-anchor': 'middle', fill: 'var(--text-muted)', 'font-size': '10px', 'font-style': 'italic' })).textContent = 'Predicted →'; // Accuracy badge g.appendChild(EG.svg('text', { x: plotW / 2, y: plotH + 58, 'text-anchor': 'middle', fill: 'var(--accent)', 'font-size': '12px', 'font-weight': '600' })).textContent = 'Accuracy: ' + EG.pct(data.accuracy); } """