"""Knowledge Distillation for model compression.
Compress ensemble or complex models into simpler, deployable models
while preserving most of the predictive performance.
Example
-------
>>> from endgame.ensemble import KnowledgeDistiller
>>> from sklearn.ensemble import GradientBoostingClassifier
>>> from sklearn.tree import DecisionTreeClassifier
>>>
>>> teacher = GradientBoostingClassifier(n_estimators=200).fit(X, y)
>>> distiller = KnowledgeDistiller(
... teacher=teacher,
... student=DecisionTreeClassifier(max_depth=6),
... temperature=3.0
... )
>>> student = distiller.fit(X, y)
>>> predictions = student.predict(X_test)
"""
from __future__ import annotations
import numpy as np
from sklearn.base import (
BaseEstimator,
clone,
is_classifier,
)
from sklearn.utils.validation import check_is_fitted
def _softmax_with_temperature(logits: np.ndarray, temperature: float) -> np.ndarray:
"""Apply softmax with temperature scaling."""
scaled = logits / temperature
# Numerical stability
scaled = scaled - scaled.max(axis=1, keepdims=True)
exp_scaled = np.exp(scaled)
return exp_scaled / exp_scaled.sum(axis=1, keepdims=True)
def _munge_augment(
X: np.ndarray,
y: np.ndarray,
n_augmented: int,
swap_prob: float = 0.1,
random_state: int | None = None,
) -> tuple:
"""MUNGE data augmentation for knowledge distillation.
Creates synthetic training data by swapping feature values between
nearest neighbor pairs with some probability.
Parameters
----------
X : ndarray of shape (n_samples, n_features)
y : ndarray of shape (n_samples,) or (n_samples, n_classes)
Hard labels or soft probabilities.
n_augmented : int
Number of augmented samples to generate.
swap_prob : float
Probability of swapping each feature.
random_state : int or None
Returns
-------
X_aug, y_aug : augmented data
"""
rng = np.random.RandomState(random_state)
n_samples, n_features = X.shape
X_aug = np.empty((n_augmented, n_features), dtype=X.dtype)
if y.ndim == 1:
y_aug = np.empty(n_augmented, dtype=y.dtype)
else:
y_aug = np.empty((n_augmented, y.shape[1]), dtype=y.dtype)
for i in range(n_augmented):
# Pick a random sample
idx = rng.randint(0, n_samples)
x_base = X[idx].copy()
y_base = y[idx].copy()
# Pick a random neighbor (simple random for efficiency)
neighbor_idx = rng.randint(0, n_samples)
while neighbor_idx == idx:
neighbor_idx = rng.randint(0, n_samples)
# Swap features with probability
swap_mask = rng.random(n_features) < swap_prob
x_base[swap_mask] = X[neighbor_idx, swap_mask]
X_aug[i] = x_base
y_aug[i] = y_base
return X_aug, y_aug
[docs]
class KnowledgeDistiller(BaseEstimator):
"""Knowledge distillation from teacher to student model.
Trains a simpler student model to mimic the predictions of a
complex teacher model (or ensemble), enabling deployment of
lightweight models with minimal accuracy loss.
Parameters
----------
teacher : estimator
Fitted teacher model. Must have predict_proba (classification)
or predict (regression).
student : estimator
Unfitted student model to train.
temperature : float, default=3.0
Softmax temperature for soft label generation (classification only).
Higher values produce softer probability distributions that reveal
more about the teacher's learned relationships.
alpha : float, default=0.7
Weight for soft labels vs hard labels.
Loss = alpha * soft_loss + (1 - alpha) * hard_loss.
Set to 1.0 for pure distillation.
augment : bool, default=False
Whether to use MUNGE data augmentation to generate additional
training data labeled by the teacher.
augment_ratio : float, default=1.0
Ratio of augmented samples to original samples.
augment_swap_prob : float, default=0.1
Feature swap probability for MUNGE augmentation.
random_state : int or None, default=None
Random state.
Attributes
----------
student_ : estimator
The trained student model.
teacher_score_ : float or None
Teacher's accuracy/R2 on training data (for reference).
student_score_ : float or None
Student's accuracy/R2 on training data.
is_classifier_ : bool
Whether this is a classification task.
Example
-------
>>> from sklearn.ensemble import RandomForestClassifier
>>> from sklearn.linear_model import LogisticRegression
>>>
>>> teacher = RandomForestClassifier(n_estimators=500).fit(X, y)
>>> distiller = KnowledgeDistiller(
... teacher=teacher,
... student=LogisticRegression(),
... temperature=4.0,
... alpha=0.8,
... augment=True
... )
>>> distiller.fit(X, y)
>>> y_pred = distiller.predict(X_test)
"""
def __init__(
self,
teacher,
student,
temperature: float = 3.0,
alpha: float = 0.7,
augment: bool = False,
augment_ratio: float = 1.0,
augment_swap_prob: float = 0.1,
random_state: int | None = None,
):
self.teacher = teacher
self.student = student
self.temperature = temperature
self.alpha = alpha
self.augment = augment
self.augment_ratio = augment_ratio
self.augment_swap_prob = augment_swap_prob
self.random_state = random_state
[docs]
def fit(self, X, y, **fit_params) -> KnowledgeDistiller:
"""Train the student model using knowledge distillation.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training features.
y : array-like of shape (n_samples,)
True labels (hard targets).
Returns
-------
self
"""
X = np.asarray(X)
y = np.asarray(y)
self.is_classifier_ = is_classifier(self.teacher)
if self.is_classifier_:
self._fit_classification(X, y, **fit_params)
else:
self._fit_regression(X, y, **fit_params)
return self
def _fit_classification(self, X, y, **fit_params):
"""Distillation for classification tasks."""
# Get teacher's soft predictions
teacher_proba = self.teacher.predict_proba(X)
n_classes = teacher_proba.shape[1]
# Generate soft labels with temperature
if self.temperature != 1.0:
# Convert probabilities to logits, then apply temperature
eps = 1e-10
logits = np.log(teacher_proba + eps)
soft_labels = _softmax_with_temperature(logits, self.temperature)
else:
soft_labels = teacher_proba
# Data augmentation with MUNGE
if self.augment:
n_aug = int(len(X) * self.augment_ratio)
X_aug, _ = _munge_augment(
X, y, n_aug,
swap_prob=self.augment_swap_prob,
random_state=self.random_state,
)
# Label augmented data with teacher
teacher_proba_aug = self.teacher.predict_proba(X_aug)
if self.temperature != 1.0:
logits_aug = np.log(teacher_proba_aug + 1e-10)
soft_labels_aug = _softmax_with_temperature(logits_aug, self.temperature)
else:
soft_labels_aug = teacher_proba_aug
X_combined = np.vstack([X, X_aug])
soft_combined = np.vstack([soft_labels, soft_labels_aug])
# For hard labels, use teacher predictions on augmented data
y_aug = np.argmax(teacher_proba_aug, axis=1)
y_combined = np.concatenate([y, y_aug])
else:
X_combined = X
soft_combined = soft_labels
y_combined = y
# Train student
# Strategy: if alpha < 1, use blend of hard and soft labels
# For most sklearn classifiers, we train on the teacher's hard predictions
# (soft labels require specialized loss functions)
if self.alpha >= 0.99:
# Pure distillation: train on teacher's predictions
y_teacher = np.argmax(soft_combined, axis=1)
self.student_ = clone(self.student)
# If student supports sample_weight, use confidence as weight
if hasattr(self.student_, 'fit') and 'sample_weight' in \
self.student_.fit.__code__.co_varnames:
weights = np.max(soft_combined, axis=1)
self.student_.fit(X_combined, y_teacher,
sample_weight=weights, **fit_params)
else:
self.student_.fit(X_combined, y_teacher, **fit_params)
else:
# Blend: mix hard labels with teacher predictions
# Use teacher predictions weighted by alpha
y_teacher = np.argmax(soft_combined, axis=1)
# With probability (1-alpha), use true hard labels
rng = np.random.RandomState(self.random_state)
use_hard = rng.random(len(y_combined)) > self.alpha
y_train = y_teacher.copy()
y_train[use_hard] = y_combined[use_hard]
self.student_ = clone(self.student)
self.student_.fit(X_combined, y_train, **fit_params)
# Record scores
self.teacher_score_ = float(np.mean(
self.teacher.predict(X) == y
))
self.student_score_ = float(np.mean(
self.student_.predict(X) == y
))
def _fit_regression(self, X, y, **fit_params):
"""Distillation for regression tasks."""
# Get teacher predictions
teacher_preds = self.teacher.predict(X)
if self.augment:
n_aug = int(len(X) * self.augment_ratio)
X_aug, _ = _munge_augment(
X, y, n_aug,
swap_prob=self.augment_swap_prob,
random_state=self.random_state,
)
teacher_preds_aug = self.teacher.predict(X_aug)
X_combined = np.vstack([X, X_aug])
y_combined = np.concatenate([y, y])
teacher_combined = np.concatenate([teacher_preds, teacher_preds_aug])
else:
X_combined = X
y_combined = y
teacher_combined = teacher_preds
# Blend teacher predictions with true labels
y_train = self.alpha * teacher_combined + (1 - self.alpha) * y_combined
self.student_ = clone(self.student)
self.student_.fit(X_combined, y_train, **fit_params)
# Record scores
from sklearn.metrics import r2_score
self.teacher_score_ = r2_score(y, teacher_preds)
self.student_score_ = r2_score(y, self.student_.predict(X))
[docs]
def predict(self, X) -> np.ndarray:
"""Predict using the trained student model."""
check_is_fitted(self, 'student_')
return self.student_.predict(X)
[docs]
def predict_proba(self, X) -> np.ndarray:
"""Predict probabilities using the trained student model."""
check_is_fitted(self, 'student_')
if not self.is_classifier_:
raise AttributeError("predict_proba is not available for regression.")
return self.student_.predict_proba(X)
@property
def feature_importances_(self):
"""Feature importances from the student model."""
check_is_fitted(self, 'student_')
if hasattr(self.student_, 'feature_importances_'):
return self.student_.feature_importances_
raise AttributeError("Student model does not have feature_importances_.")
[docs]
def compression_report(self) -> dict:
"""Generate a report comparing teacher and student performance.
Returns
-------
dict with keys:
teacher_score, student_score, score_retention,
teacher_type, student_type
"""
check_is_fitted(self, 'student_')
retention = (self.student_score_ / self.teacher_score_ * 100
if self.teacher_score_ > 0 else 0)
return {
"teacher_type": type(self.teacher).__name__,
"student_type": type(self.student_).__name__,
"teacher_score": self.teacher_score_,
"student_score": self.student_score_,
"score_retention_pct": round(retention, 1),
"metric": "accuracy" if self.is_classifier_ else "r2",
}