Source code for endgame.visualization.venn

"""Venn diagram visualizer.

Interactive Venn diagrams for showing set overlaps (2 or 3 sets).

Example
-------
>>> from endgame.visualization import VennDiagramVisualizer
>>> viz = VennDiagramVisualizer(
...     sets={"Model A": 150, "Model B": 120},
...     intersections={"Model A&Model B": 80},
...     title="Prediction Overlap",
... )
>>> viz.save("venn.html")
"""

from __future__ import annotations

from typing import Any

from endgame.visualization._base import BaseVisualizer


[docs] class VennDiagramVisualizer(BaseVisualizer): """Interactive Venn diagram visualizer (2 or 3 sets). Parameters ---------- sets : dict of str → int Set name → total count. intersections : dict of str → int Intersection keys (e.g., 'A&B', 'A&B&C') → count. title : str, optional Chart title. palette : str, default='tableau' Color palette. width : int, default=600 Chart width. height : int, default=500 Chart height. theme : str, default='dark' 'dark' or 'light'. """ def __init__( self, sets: dict[str, int], intersections: dict[str, int], *, title: str = "", palette: str = "tableau", width: int = 600, height: int = 500, theme: str = "dark", ): super().__init__(title=title, palette=palette, width=width, height=height, theme=theme) self._sets = {k: int(v) for k, v in sets.items()} self._intersections = {k: int(v) for k, v in intersections.items()} def _build_data(self) -> dict[str, Any]: names = list(self._sets.keys()) sizes = [self._sets[n] for n in names] intersections = {} for key, val in self._intersections.items(): intersections[key] = val return { "names": names, "sizes": sizes, "intersections": intersections, "nSets": len(names), } def _chart_type(self) -> str: return "venn" def _get_chart_js(self) -> str: return _VENN_JS
_VENN_JS = r""" function renderChart(data, config) { const container = document.getElementById('chart-container'); const size = Math.min(config.width, config.height); const svg = EG.svg('svg', {width: size, height: size}); container.appendChild(svg); container.style.width = size + 'px'; container.style.height = size + 'px'; const palette = config.palette; const names = data.names; const sizes = data.sizes; const n = data.nSets; const cx = size / 2, cy = size / 2; // Circle positions const R = size * 0.22; const offset = size * 0.12; const circles = []; if (n === 2) { circles.push({x: cx - offset * 0.6, y: cy, r: R}); circles.push({x: cx + offset * 0.6, y: cy, r: R}); } else if (n >= 3) { const angleStep = (2 * Math.PI) / Math.min(n, 3); for (let i = 0; i < Math.min(n, 3); i++) { const a = -Math.PI / 2 + i * angleStep; circles.push({x: cx + offset * Math.cos(a), y: cy + offset * Math.sin(a), r: R}); } } // Draw circles circles.forEach(function(c, i) { const color = palette[i % palette.length]; const circle = EG.svg('circle', { cx: c.x, cy: c.y, r: c.r, fill: color, opacity: 0.25, stroke: color, 'stroke-width': 2.5 }); circle.addEventListener('mouseenter', function(e) { circle.setAttribute('opacity', '0.4'); EG.tooltip.show(e, '<b>' + EG.esc(names[i]) + '</b><br>Size: ' + sizes[i]); }); circle.addEventListener('mouseleave', function() { circle.setAttribute('opacity', '0.25'); EG.tooltip.hide(); }); svg.appendChild(circle); // Label const lx = c.x + (c.x > cx ? 20 : c.x < cx ? -20 : 0); const ly = c.y + (c.y > cy ? R + 20 : c.y < cy ? -R - 10 : -R - 10); const label = EG.svg('text', { x: lx, y: ly, 'text-anchor': 'middle', fill: 'var(--text-primary)', 'font-size': '13px', 'font-weight': '600' }); label.textContent = names[i]; svg.appendChild(label); // Size text on circle const sx = c.x + (c.x > cx ? R*0.3 : c.x < cx ? -R*0.3 : 0); const sy = c.y + (n === 2 ? 0 : (c.y > cy ? R*0.3 : -R*0.3)); const sizeText = EG.svg('text', { x: sx, y: sy + 4, 'text-anchor': 'middle', fill: 'var(--text-secondary)', 'font-size': '14px', 'font-weight': '500' }); sizeText.textContent = sizes[i]; svg.appendChild(sizeText); }); // Intersection labels const intKeys = Object.keys(data.intersections); intKeys.forEach(function(key) { const val = data.intersections[key]; // Show at center for now const txt = EG.svg('text', { x: cx, y: cy + 4, 'text-anchor': 'middle', fill: 'var(--text-primary)', 'font-size': '16px', 'font-weight': '700' }); txt.textContent = val; svg.appendChild(txt); }); } """