"""Scatterplot and bubble chart visualizer.
Interactive scatterplots for embeddings, actual-vs-predicted plots,
and multi-dimensional data exploration with zoom/pan support.
Example
-------
>>> from endgame.visualization import ScatterplotVisualizer
>>> import numpy as np
>>> x = np.random.randn(200)
>>> y = x + np.random.randn(200) * 0.3
>>> viz = ScatterplotVisualizer(x, y, title="Correlation")
>>> viz.save("scatter.html")
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import numpy as np
from endgame.visualization._base import BaseVisualizer
[docs]
class ScatterplotVisualizer(BaseVisualizer):
"""Interactive scatterplot visualizer.
Parameters
----------
x : array-like
X coordinates.
y : array-like
Y coordinates.
labels : array-like of str, optional
Point labels for coloring (categorical).
sizes : array-like of float, optional
Point sizes for bubble mode.
colors : array-like of float, optional
Continuous values for color mapping.
point_labels : list of str, optional
Individual point hover labels.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
show_diagonal : bool, default=False
Show y=x diagonal line.
show_regression : bool, default=False
Show linear regression line.
title : str, optional
Chart title.
palette : str, default='tableau'
Color palette name.
width : int, default=700
Chart width.
height : int, default=600
Chart height.
theme : str, default='dark'
'dark' or 'light'.
"""
def __init__(
self,
x: Any,
y: Any,
*,
labels: Any | None = None,
sizes: Any | None = None,
colors: Any | None = None,
point_labels: Sequence[str] | None = None,
x_label: str = "",
y_label: str = "",
show_diagonal: bool = False,
show_regression: bool = False,
title: str = "",
palette: str = "tableau",
width: int = 700,
height: int = 600,
theme: str = "dark",
):
super().__init__(title=title, palette=palette, width=width, height=height, theme=theme)
self._x = np.asarray(x, dtype=float).ravel()
self._y = np.asarray(y, dtype=float).ravel()
self._labels = np.asarray(labels).ravel() if labels is not None else None
self._sizes = np.asarray(sizes, dtype=float).ravel() if sizes is not None else None
self._colors = np.asarray(colors, dtype=float).ravel() if colors is not None else None
self._point_labels = list(point_labels) if point_labels is not None else None
self.x_label = x_label
self.y_label = y_label
self.show_diagonal = show_diagonal
self.show_regression = show_regression
# ------------------------------------------------------------------
# Classmethod constructors
# ------------------------------------------------------------------
[docs]
@classmethod
def from_predictions(
cls,
y_true: Any,
y_pred: Any,
**kwargs,
) -> ScatterplotVisualizer:
"""Create an actual-vs-predicted scatter plot.
Parameters
----------
y_true : array-like
True values.
y_pred : array-like
Predicted values.
**kwargs
Additional keyword arguments.
"""
kwargs.setdefault("title", "Actual vs Predicted")
kwargs.setdefault("x_label", "Actual")
kwargs.setdefault("y_label", "Predicted")
kwargs.setdefault("show_diagonal", True)
kwargs.setdefault("show_regression", True)
return cls(np.asarray(y_true), np.asarray(y_pred), **kwargs)
[docs]
@classmethod
def from_embedding(
cls,
embedding: Any,
labels: Any | None = None,
**kwargs,
) -> ScatterplotVisualizer:
"""Create a scatter plot from 2D embeddings (t-SNE, UMAP, PCA).
Parameters
----------
embedding : array-like, shape (n_samples, 2)
2D embedding coordinates.
labels : array-like, optional
Labels for coloring.
**kwargs
Additional keyword arguments.
"""
emb = np.asarray(embedding)
kwargs.setdefault("title", "2D Embedding")
kwargs.setdefault("x_label", "Dimension 1")
kwargs.setdefault("y_label", "Dimension 2")
return cls(emb[:, 0], emb[:, 1], labels=labels, **kwargs)
# ------------------------------------------------------------------
# BaseVisualizer interface
# ------------------------------------------------------------------
def _build_data(self) -> dict[str, Any]:
n = len(self._x)
mask = ~(np.isnan(self._x) | np.isnan(self._y))
x = self._x[mask].tolist()
y = self._y[mask].tolist()
result: dict[str, Any] = {
"x": x,
"y": y,
"xLabel": self.x_label,
"yLabel": self.y_label,
"showDiagonal": self.show_diagonal,
"showRegression": self.show_regression,
}
if self._labels is not None:
labels = self._labels[mask]
unique = sorted(set(str(l) for l in labels))
result["labels"] = [str(l) for l in labels]
result["uniqueLabels"] = unique
if self._sizes is not None:
s = self._sizes[mask]
# Normalize sizes to 3-20px radius
smin, smax = float(s.min()), float(s.max())
if smax > smin:
norm = ((s - smin) / (smax - smin) * 17 + 3).tolist()
else:
norm = [6.0] * len(s)
result["sizes"] = norm
if self._colors is not None:
c = self._colors[mask]
result["colorValues"] = c.tolist()
result["colorMin"] = float(np.nanmin(c))
result["colorMax"] = float(np.nanmax(c))
if self._point_labels is not None:
result["pointLabels"] = [self._point_labels[i] for i in range(n) if mask[i]]
# Regression line
if self.show_regression and len(x) >= 2:
xa, ya = np.array(x), np.array(y)
slope, intercept = np.polyfit(xa, ya, 1)
r2 = 1 - np.sum((ya - (slope * xa + intercept))**2) / np.sum((ya - ya.mean())**2) if ya.std() > 0 else 0
result["regression"] = {"slope": round(float(slope), 6), "intercept": round(float(intercept), 6), "r2": round(float(r2), 4)}
return result
def _chart_type(self) -> str:
return "scatter"
def _get_chart_js(self) -> str:
return _SCATTER_JS
# ---------------------------------------------------------------------------
# JavaScript renderer
# ---------------------------------------------------------------------------
_SCATTER_JS = r"""
function renderChart(data, config) {
const container = document.getElementById('chart-container');
const margin = {top: 20, right: 20, bottom: 50, left: 60};
const ctx = EG.createSVG(container, config.width, config.height, margin);
const {g, width: W, height: H} = ctx;
const palette = config.palette;
const x = data.x, y = data.y;
if (x.length === 0) return;
// Scales
let xMin = Math.min.apply(null, x), xMax = Math.max.apply(null, x);
let yMin = Math.min.apply(null, y), yMax = Math.max.apply(null, y);
const xPad = (xMax - xMin) * 0.05 || 1;
const yPad = (yMax - yMin) * 0.05 || 1;
xMin -= xPad; xMax += xPad; yMin -= yPad; yMax += yPad;
const xScale = EG.scaleLinear([xMin, xMax], [0, W]);
const yScale = EG.scaleLinear([yMin, yMax], [H, 0]);
EG.drawXAxis(g, xScale, H, data.xLabel);
EG.drawYAxis(g, yScale, W, data.yLabel);
// Diagonal line
if (data.showDiagonal) {
const dMin = Math.max(xMin, yMin), dMax = Math.min(xMax, yMax);
g.appendChild(EG.svg('line', {
x1: xScale(dMin), y1: yScale(dMin), x2: xScale(dMax), y2: yScale(dMax),
stroke: 'var(--text-muted)', 'stroke-width': 1, 'stroke-dasharray': '6,4', opacity: 0.6
}));
}
// Regression line
if (data.regression) {
const r = data.regression;
const rx1 = xMin, rx2 = xMax;
const ry1 = r.slope * rx1 + r.intercept;
const ry2 = r.slope * rx2 + r.intercept;
g.appendChild(EG.svg('line', {
x1: xScale(rx1), y1: yScale(ry1), x2: xScale(rx2), y2: yScale(ry2),
stroke: 'var(--accent)', 'stroke-width': 2, opacity: 0.7
}));
const regLabel = EG.svg('text', {
x: W - 5, y: 15, 'text-anchor': 'end',
fill: 'var(--accent)', 'font-size': '11px'
});
regLabel.textContent = 'R² = ' + r.r2.toFixed(3);
g.appendChild(regLabel);
}
// Color setup
let colorFn;
if (data.uniqueLabels) {
const labelMap = {};
data.uniqueLabels.forEach(function(l, i) { labelMap[l] = palette[i % palette.length]; });
colorFn = function(i) { return labelMap[data.labels[i]]; };
} else if (data.colorValues) {
const cs = EG.colorScale(palette, data.colorMin, data.colorMax);
colorFn = function(i) { return cs(data.colorValues[i]); };
} else {
colorFn = function() { return palette[0]; };
}
// Points
for (let i = 0; i < x.length; i++) {
const cx = xScale(x[i]), cy = yScale(y[i]);
const r = data.sizes ? data.sizes[i] : 4;
const color = colorFn(i);
const circle = EG.svg('circle', {cx: cx, cy: cy, r: r, fill: color, opacity: 0.75, stroke: 'none'});
circle.addEventListener('mouseenter', function(e) {
circle.setAttribute('opacity', '1');
circle.setAttribute('stroke', 'var(--text-primary)');
circle.setAttribute('stroke-width', '1.5');
let html = '<b>x:</b> ' + EG.fmt(x[i], 4) + '<br><b>y:</b> ' + EG.fmt(y[i], 4);
if (data.labels) html = '<b>' + EG.esc(data.labels[i]) + '</b><br>' + html;
if (data.pointLabels) html = '<b>' + EG.esc(data.pointLabels[i]) + '</b><br>' + html;
EG.tooltip.show(e, html);
});
circle.addEventListener('mouseleave', function() {
circle.setAttribute('opacity', '0.75');
circle.removeAttribute('stroke');
circle.removeAttribute('stroke-width');
EG.tooltip.hide();
});
g.appendChild(circle);
}
// Legend for categorical labels
if (data.uniqueLabels && data.uniqueLabels.length <= 20) {
const items = data.uniqueLabels.map(function(l, i) {
return {label: l, color: palette[i % palette.length]};
});
EG.drawLegend(container, items);
}
}
"""