from __future__ import annotations
"""Metadata collection for model persistence."""
import platform
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import Any
import numpy as np
CURRENT_FORMAT_VERSION = 1
def _sanitize_params(params: dict) -> dict:
"""Make params JSON-serializable by converting non-standard types."""
clean = {}
for k, v in params.items():
if isinstance(v, np.generic):
clean[k] = v.item()
elif isinstance(v, np.ndarray):
clean[k] = v.tolist()
elif isinstance(v, dict):
clean[k] = _sanitize_params(v)
elif isinstance(v, (str, int, float, bool, type(None), list)):
clean[k] = v
else:
clean[k] = str(v)
return clean
def _get_dependency_versions() -> dict[str, str]:
"""Collect versions of key ML dependencies."""
deps = {"numpy": np.__version__}
for pkg_name, import_name in [
("scikit-learn", "sklearn"),
("polars", "polars"),
("pandas", "pandas"),
("torch", "torch"),
("lightgbm", "lightgbm"),
("xgboost", "xgboost"),
("catboost", "catboost"),
("joblib", "joblib"),
]:
try:
mod = __import__(import_name)
deps[pkg_name] = getattr(mod, "__version__", "unknown")
except ImportError:
pass
return deps
def _check_is_fitted(estimator) -> bool:
"""Check if an estimator appears to be fitted."""
# EndgameEstimator uses _is_fitted
if hasattr(estimator, "_is_fitted"):
return estimator._is_fitted
if hasattr(estimator, "is_fitted_"):
return estimator.is_fitted_
# sklearn convention: fitted estimators have attributes ending in _
from sklearn.utils.validation import check_is_fitted as sklearn_check
try:
sklearn_check(estimator)
return True
except Exception:
return False
def collect_metadata(
estimator,
backend: str,
compression: int | None,
) -> ModelMetadata:
"""Gather metadata by introspecting the estimator and environment.
Parameters
----------
estimator : estimator object
The sklearn-compatible estimator to inspect.
backend : str
Backend name ("joblib", "torch", or "pickle").
compression : int or None
Compression level.
Returns
-------
ModelMetadata
Populated metadata object.
"""
import endgame
cls = type(estimator)
model_class = f"{cls.__module__}.{cls.__qualname__}"
# Get params safely
params = {}
if hasattr(estimator, "get_params"):
try:
params = estimator.get_params(deep=False)
except Exception:
pass
# Feature metadata
n_features = getattr(estimator, "n_features_in_", None)
if n_features is None:
n_features = getattr(estimator, "_n_features_in", None)
feature_names = getattr(estimator, "feature_names_in_", None)
if feature_names is None:
feature_names = getattr(estimator, "_feature_names_in", None)
if feature_names is not None:
feature_names = list(feature_names)
classes = getattr(estimator, "classes_", None)
if classes is not None:
classes = list(classes)
return ModelMetadata(
endgame_version=endgame.__version__,
format_version=CURRENT_FORMAT_VERSION,
model_class=model_class,
model_params=params,
created_at=datetime.now(timezone.utc).isoformat(),
python_version=platform.python_version(),
dependencies=_get_dependency_versions(),
n_features_in_=n_features,
feature_names_in_=feature_names,
classes_=classes,
is_fitted=_check_is_fitted(estimator),
backend=backend,
compression=compression,
)