Source code for endgame.ensemble.snapshot

"""Snapshot Ensemble: Ensemble from learning rate schedule snapshots.

Collects model snapshots at the end of each warm-restart cycle
(cosine annealing learning rate schedule) and averages their
predictions. Gets the diversity of an ensemble for the cost of
training a single model.

Theory: Huang et al. (2017), "Snapshot Ensembles: Train 1, Get M for Free."

Example
-------
>>> from endgame.ensemble import SnapshotEnsemble
>>> snap = SnapshotEnsemble(
...     base_estimator=MLPClassifier(max_iter=1),
...     n_snapshots=5,
...     epochs_per_cycle=40,
...     initial_lr=0.1,
... )
>>> snap.fit(X_train, y_train)
>>> snap.predict(X_test)
"""

from __future__ import annotations

import copy

import numpy as np
from sklearn.base import BaseEstimator, clone


[docs] class SnapshotEnsemble(BaseEstimator): """Snapshot Ensemble via cosine annealing warm restarts. Trains a single neural-network-like estimator with a cyclic learning rate schedule. At the end of each cycle (when LR reaches its minimum), takes a "snapshot" of the model. The final ensemble averages predictions across all snapshots. Parameters ---------- base_estimator : estimator A model supporting ``partial_fit`` (e.g., ``MLPClassifier``, ``SGDClassifier``, ``SGDRegressor``). Must accept ``learning_rate_init`` or ``eta0``. n_snapshots : int, default=5 Number of snapshots (cycles) to collect. epochs_per_cycle : int, default=40 Training epochs per cosine annealing cycle. initial_lr : float, default=0.1 Peak learning rate at the start of each cycle. min_lr : float, default=1e-5 Minimum learning rate at end of each cycle (snapshot point). verbose : bool, default=False Attributes ---------- snapshots_ : list of estimator Saved model snapshots. lr_history_ : list of float Learning rate at each epoch. is_classifier_ : bool References ---------- Huang, G., Li, Y., Pleiss, G., Liu, Z., Hopcroft, J.E., & Weinberger, K.Q. (2017). Snapshot Ensembles: Train 1, Get M for Free. *ICLR*. """ def __init__( self, base_estimator: BaseEstimator, n_snapshots: int = 5, epochs_per_cycle: int = 40, initial_lr: float = 0.1, min_lr: float = 1e-5, verbose: bool = False, ): self.base_estimator = base_estimator self.n_snapshots = n_snapshots self.epochs_per_cycle = epochs_per_cycle self.initial_lr = initial_lr self.min_lr = min_lr self.verbose = verbose def _cosine_lr(self, epoch_in_cycle: int) -> float: """Cosine annealing learning rate for current epoch within a cycle.""" t = epoch_in_cycle / max(self.epochs_per_cycle, 1) return self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1 + np.cos(np.pi * t))
[docs] def fit(self, X, y, **fit_params): """Train with cyclic LR and collect snapshots. Parameters ---------- X : array-like y : array-like """ X = np.asarray(X) y = np.asarray(y) # Detect task type self.is_classifier_ = len(np.unique(y)) <= 30 and ( np.issubdtype(y.dtype, np.integer) or len(np.unique(y)) <= 10 ) if self.is_classifier_: self.classes_ = np.unique(y) model = clone(self.base_estimator) # Check for partial_fit support if not hasattr(model, "partial_fit"): raise ValueError( "base_estimator must support partial_fit for snapshot ensemble. " "Use MLPClassifier, SGDClassifier, SGDRegressor, etc." ) self.snapshots_ = [] self.lr_history_ = [] for cycle in range(self.n_snapshots): if self.verbose: print(f"[Snapshot] Cycle {cycle + 1}/{self.n_snapshots}") for epoch in range(self.epochs_per_cycle): lr = self._cosine_lr(epoch) self.lr_history_.append(lr) # Set learning rate if hasattr(model, "learning_rate_init"): model.learning_rate_init = lr if hasattr(model, "eta0"): model.eta0 = lr # partial_fit if self.is_classifier_ and cycle == 0 and epoch == 0 or self.is_classifier_: model.partial_fit(X, y, classes=self.classes_) else: model.partial_fit(X, y) # Take snapshot at end of cycle (minimum LR) snapshot = copy.deepcopy(model) self.snapshots_.append(snapshot) if self.verbose: if hasattr(snapshot, "score"): score = snapshot.score(X, y) print(f" Snapshot score: {score:.4f}") return self
[docs] def predict(self, X): X = np.asarray(X) if self.is_classifier_: proba = self.predict_proba(X) return self.classes_[np.argmax(proba, axis=1)] preds = np.array([s.predict(X) for s in self.snapshots_]) return preds.mean(axis=0)
[docs] def predict_proba(self, X): if not self.is_classifier_: raise ValueError("predict_proba only for classification.") X = np.asarray(X) probas = np.array([s.predict_proba(X) for s in self.snapshots_]) return probas.mean(axis=0)