Source code for endgame.visualization.ridgeline_plot

"""Ridgeline (Joy) plot visualizer.

Overlapping density plots — beautiful for comparing distributions across
models, folds, or feature groups. Each group gets its own KDE density
curve stacked vertically with configurable overlap.

Example
-------
>>> from endgame.visualization import RidgelinePlotVisualizer
>>> data = {"Fold 1": scores_1, "Fold 2": scores_2, "Fold 3": scores_3}
>>> viz = RidgelinePlotVisualizer(data)
>>> viz.save("ridgeline.html")
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

import numpy as np

from endgame.visualization._base import BaseVisualizer


[docs] class RidgelinePlotVisualizer(BaseVisualizer): """Interactive ridgeline (joy) plot visualizer. Parameters ---------- data : dict of str → list of float Group name → raw values. overlap : float, default=0.5 Overlap ratio between adjacent ridges (0 = no overlap, 1 = full). kde_points : int, default=100 Number of points for KDE evaluation. bandwidth : float, optional KDE bandwidth. If None, uses Silverman's rule. show_quantiles : bool, default=True Show median and Q1/Q3 markers. title : str, optional Chart title. palette : str, default='tableau' Color palette. width : int, default=800 Chart width. height : int, default=500 Chart height. theme : str, default='dark' 'dark' or 'light'. """ def __init__( self, data: dict[str, Sequence[float]], *, overlap: float = 0.5, kde_points: int = 100, bandwidth: float | None = None, show_quantiles: bool = True, title: str = "", palette: str = "tableau", width: int = 800, height: int = 500, theme: str = "dark", ): super().__init__(title=title or "Ridgeline Plot", palette=palette, width=width, height=height, theme=theme) self._data = {k: list(v) for k, v in data.items()} self.overlap = overlap self.kde_points = kde_points self.bandwidth = bandwidth self.show_quantiles = show_quantiles
[docs] @classmethod def from_cv_results( cls, results: dict[str, Sequence[float]], **kwargs, ) -> RidgelinePlotVisualizer: """Create from cross-validation results. Parameters ---------- results : dict of str → list of float Model name → fold scores. **kwargs Additional keyword arguments. """ kwargs.setdefault("title", "CV Score Distributions") return cls(results, **kwargs)
[docs] @classmethod def from_feature_distributions( cls, X: Any, feature_names: Sequence[str] | None = None, *, max_features: int = 15, **kwargs, ) -> RidgelinePlotVisualizer: """Create from feature column distributions. Parameters ---------- X : array-like of shape (n_samples, n_features) Feature matrix. feature_names : list of str, optional Feature names. max_features : int, default=15 Max number of features to show. **kwargs Additional keyword arguments. """ X_arr = np.asarray(X) if X_arr.ndim == 1: X_arr = X_arr.reshape(-1, 1) n_features = min(X_arr.shape[1], max_features) if feature_names is None: if hasattr(X, "columns"): feature_names = list(X.columns)[:n_features] else: feature_names = [f"Feature {i}" for i in range(n_features)] data = {} for i in range(n_features): col = X_arr[:, i] col = col[~np.isnan(col)] data[feature_names[i]] = col.tolist() kwargs.setdefault("title", "Feature Distributions") return cls(data, **kwargs)
def _build_data(self) -> dict[str, Any]: ridges = [] global_min = float("inf") global_max = float("-inf") for name, values in self._data.items(): arr = np.asarray(values, dtype=float) arr = arr[~np.isnan(arr)] if len(arr) == 0: continue lo, hi = float(arr.min()), float(arr.max()) global_min = min(global_min, lo) global_max = max(global_max, hi) # KDE bw = self.bandwidth if bw is None: std = float(np.std(arr)) n = len(arr) bw = 1.06 * std * n ** (-1 / 5) if std > 0 else 0.1 pad = bw * 2 x_grid = np.linspace(lo - pad, hi + pad, self.kde_points) density = _kde(arr, x_grid, bw) # Quantiles q1 = float(np.percentile(arr, 25)) median = float(np.median(arr)) q3 = float(np.percentile(arr, 75)) mean = float(np.mean(arr)) ridges.append({ "name": name, "x": [round(float(v), 6) for v in x_grid], "density": [round(float(v), 8) for v in density], "q1": round(q1, 6), "median": round(median, 6), "q3": round(q3, 6), "mean": round(mean, 6), "n": len(arr), }) return { "ridges": ridges, "overlap": self.overlap, "showQuantiles": self.show_quantiles, } def _chart_type(self) -> str: return "ridgeline" def _get_chart_js(self) -> str: return _RIDGELINE_JS
def _kde(data: np.ndarray, x_grid: np.ndarray, bandwidth: float) -> np.ndarray: """Simple Gaussian KDE.""" n = len(data) if n == 0 or bandwidth <= 0: return np.zeros_like(x_grid) density = np.zeros_like(x_grid, dtype=float) for xi in data: density += np.exp(-0.5 * ((x_grid - xi) / bandwidth) ** 2) density /= n * bandwidth * np.sqrt(2 * np.pi) return density _RIDGELINE_JS = r""" function renderChart(data, config) { const container = document.getElementById('chart-container'); const ridges = data.ridges; const n = ridges.length; if (n === 0) return; const margin = {top: 20, right: 30, bottom: 45, left: 130}; const W = config.width - margin.left - margin.right; const ridgeH = Math.min(80, (config.height - margin.top - margin.bottom) / (n * (1 - data.overlap * 0.5))); const stepY = ridgeH * (1 - data.overlap * 0.5); const totalH = stepY * n + ridgeH * 0.5 + margin.top + margin.bottom; const svg = EG.svg('svg', {width: config.width, height: totalH}); container.appendChild(svg); const g = EG.svg('g', {transform: 'translate(' + margin.left + ',' + margin.top + ')'}); svg.appendChild(g); const palette = config.palette; // Global x range let xMin = Infinity, xMax = -Infinity; ridges.forEach(function(r) { r.x.forEach(function(v) { if (v < xMin) xMin = v; if (v > xMax) xMax = v; }); }); const xScale = EG.scaleLinear([xMin, xMax], [0, W]); // Find max density for normalization let dMax = 0; ridges.forEach(function(r) { r.density.forEach(function(v) { if (v > dMax) dMax = v; }); }); if (dMax === 0) dMax = 1; // Draw ridges from bottom to top so earlier ones render behind for (let i = n - 1; i >= 0; i--) { const r = ridges[i]; const color = palette[i % palette.length]; const baseY = i * stepY; const yScale = function(d) { return baseY + ridgeH - (d / dMax) * ridgeH * 0.85; }; // Area path let d = 'M' + xScale(r.x[0]) + ' ' + (baseY + ridgeH); for (let j = 0; j < r.x.length; j++) { d += ' L' + xScale(r.x[j]) + ' ' + yScale(r.density[j]); } d += ' L' + xScale(r.x[r.x.length - 1]) + ' ' + (baseY + ridgeH) + ' Z'; const area = EG.svg('path', { d: d, fill: color, opacity: 0.55, stroke: color, 'stroke-width': 1.5 }); area.addEventListener('mouseenter', function(e) { area.setAttribute('opacity', '0.8'); EG.tooltip.show(e, '<b>' + EG.esc(r.name) + '</b> (n=' + r.n + ')<br>' + 'Mean: ' + EG.fmt(r.mean, 4) + '<br>' + 'Median: ' + EG.fmt(r.median, 4) + '<br>' + 'Q1: ' + EG.fmt(r.q1, 4) + ' | Q3: ' + EG.fmt(r.q3, 4)); }); area.addEventListener('mouseleave', function() { area.setAttribute('opacity', '0.55'); EG.tooltip.hide(); }); g.appendChild(area); // Quantile markers if (data.showQuantiles) { [r.q1, r.median, r.q3].forEach(function(qv, qi) { const qx = xScale(qv); const dAtQ = _interpDensity(r.x, r.density, qv); const qy1 = baseY + ridgeH; const qy2 = yScale(dAtQ); g.appendChild(EG.svg('line', { x1: qx, y1: qy1, x2: qx, y2: qy2, stroke: qi === 1 ? '#ffffff' : 'rgba(255,255,255,0.5)', 'stroke-width': qi === 1 ? 2 : 1, 'stroke-dasharray': qi === 1 ? 'none' : '3,2' })); }); } // Label g.appendChild(EG.svg('text', { x: -10, y: baseY + ridgeH * 0.6, 'text-anchor': 'end', fill: 'var(--text-primary)', 'font-size': '11px' })).textContent = r.name.length > 18 ? r.name.slice(0, 16) + '…' : r.name; } // X axis var axG = EG.svg('g', {transform: 'translate(0,' + (n * stepY + ridgeH * 0.3) + ')'}); g.appendChild(axG); axG.appendChild(EG.svg('line', {x1: 0, y1: 0, x2: W, y2: 0, stroke: 'var(--border)'})); var ticks = EG.niceTicks(xMin, xMax, 6); ticks.forEach(function(v) { axG.appendChild(EG.svg('line', {x1: xScale(v), y1: 0, x2: xScale(v), y2: 5, stroke: 'var(--text-muted)'})); axG.appendChild(EG.svg('text', {x: xScale(v), y: 18, 'text-anchor': 'middle', fill: 'var(--text-secondary)', 'font-size': '10px'})).textContent = EG.fmt(v, 3); }); function _interpDensity(xs, ds, val) { for (var j = 0; j < xs.length - 1; j++) { if (xs[j] <= val && xs[j+1] >= val) { var t = (val - xs[j]) / (xs[j+1] - xs[j]); return ds[j] + t * (ds[j+1] - ds[j]); } } return 0; } } """