Source code for endgame.preprocessing.imbalance_llm

"""LLM-based oversampler for class imbalance handling.

This module provides tabular data generation via language model fine-tuning,
following the GReaT (Generation of Realistic Tabular data) approach.

Algorithms
----------
- GReaTResampler: Fine-tune a causal LM on serialized tabular rows

References
----------
- GReaT (Borisov et al., 2023)
- ImbLLM improvements (2025)
"""

from __future__ import annotations

from typing import Any

import numpy as np
from numpy.typing import ArrayLike
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_X_y

from endgame.preprocessing.imbalance_geometric import _compute_sampling_targets

try:
    import torch
    from torch.utils.data import Dataset as TorchDataset
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        Trainer,
        TrainingArguments,
    )

    HAS_TRANSFORMERS = True
except ImportError:
    HAS_TRANSFORMERS = False


def _check_transformers():
    if not HAS_TRANSFORMERS:
        raise ImportError(
            "transformers and PyTorch are required for GReaTResampler. "
            "Install with: pip install transformers torch"
        )


def _serialize_row(
    row: np.ndarray, feature_names: list[str], label_name: str, label_value: Any
) -> str:
    """Serialize a tabular row to natural language.

    Example output: "Feature1 is 3.14, Feature2 is -1.00, Class is 1"
    """
    parts = [f"{name} is {val:.4g}" for name, val in zip(feature_names, row)]
    parts.append(f"{label_name} is {label_value}")
    return ", ".join(parts)


def _parse_generated_text(
    text: str, feature_names: list[str], label_name: str, n_features: int
) -> np.ndarray | None:
    """Parse a generated text string back to a feature vector.

    Returns None if parsing fails.
    """
    values = {}
    for part in text.split(","):
        part = part.strip()
        if " is " not in part:
            continue
        name, val_str = part.split(" is ", 1)
        name = name.strip()
        val_str = val_str.strip()
        if name in feature_names:
            try:
                values[name] = float(val_str)
            except ValueError:
                continue

    if len(values) < n_features // 2:
        return None

    row = np.zeros(n_features)
    for i, name in enumerate(feature_names):
        if name in values:
            row[i] = values[name]
    return row


# Internal dataset for fine-tuning
_GreatDataset = None

if HAS_TRANSFORMERS:

    class _GreatDataset(TorchDataset):
        """Dataset of serialized tabular rows for LM fine-tuning."""

        def __init__(self, texts: list[str], tokenizer, max_length: int):
            self.encodings = tokenizer(
                texts,
                truncation=True,
                padding="max_length",
                max_length=max_length,
                return_tensors="pt",
            )

        def __len__(self):
            return self.encodings["input_ids"].shape[0]

        def __getitem__(self, idx):
            item = {k: v[idx] for k, v in self.encodings.items()}
            item["labels"] = item["input_ids"].clone()
            return item


[docs] class GReaTResampler(BaseEstimator): """GReaT oversampler: LLM-based tabular data generation. Serializes tabular rows as natural language strings, fine-tunes a small causal language model (e.g. distilgpt2), and generates new minority samples by prompting with the minority class label prefix. Parameters ---------- sampling_strategy : str, float, or dict, default='auto' See :func:`~imbalance_geometric._compute_sampling_targets`. model_name : str, default='distilgpt2' HuggingFace model name for the causal LM backbone. n_epochs : int, default=5 Fine-tuning epochs. batch_size : int, default=8 Training batch size. max_length : int, default=256 Maximum token length for serialized rows. temperature : float, default=0.7 Sampling temperature for generation. feature_names : list of str or None, default=None Feature names. If None, uses ``f0, f1, ...``. label_name : str, default='Class' Name for the target column in serialization. device : str, default='auto' Computation device. random_state : int or None, default=None Random seed. verbose : bool, default=False Enable verbose output. References ---------- GReaT (Borisov et al., 2023), ImbLLM (2025) """ def __init__( self, sampling_strategy: str | float | dict = "auto", model_name: str = "distilgpt2", n_epochs: int = 5, batch_size: int = 8, max_length: int = 256, temperature: float = 0.7, feature_names: list[str] | None = None, label_name: str = "Class", device: str = "auto", random_state: int | None = None, verbose: bool = False, ): self.sampling_strategy = sampling_strategy self.model_name = model_name self.n_epochs = n_epochs self.batch_size = batch_size self.max_length = max_length self.temperature = temperature self.feature_names = feature_names self.label_name = label_name self.device = device self.random_state = random_state self.verbose = verbose def _log(self, msg: str): if self.verbose: print(f"[GReaT] {msg}") def _get_device(self) -> str: if self.device == "auto": return "cuda" if torch.cuda.is_available() else "cpu" return self.device
[docs] def fit(self, X: ArrayLike, y: ArrayLike) -> GReaTResampler: """Fit (validate input).""" _check_transformers() X, y = check_X_y(X, y) self.targets_ = _compute_sampling_targets(y, self.sampling_strategy) return self
[docs] def fit_resample( self, X: ArrayLike, y: ArrayLike ) -> tuple[np.ndarray, np.ndarray]: """Fit and resample using GReaT LLM generation.""" self.fit(X, y) X = np.asarray(X, dtype=np.float64) y = np.asarray(y) n_features = X.shape[1] feature_names = self.feature_names or [f"f{i}" for i in range(n_features)] device = self._get_device() synthetic_X: list[np.ndarray] = [] synthetic_y: list[np.ndarray] = [] for cls, n_synthetic in self.targets_.items(): if n_synthetic <= 0: continue X_cls = X[y == cls] if len(X_cls) == 0: continue self._log(f"Fine-tuning {self.model_name} on class {cls} " f"({len(X_cls)} samples)...") # Serialize training data texts = [ _serialize_row(row, feature_names, self.label_name, cls) for row in X_cls ] # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(self.model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(self.model_name) # Fine-tune dataset = _GreatDataset(texts, tokenizer, self.max_length) training_args = TrainingArguments( output_dir="/tmp/great_training", num_train_epochs=self.n_epochs, per_device_train_batch_size=self.batch_size, logging_steps=max(1, len(texts) // self.batch_size), save_strategy="no", report_to="none", seed=self.random_state or 42, no_cuda=(device == "cpu"), ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, ) trainer.train() # Generate samples self._log(f"Generating {n_synthetic} samples for class {cls}...") prompt = f"{self.label_name} is {cls}" prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to( model.device ) generated_rows = [] attempts = 0 max_attempts = n_synthetic * 5 model.eval() while len(generated_rows) < n_synthetic and attempts < max_attempts: batch_n = min(self.batch_size, n_synthetic - len(generated_rows)) input_ids = prompt_ids.repeat(batch_n, 1) with torch.no_grad(): outputs = model.generate( input_ids, max_new_tokens=self.max_length, temperature=self.temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) for output in outputs: text = tokenizer.decode(output, skip_special_tokens=True) row = _parse_generated_text( text, feature_names, self.label_name, n_features ) if row is not None: generated_rows.append(row) attempts += 1 # Fill shortfall with random duplicates if len(generated_rows) < n_synthetic: rng = np.random.RandomState(self.random_state) shortfall = n_synthetic - len(generated_rows) extra_idx = rng.choice(len(X_cls), size=shortfall, replace=True) generated_rows.extend(X_cls[extra_idx]) synthetic_X.append(np.array(generated_rows[:n_synthetic])) synthetic_y.append(np.full(n_synthetic, cls)) if synthetic_X: X_out = np.vstack([X] + synthetic_X) y_out = np.concatenate([y] + synthetic_y) else: X_out, y_out = X.copy(), y.copy() return X_out, y_out
# Category dict for registration LLM_SAMPLERS = { "great": GReaTResampler, }