"""Parallel coordinates visualizer.
Interactive parallel coordinates plot for hyperparameter search
visualization and multi-dimensional data exploration.
Example
-------
>>> from endgame.visualization import ParallelCoordinatesVisualizer
>>> data = [
... {"lr": 0.001, "depth": 6, "n_est": 100, "score": 0.92},
... {"lr": 0.01, "depth": 4, "n_est": 200, "score": 0.89},
... ]
>>> viz = ParallelCoordinatesVisualizer(data, color_by="score")
>>> viz.save("parallel_coords.html")
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from endgame.visualization._base import BaseVisualizer
[docs]
class ParallelCoordinatesVisualizer(BaseVisualizer):
"""Interactive parallel coordinates visualizer.
Parameters
----------
data : list of dict
List of records, each a dict of dimension_name → value.
dimensions : list of str, optional
Which dimensions to show (and in what order). If None, use all.
color_by : str, optional
Dimension name to use for line coloring.
cmap : str, default='viridis_seq'
Color palette for continuous coloring.
title : str, optional
Chart title.
width : int, default=900
Chart width.
height : int, default=500
Chart height.
theme : str, default='dark'
'dark' or 'light'.
"""
def __init__(
self,
data: Sequence[dict[str, Any]],
*,
dimensions: Sequence[str] | None = None,
color_by: str | None = None,
cmap: str = "viridis_seq",
title: str = "",
width: int = 900,
height: int = 500,
theme: str = "dark",
):
super().__init__(title=title, palette=cmap, width=width, height=height, theme=theme)
self._records = list(data)
if dimensions is None:
# Use all keys from first record
dimensions = list(self._records[0].keys()) if self._records else []
self.dimensions = list(dimensions)
self.color_by = color_by
[docs]
@classmethod
def from_optuna_study(
cls,
study: Any,
*,
n_trials: int | None = None,
**kwargs,
) -> ParallelCoordinatesVisualizer:
"""Create from an Optuna study.
Parameters
----------
study : optuna.Study
Completed Optuna study.
n_trials : int, optional
Max number of trials to show.
**kwargs
Additional keyword arguments.
"""
trials = [t for t in study.trials if t.state.name == "COMPLETE"]
if n_trials:
trials = trials[:n_trials]
records = []
for t in trials:
rec = dict(t.params)
rec["objective"] = t.value
records.append(rec)
kwargs.setdefault("color_by", "objective")
kwargs.setdefault("title", "Hyperparameter Search")
return cls(records, **kwargs)
def _build_data(self) -> dict[str, Any]:
dims = self.dimensions
axes = []
for dim in dims:
values = [r.get(dim) for r in self._records]
# Determine if numeric
numeric = all(isinstance(v, (int, float)) for v in values if v is not None)
if numeric:
clean = [float(v) for v in values if v is not None]
axes.append({
"name": dim,
"type": "numeric",
"values": [float(v) if v is not None else None for v in values],
"min": min(clean) if clean else 0,
"max": max(clean) if clean else 1,
})
else:
unique = sorted(set(str(v) for v in values if v is not None))
axes.append({
"name": dim,
"type": "categorical",
"values": [str(v) if v is not None else "" for v in values],
"categories": unique,
})
color_data = None
if self.color_by:
color_vals = [r.get(self.color_by) for r in self._records]
clean = [float(v) for v in color_vals if v is not None and isinstance(v, (int, float))]
if clean:
color_data = {
"values": [float(v) if isinstance(v, (int, float)) else None for v in color_vals],
"min": min(clean),
"max": max(clean),
"name": self.color_by,
}
return {
"axes": axes,
"nRecords": len(self._records),
"colorData": color_data,
}
def _chart_type(self) -> str:
return "parallel_coordinates"
def _get_chart_js(self) -> str:
return _PARCOORD_JS
_PARCOORD_JS = r"""
function renderChart(data, config) {
const container = document.getElementById('chart-container');
const margin = {top: 30, right: 30, bottom: 30, left: 30};
const ctx = EG.createSVG(container, config.width, config.height, margin);
const {g, width: W, height: H} = ctx;
const palette = config.palette;
const axes = data.axes;
const nAxes = axes.length;
const nRec = data.nRecords;
if (nAxes < 2 || nRec === 0) return;
const axisSpacing = W / (nAxes - 1);
// Color function
let colorFn;
if (data.colorData) {
const cs = EG.colorScale(palette, data.colorData.min, data.colorData.max);
colorFn = function(i) {
const v = data.colorData.values[i];
return v !== null ? cs(v) : 'var(--text-muted)';
};
} else {
colorFn = function() { return palette[0] || 'var(--accent)'; };
}
// Scale functions per axis
const scales = axes.map(function(ax, ai) {
const x = ai * axisSpacing;
if (ax.type === 'numeric') {
return function(v) {
if (v === null) return H / 2;
return H - (v - ax.min) / (ax.max - ax.min || 1) * H;
};
} else {
const cats = ax.categories;
return function(v) {
const idx = cats.indexOf(String(v));
return H - (idx + 0.5) / cats.length * H;
};
}
});
// Draw lines
for (let i = 0; i < nRec; i++) {
let d = '';
for (let a = 0; a < nAxes; a++) {
const x = a * axisSpacing;
const y = scales[a](axes[a].values[i]);
d += (a === 0 ? 'M' : ' L') + x + ' ' + y;
}
const line = EG.svg('path', {
d: d, fill: 'none', stroke: colorFn(i),
'stroke-width': 1.5, opacity: 0.4
});
line.addEventListener('mouseenter', function(e) {
line.setAttribute('opacity', '1');
line.setAttribute('stroke-width', '3');
let html = '';
axes.forEach(function(ax) {
html += '<b>' + EG.esc(ax.name) + ':</b> ' + ax.values[i] + '<br>';
});
EG.tooltip.show(e, html);
});
line.addEventListener('mouseleave', function() {
line.setAttribute('opacity', '0.4');
line.setAttribute('stroke-width', '1.5');
EG.tooltip.hide();
});
g.appendChild(line);
}
// Draw axes
for (let a = 0; a < nAxes; a++) {
const x = a * axisSpacing;
const ax = axes[a];
g.appendChild(EG.svg('line', {x1: x, y1: 0, x2: x, y2: H, stroke: 'var(--border)', 'stroke-width': 1.5}));
// Axis label
const label = EG.svg('text', {x: x, y: -10, 'text-anchor': 'middle', fill: 'var(--text-primary)', 'font-size': '11px', 'font-weight': '600'});
label.textContent = ax.name.length > 12 ? ax.name.slice(0,10) + '…' : ax.name;
g.appendChild(label);
// Ticks
if (ax.type === 'numeric') {
const ticks = EG.niceTicks(ax.min, ax.max, 4);
ticks.forEach(function(v) {
const y = scales[a](v);
const t = EG.svg('text', {x: x - 5, y: y + 3, 'text-anchor': 'end', fill: 'var(--text-muted)', 'font-size': '9px'});
t.textContent = EG.fmt(v);
g.appendChild(t);
});
} else {
ax.categories.forEach(function(cat, ci) {
const y = H - (ci + 0.5) / ax.categories.length * H;
const t = EG.svg('text', {x: x - 5, y: y + 3, 'text-anchor': 'end', fill: 'var(--text-muted)', 'font-size': '9px'});
t.textContent = cat.length > 8 ? cat.slice(0,6) + '…' : cat;
g.appendChild(t);
});
}
}
// Color bar legend
if (data.colorData) {
const cbH = H * 0.6, cbW = 12;
const cbG = EG.svg('g', {transform: `translate(${W + 15}, ${(H - cbH)/2})`});
g.appendChild(cbG);
const nSteps = 30;
const cs = EG.colorScale(palette, data.colorData.min, data.colorData.max);
for (let i = 0; i < nSteps; i++) {
const v = data.colorData.max - (i / nSteps) * (data.colorData.max - data.colorData.min);
cbG.appendChild(EG.svg('rect', {x:0, y: i * (cbH/nSteps), width: cbW, height: cbH/nSteps+1, fill: cs(v)}));
}
cbG.appendChild(EG.svg('text', {x: cbW+4, y: 10, fill:'var(--text-secondary)', 'font-size':'9px'})).textContent = EG.fmt(data.colorData.max);
cbG.appendChild(EG.svg('text', {x: cbW+4, y: cbH, fill:'var(--text-secondary)', 'font-size':'9px'})).textContent = EG.fmt(data.colorData.min);
cbG.appendChild(EG.svg('text', {x: cbW/2, y: -8, 'text-anchor':'middle', fill:'var(--text-muted)', 'font-size':'9px'})).textContent = data.colorData.name;
}
}
"""