Source code for endgame.persistence._core
from __future__ import annotations
"""Core save/load functions for model persistence."""
from pathlib import Path
from typing import Any
from endgame.persistence._backends import (
load_joblib,
load_pickle,
load_torch,
save_joblib,
save_pickle,
save_torch,
)
from endgame.persistence._detection import detect_backend
from endgame.persistence._metadata import collect_metadata
[docs]
def save(
estimator: Any,
path: str,
backend: str = "auto",
compress: int | None = None,
) -> str:
"""Save any sklearn-compatible estimator to disk.
Parameters
----------
estimator : estimator object
A fitted (or unfitted) sklearn-compatible estimator.
path : str
Destination file or directory path. The appropriate extension
(``.egm`` for single-file, ``.egd`` for PyTorch directory)
will be added automatically.
backend : str, default="auto"
Serialization backend. ``"auto"`` selects ``"torch"`` for
estimators containing ``nn.Module`` attributes, ``"joblib"``
otherwise. Explicit options: ``"joblib"``, ``"torch"``,
``"pickle"``.
compress : int or None
Compression level (0-9). Only used by the joblib backend.
``None`` means no compression.
Returns
-------
str
The actual path where the model was saved.
Examples
--------
>>> from sklearn.linear_model import LogisticRegression
>>> import endgame as eg
>>> model = LogisticRegression().fit(X_train, y_train)
>>> eg.save(model, "/tmp/my_model")
'/tmp/my_model.egm'
>>> loaded = eg.load("/tmp/my_model.egm")
"""
resolved_backend = detect_backend(estimator, preferred=backend)
metadata = collect_metadata(estimator, resolved_backend, compress)
dest = Path(path)
if resolved_backend == "torch":
result = save_torch(estimator, metadata, dest, compress)
elif resolved_backend == "pickle":
result = save_pickle(estimator, metadata, dest, compress)
else:
result = save_joblib(estimator, metadata, dest, compress)
return str(result)
[docs]
def load(
path: str,
map_location: str | None = None,
) -> Any:
"""Load an estimator from disk.
Parameters
----------
path : str
Path to a ``.egm`` file or ``.egd`` directory.
map_location : str or None
PyTorch ``map_location`` for loading tensors (e.g. ``"cpu"``).
Only relevant for ``.egd`` (PyTorch) saves.
Returns
-------
estimator
The loaded estimator.
Examples
--------
>>> import endgame as eg
>>> model = eg.load("/tmp/my_model.egm")
>>> model.predict(X_test)
"""
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"No model found at: {path}")
if p.is_dir():
# Directory format — either .egd or a directory without extension
if (p / "metadata.json").exists():
estimator, _meta = load_torch(p, map_location=map_location)
return estimator
raise ValueError(
f"Directory {path} does not appear to be a valid .egd model directory "
"(missing metadata.json)."
)
# Single-file format
try:
estimator, _meta = load_joblib(p)
return estimator
except Exception:
# Fall back to pickle
try:
estimator, _meta = load_pickle(p)
return estimator
except Exception as exc:
raise ValueError(
f"Could not load model from {path}. "
f"The file may be corrupted or in an unsupported format."
) from exc