Source code for endgame.utils.reproducibility
from __future__ import annotations
"""Reproducibility utilities for consistent experiments."""
import os
import random
import numpy as np
[docs]
def seed_everything(seed: int = 42) -> None:
"""Set random seeds for reproducibility.
Sets seeds for:
- Python random
- NumPy
- PyTorch (if available)
- TensorFlow (if available)
- CUDA (if available)
Also sets environment variables for deterministic behavior.
Parameters
----------
seed : int, default=42
Random seed to use.
Examples
--------
>>> from endgame.utils import seed_everything
>>> seed_everything(42)
"""
# Python random
random.seed(seed)
# Environment variables
os.environ["PYTHONHASHSEED"] = str(seed)
# NumPy
np.random.seed(seed)
# PyTorch
try:
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# For deterministic behavior (may impact performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
except ImportError:
pass
# TensorFlow
try:
import tensorflow as tf
tf.random.set_seed(seed)
# For TF 2.x
os.environ["TF_DETERMINISTIC_OPS"] = "1"
except ImportError:
pass
[docs]
class SeedEverything:
"""Context manager for reproducible experiments.
Sets random seeds on entry and optionally restores state on exit.
Parameters
----------
seed : int, default=42
Random seed to use.
restore : bool, default=False
Whether to restore random state on exit.
Examples
--------
>>> with SeedEverything(42):
... # Reproducible code here
... pass
>>> seed_ctx = SeedEverything(42)
>>> with seed_ctx:
... result = train_model()
"""
def __init__(self, seed: int = 42, restore: bool = False):
self.seed = seed
self.restore = restore
self._python_state: tuple | None = None
self._numpy_state: dict | None = None
self._torch_state: any | None = None
self._cuda_state: any | None = None
def __enter__(self) -> SeedEverything:
"""Enter the context manager."""
if self.restore:
# Save current state
self._python_state = random.getstate()
self._numpy_state = np.random.get_state()
try:
import torch
self._torch_state = torch.get_rng_state()
if torch.cuda.is_available():
self._cuda_state = torch.cuda.get_rng_state_all()
except ImportError:
pass
# Set seeds
seed_everything(self.seed)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit the context manager."""
if self.restore:
# Restore state
if self._python_state is not None:
random.setstate(self._python_state)
if self._numpy_state is not None:
np.random.set_state(self._numpy_state)
try:
import torch
if self._torch_state is not None:
torch.set_rng_state(self._torch_state)
if self._cuda_state is not None and torch.cuda.is_available():
torch.cuda.set_rng_state_all(self._cuda_state)
except ImportError:
pass
class ReproducibleRun:
"""Context manager for fully reproducible ML experiments.
Combines seed setting with logging of environment info.
Parameters
----------
seed : int, default=42
Random seed.
log_environment : bool, default=True
Whether to log environment information.
Examples
--------
>>> with ReproducibleRun(seed=42) as run:
... print(run.environment_info)
... # Train model
... pass
"""
def __init__(self, seed: int = 42, log_environment: bool = True):
self.seed = seed
self.log_environment = log_environment
self.environment_info: dict = {}
self._seed_ctx: SeedEverything | None = None
def __enter__(self) -> ReproducibleRun:
"""Enter the context manager."""
# Set seeds
self._seed_ctx = SeedEverything(self.seed)
self._seed_ctx.__enter__()
# Log environment
if self.log_environment:
self.environment_info = self._get_environment_info()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit the context manager."""
if self._seed_ctx is not None:
self._seed_ctx.__exit__(exc_type, exc_val, exc_tb)
def _get_environment_info(self) -> dict:
"""Get environment information for logging."""
import platform
import sys
info = {
"python_version": sys.version,
"platform": platform.platform(),
"numpy_version": np.__version__,
"seed": self.seed,
}
# Add package versions
packages = ["sklearn", "pandas", "polars", "torch", "lightgbm", "xgboost", "catboost"]
for pkg in packages:
try:
mod = __import__(pkg)
info[f"{pkg}_version"] = getattr(mod, "__version__", "unknown")
except ImportError:
pass
return info
def save_config(self, filepath: str) -> None:
"""Save reproducibility configuration to file.
Parameters
----------
filepath : str
Output file path (JSON or YAML).
"""
import json
config = {
"seed": self.seed,
"environment": self.environment_info,
}
with open(filepath, 'w') as f:
json.dump(config, f, indent=2)