from __future__ import annotations
"""Quality guardrails for AutoML pipelines.
This module provides data quality checks that run early in the pipeline
to detect issues like target leakage, redundant features, and data health
problems before expensive model training begins.
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import pandas as pd
from endgame.automl.orchestrator import BaseStageExecutor, StageResult
logger = logging.getLogger(__name__)
[docs]
@dataclass
class DataQualityWarning:
"""A single data quality issue.
Attributes
----------
category : str
Issue category: "leakage", "redundancy", "data_health".
severity : str
Severity level: "critical", "warning", "info".
message : str
Human-readable description.
details : dict
Additional details (feature names, values, etc.).
"""
category: str
severity: str
message: str
details: dict[str, Any] = field(default_factory=dict)
[docs]
@dataclass
class GuardrailsReport:
"""Aggregated result from all guardrail checks.
Attributes
----------
warnings : list of DataQualityWarning
All detected issues.
passed : bool
True if no critical issues found.
n_critical : int
Number of critical issues.
n_warnings : int
Number of warning-level issues.
"""
warnings: list[DataQualityWarning] = field(default_factory=list)
passed: bool = True
n_critical: int = 0
n_warnings: int = 0
[docs]
def add(self, warning: DataQualityWarning) -> None:
"""Add a warning and update counts."""
self.warnings.append(warning)
if warning.severity == "critical":
self.n_critical += 1
self.passed = False
elif warning.severity == "warning":
self.n_warnings += 1
[docs]
class QualityGuardrailsExecutor(BaseStageExecutor):
"""Performs data quality checks early in the pipeline.
Checks for target leakage, feature redundancy, and general data
health issues. By default issues are logged as warnings; set
``strict=True`` to abort on critical problems.
Parameters
----------
strict : bool, default=False
If True, sets ``fail_fast=True`` in context metadata on critical
issues, causing the orchestrator to abort early.
leakage_threshold : float, default=0.95
Absolute correlation with target above which a feature is
flagged as potential leakage.
redundancy_threshold : float, default=0.98
Absolute pairwise correlation above which a feature pair is
flagged as redundant.
"""
def __init__(
self,
strict: bool = False,
leakage_threshold: float = 0.95,
redundancy_threshold: float = 0.98,
):
self.strict = strict
self.leakage_threshold = leakage_threshold
self.redundancy_threshold = redundancy_threshold
[docs]
def execute(
self,
context: dict[str, Any],
time_budget: float,
) -> StageResult:
"""Run all guardrail checks.
Parameters
----------
context : dict
Pipeline context containing ``X``, ``y``, and ``task_type``.
time_budget : float
Time budget in seconds.
Returns
-------
StageResult
Contains ``guardrails_report`` in output.
"""
start = time.time()
report = GuardrailsReport()
X = context.get("X")
y = context.get("y")
if X is None or y is None:
return StageResult(
stage_name="quality_guardrails",
success=True,
duration=time.time() - start,
output={"guardrails_report": report},
)
# Convert to numpy for uniform handling
if isinstance(X, pd.DataFrame):
feature_names = X.columns.tolist()
X_arr = X.select_dtypes(include=[np.number]).values
numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist()
else:
X_arr = np.asarray(X, dtype=float)
feature_names = [f"feature_{i}" for i in range(X_arr.shape[1])]
numeric_cols = feature_names
y_arr = np.asarray(y).ravel()
n_samples, n_features = X_arr.shape
# --- Data health checks ---
self._check_data_health(
X, X_arr, y_arr, n_samples, n_features, feature_names, report
)
# --- Target leakage ---
if time.time() - start < time_budget * 0.5:
self._check_leakage(X_arr, y_arr, numeric_cols, report)
# --- Feature redundancy ---
if time.time() - start < time_budget * 0.8 and n_features <= 500:
self._check_redundancy(X_arr, numeric_cols, report)
# Log results
for w in report.warnings:
if w.severity == "critical":
logger.warning(f"[Guardrails CRITICAL] {w.message}")
elif w.severity == "warning":
logger.warning(f"[Guardrails] {w.message}")
else:
logger.debug(f"[Guardrails] {w.message}")
# Set fail_fast if strict and critical issues found
output: dict[str, Any] = {"guardrails_report": report}
if self.strict and not report.passed:
output["fail_fast"] = True
duration = time.time() - start
return StageResult(
stage_name="quality_guardrails",
success=True,
duration=duration,
output=output,
metadata={
"n_critical": report.n_critical,
"n_warnings": report.n_warnings,
},
)
def _check_data_health(
self,
X_raw: Any,
X_arr: np.ndarray,
y_arr: np.ndarray,
n_samples: int,
n_features: int,
feature_names: list[str],
report: GuardrailsReport,
) -> None:
"""Check general data health."""
# Too few samples
if n_samples < 20:
report.add(DataQualityWarning(
category="data_health",
severity="critical",
message=f"Very few samples ({n_samples}). Results will be unreliable.",
details={"n_samples": n_samples},
))
# Feature-to-sample ratio
if n_features > 0 and n_samples > 0 and n_features / n_samples > 10:
report.add(DataQualityWarning(
category="data_health",
severity="warning",
message=(
f"High feature-to-sample ratio ({n_features}/{n_samples} = "
f"{n_features / n_samples:.1f}). Consider feature selection."
),
details={"ratio": n_features / n_samples},
))
# Constant columns
if X_arr.shape[1] > 0:
stds = np.nanstd(X_arr, axis=0)
constant_mask = stds == 0
constant_cols = [
feature_names[i] if i < len(feature_names) else f"col_{i}"
for i in np.where(constant_mask)[0]
]
if constant_cols:
report.add(DataQualityWarning(
category="data_health",
severity="warning",
message=f"{len(constant_cols)} constant column(s) detected.",
details={"columns": constant_cols[:10]},
))
# All-missing columns
if X_arr.shape[1] > 0:
all_missing = np.all(np.isnan(X_arr), axis=0)
missing_cols = [
feature_names[i] if i < len(feature_names) else f"col_{i}"
for i in np.where(all_missing)[0]
]
if missing_cols:
report.add(DataQualityWarning(
category="data_health",
severity="critical",
message=f"{len(missing_cols)} all-missing column(s) detected.",
details={"columns": missing_cols[:10]},
))
# Minority class check (classification)
unique, counts = np.unique(y_arr[~pd.isna(y_arr)], return_counts=True)
if len(unique) > 1 and len(unique) <= 100:
min_frac = counts.min() / counts.sum()
if min_frac < 0.01:
report.add(DataQualityWarning(
category="data_health",
severity="warning",
message=(
f"Extreme class imbalance: minority class has "
f"{min_frac:.2%} of samples ({counts.min()} samples)."
),
details={"min_class_fraction": float(min_frac)},
))
# ID-like columns (unique count == n_samples for integer/object cols)
if isinstance(X_raw, pd.DataFrame):
for col in X_raw.columns:
if X_raw[col].nunique() == n_samples and n_samples > 20:
report.add(DataQualityWarning(
category="data_health",
severity="info",
message=f"Column '{col}' has all unique values (possible ID column).",
details={"column": col},
))
def _check_leakage(
self,
X_arr: np.ndarray,
y_arr: np.ndarray,
feature_names: list[str],
report: GuardrailsReport,
) -> None:
"""Flag features highly correlated with target."""
if X_arr.shape[1] == 0:
return
# Only check numeric target
try:
y_numeric = y_arr.astype(float)
except (ValueError, TypeError):
return
# Compute correlations with target
valid_mask = ~(np.isnan(X_arr).any(axis=1) | np.isnan(y_numeric))
if valid_mask.sum() < 10:
return
X_valid = X_arr[valid_mask]
y_valid = y_numeric[valid_mask]
for i in range(X_valid.shape[1]):
col = X_valid[:, i]
if np.std(col) == 0:
continue
corr = np.abs(np.corrcoef(col, y_valid)[0, 1])
if np.isnan(corr):
continue
if corr > self.leakage_threshold:
name = feature_names[i] if i < len(feature_names) else f"feature_{i}"
report.add(DataQualityWarning(
category="leakage",
severity="critical",
message=(
f"Potential target leakage: '{name}' has "
f"|corr| = {corr:.3f} with target."
),
details={"feature": name, "correlation": float(corr)},
))
def _check_redundancy(
self,
X_arr: np.ndarray,
feature_names: list[str],
report: GuardrailsReport,
) -> None:
"""Flag highly correlated feature pairs."""
if X_arr.shape[1] < 2:
return
# Subsample for speed
n = min(X_arr.shape[0], 5000)
X_sub = X_arr[:n]
# Remove constant columns
stds = np.nanstd(X_sub, axis=0)
nonconst = stds > 0
if nonconst.sum() < 2:
return
X_sub = X_sub[:, nonconst]
names_sub = [
feature_names[i] for i in range(len(feature_names)) if nonconst[i]
] if len(feature_names) == len(nonconst) else [
f"feature_{i}" for i in range(X_sub.shape[1])
]
# Fill NaN with column mean for correlation
col_means = np.nanmean(X_sub, axis=0)
for j in range(X_sub.shape[1]):
mask = np.isnan(X_sub[:, j])
if mask.any():
X_sub[mask, j] = col_means[j]
try:
corr_matrix = np.corrcoef(X_sub.T)
except Exception:
return
redundant_pairs = []
n_cols = corr_matrix.shape[0]
for i in range(n_cols):
for j in range(i + 1, n_cols):
if np.abs(corr_matrix[i, j]) > self.redundancy_threshold:
redundant_pairs.append((names_sub[i], names_sub[j]))
if redundant_pairs:
report.add(DataQualityWarning(
category="redundancy",
severity="warning",
message=(
f"{len(redundant_pairs)} redundant feature pair(s) detected "
f"(|corr| > {self.redundancy_threshold})."
),
details={"pairs": redundant_pairs[:10]},
))