Source code for endgame.visualization.heatmap

"""Heatmap visualizer.

Interactive heatmap for correlation matrices, hyperparameter grids,
and general 2D data visualization.

Example
-------
>>> from endgame.visualization import HeatmapVisualizer
>>> import numpy as np
>>> data = np.random.randn(5, 5)
>>> viz = HeatmapVisualizer(data, x_labels=["a","b","c","d","e"],
...                         y_labels=["a","b","c","d","e"])
>>> viz.save("heatmap.html")
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

import numpy as np

from endgame.visualization._base import BaseVisualizer
from endgame.visualization._palettes import DEFAULT_DIVERGING


[docs] class HeatmapVisualizer(BaseVisualizer): """Interactive heatmap visualizer. Parameters ---------- data : array-like, shape (n_rows, n_cols) 2D data matrix. x_labels : list of str, optional Column labels. y_labels : list of str, optional Row labels. annotate : bool, default=True Show cell values. fmt : str, default='.2f' Format string for annotations. vmin, vmax : float, optional Color scale range. Auto-computed if None. cmap : str, default='rdbu' Color palette for the heatmap (diverging palettes work best). title : str, optional Chart title. width : int, default=700 Chart width. height : int, default=600 Chart height. theme : str, default='dark' 'dark' or 'light'. """ def __init__( self, data: Any, *, x_labels: Sequence[str] | None = None, y_labels: Sequence[str] | None = None, annotate: bool = True, fmt: str = ".2f", vmin: float | None = None, vmax: float | None = None, cmap: str = DEFAULT_DIVERGING, title: str = "", width: int = 700, height: int = 600, theme: str = "dark", ): super().__init__(title=title, palette=cmap, width=width, height=height, theme=theme) self._data = np.asarray(data, dtype=float) if self._data.ndim != 2: raise ValueError(f"data must be 2D, got shape {self._data.shape}") n_rows, n_cols = self._data.shape self.x_labels = list(x_labels) if x_labels is not None else [str(i) for i in range(n_cols)] self.y_labels = list(y_labels) if y_labels is not None else [str(i) for i in range(n_rows)] self.annotate = annotate self.fmt = fmt self.vmin = vmin self.vmax = vmax # ------------------------------------------------------------------ # Classmethod constructors # ------------------------------------------------------------------
[docs] @classmethod def from_correlation( cls, X: Any, *, feature_names: Sequence[str] | None = None, method: str = "pearson", **kwargs, ) -> HeatmapVisualizer: """Create a heatmap from a correlation matrix. Parameters ---------- X : array-like or DataFrame Data to compute correlations from. feature_names : list of str, optional Feature names. method : str, default='pearson' Correlation method (only used if X is a pandas DataFrame). **kwargs Additional keyword arguments passed to the constructor. """ # Handle pandas DataFrame if hasattr(X, "corr"): corr = X.corr(method=method) if feature_names is None: feature_names = list(corr.columns) corr_data = corr.values else: arr = np.asarray(X) corr_data = np.corrcoef(arr, rowvar=False) if feature_names is None: feature_names = [f"feature_{i}" for i in range(arr.shape[1])] kwargs.setdefault("title", "Correlation Matrix") kwargs.setdefault("vmin", -1.0) kwargs.setdefault("vmax", 1.0) return cls( corr_data, x_labels=feature_names, y_labels=feature_names, **kwargs, )
# ------------------------------------------------------------------ # BaseVisualizer interface # ------------------------------------------------------------------ def _build_data(self) -> dict[str, Any]: data = self._data # Handle NaN data_list = [] for row in data: data_list.append([None if np.isnan(v) else round(float(v), 6) for v in row]) vmin = self.vmin if self.vmin is not None else float(np.nanmin(data)) vmax = self.vmax if self.vmax is not None else float(np.nanmax(data)) return { "matrix": data_list, "xLabels": self.x_labels, "yLabels": self.y_labels, "annotate": self.annotate, "fmt": self.fmt, "vmin": vmin, "vmax": vmax, } def _chart_type(self) -> str: return "heatmap" def _get_chart_js(self) -> str: return _HEATMAP_JS
# --------------------------------------------------------------------------- # JavaScript renderer # --------------------------------------------------------------------------- _HEATMAP_JS = r""" function renderChart(data, config) { const container = document.getElementById('chart-container'); const nRows = data.matrix.length; const nCols = data.matrix[0] ? data.matrix[0].length : 0; if (nRows === 0 || nCols === 0) return; const cellW = Math.min(60, Math.floor((config.width - 120) / nCols)); const cellH = Math.min(60, Math.floor((config.height - 100) / nRows)); const margin = {top: 20, right: 60, bottom: 60 + (nCols > 8 ? 20 : 0), left: 90}; const plotW = nCols * cellW; const plotH = nRows * cellH; const totalW = plotW + margin.left + margin.right; const totalH = plotH + margin.top + margin.bottom; const svg = EG.svg('svg', {width: totalW, height: totalH}); container.appendChild(svg); container.style.width = totalW + 'px'; container.style.height = totalH + 'px'; const g = EG.svg('g', {transform: `translate(${margin.left},${margin.top})`}); svg.appendChild(g); const palette = config.palette; const vmin = data.vmin; const vmax = data.vmax; const colorFn = EG.colorScale(palette, vmin, vmax); // Draw cells for (let r = 0; r < nRows; r++) { for (let c = 0; c < nCols; c++) { const v = data.matrix[r][c]; const x = c * cellW; const y = r * cellH; const fill = v === null ? 'var(--bg-secondary)' : colorFn(v); const rect = EG.svg('rect', { x: x, y: y, width: cellW - 1, height: cellH - 1, fill: fill, rx: 2 }); rect.addEventListener('mouseenter', function(e) { rect.setAttribute('stroke', 'var(--text-primary)'); rect.setAttribute('stroke-width', '2'); const valStr = v === null ? 'N/A' : Number(v).toFixed(3); EG.tooltip.show(e, '<b>' + EG.esc(data.yLabels[r]) + ' × ' + EG.esc(data.xLabels[c]) + '</b><br>Value: ' + valStr); }); rect.addEventListener('mousemove', function(e) { const valStr = v === null ? 'N/A' : Number(v).toFixed(3); EG.tooltip.show(e, '<b>' + EG.esc(data.yLabels[r]) + ' × ' + EG.esc(data.xLabels[c]) + '</b><br>Value: ' + valStr); }); rect.addEventListener('mouseleave', function() { rect.removeAttribute('stroke'); rect.removeAttribute('stroke-width'); EG.tooltip.hide(); }); g.appendChild(rect); // Annotation if (data.annotate && v !== null) { const textColor = Math.abs(v - vmin) / (vmax - vmin || 1) > 0.4 && Math.abs(v - vmin) / (vmax - vmin || 1) < 0.6 ? 'var(--text-primary)' : (v > (vmin + vmax)/2 ? '#1a1a2e' : '#e0e0e0'); const txt = EG.svg('text', { x: x + cellW / 2, y: y + cellH / 2 + 4, 'text-anchor': 'middle', fill: textColor, 'font-size': cellW < 35 ? '9px' : '11px', 'font-weight': '500' }); txt.textContent = Number(v).toFixed(2); g.appendChild(txt); } } } // Y labels for (let r = 0; r < nRows; r++) { const lbl = EG.svg('text', { x: -8, y: r * cellH + cellH / 2 + 4, 'text-anchor': 'end', fill: 'var(--text-secondary)', 'font-size': '11px' }); lbl.textContent = data.yLabels[r].length > 12 ? data.yLabels[r].slice(0,10) + '…' : data.yLabels[r]; g.appendChild(lbl); } // X labels for (let c = 0; c < nCols; c++) { const lbl = EG.svg('text', { x: c * cellW + cellW / 2, y: plotH + 16, 'text-anchor': nCols > 8 ? 'end' : 'middle', fill: 'var(--text-secondary)', 'font-size': '11px', transform: nCols > 8 ? `rotate(-45,${c * cellW + cellW / 2},${plotH + 16})` : '' }); lbl.textContent = data.xLabels[c].length > 12 ? data.xLabels[c].slice(0,10) + '…' : data.xLabels[c]; g.appendChild(lbl); } // Color bar const cbW = 15, cbH = plotH; const cbX = plotW + 20; const cbG = EG.svg('g', {transform: `translate(${cbX}, 0)`}); g.appendChild(cbG); const nSteps = 50; for (let i = 0; i < nSteps; i++) { const v = vmax - (i / nSteps) * (vmax - vmin); cbG.appendChild(EG.svg('rect', { x: 0, y: i * (cbH / nSteps), width: cbW, height: cbH / nSteps + 1, fill: colorFn(v) })); } cbG.appendChild(EG.svg('text', {x: cbW + 5, y: 10, fill: 'var(--text-secondary)', 'font-size': '10px'})).textContent = EG.fmt(vmax); cbG.appendChild(EG.svg('text', {x: cbW + 5, y: cbH, fill: 'var(--text-secondary)', 'font-size': '10px'})).textContent = EG.fmt(vmin); } """