Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ saved_models/
*.svg
*.mp4

# Test checkpoint binaries (generate with tests/generate_cuda_checkpoint.py)
tests/test_data/cuda_saved_checkpoint/
tests/test_data/*.pt

## gitignore for Python
## Source: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
# Byte-compiled / optimized / DLL files
Expand Down
137 changes: 124 additions & 13 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,79 @@ def check_version(estimator):
sklearn.__version__) < packaging.version.parse("1.6.dev")


def _safe_torch_load(filename, weights_only=False, **kwargs):
checkpoint = None
def _safe_torch_load(filename, weights_only=False, _is_retry=False, **kwargs):
"""Load a checkpoint with automatic CUDA/MPS to CPU fallback.

If loading fails due to a CUDA or MPS device error (e.g. checkpoint was
saved on a GPU but no GPU is available), this function automatically retries
with ``map_location="cpu"`` and issues a warning.

Args:
filename: Path to the checkpoint file.
weights_only: Passed through to :func:`torch.load`.
_is_retry: Internal flag to prevent infinite recursion. Do not set
this manually.
**kwargs: Additional keyword arguments forwarded to :func:`torch.load`.

Returns:
The loaded checkpoint dictionary.

Raises:
RuntimeError: If loading fails on the retry attempt or for non-device
related reasons.
"""
legacy_mode = packaging.version.parse(
torch.__version__) < packaging.version.parse("2.6.0")

if legacy_mode:
checkpoint = torch.load(filename, weights_only=False, **kwargs)
else:
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
checkpoint = torch.load(filename,
weights_only=weights_only,
**kwargs)
try:
if legacy_mode:
checkpoint = torch.load(filename, weights_only=False, **kwargs)
else:
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
checkpoint = torch.load(filename,
weights_only=weights_only,
**kwargs)
except RuntimeError as e:
error_msg = str(e)
is_device_error = ("CUDA" in error_msg
or "cuda" in error_msg.lower()
or "MPS" in error_msg
or "mps" in error_msg.lower())
if is_device_error:
if _is_retry:
raise RuntimeError(
f"Failed to load checkpoint even with map_location='cpu'. "
f"The checkpoint appears to require a device (CUDA/MPS) "
f"that is not available in the current environment. "
f"Please verify your PyTorch installation or load on a "
f"machine with the required hardware. "
f"Original error: {e}"
) from e
if "map_location" in kwargs:
raise RuntimeError(
f"Loading the checkpoint failed with a device error even "
f"though map_location={kwargs['map_location']!r} was "
f"explicitly specified. The checkpoint was likely saved on "
f"a CUDA/MPS device that is not available. Please check "
f"your PyTorch installation or use a machine with the "
f"required hardware. Original error: {e}"
) from e
warnings.warn(
f"Checkpoint was saved on a device that is not available "
f"(error: {error_msg}). Automatically falling back to CPU. "
f"To suppress this warning, pass map_location='cpu' "
f"explicitly.",
UserWarning,
stacklevel=2,
)
kwargs["map_location"] = torch.device("cpu")
return _safe_torch_load(
filename,
weights_only=weights_only,
_is_retry=True,
**kwargs,
)
raise

if not isinstance(checkpoint, dict):
_check_type_checkpoint(checkpoint)
Expand Down Expand Up @@ -334,6 +395,47 @@ def _check_type_checkpoint(checkpoint):
return checkpoint


def _resolve_checkpoint_device(device: Union[str, torch.device]) -> str:
"""Resolve the device stored in a checkpoint for the current runtime.

If a checkpoint was saved on a device (CUDA, MPS, ...) that is unavailable
at load time, this falls back to CPU and issues a warning.

Args:
device: The device from the checkpoint.

Returns:
The resolved device string.
"""
if isinstance(device, str):
device = torch.device(device)

if not isinstance(device, torch.device):
raise TypeError(
f"Expected checkpoint device to be a string or torch.device, "
f"got {type(device)}.")

fallback_to_cpu = False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

if device.type == "cuda" and not torch.cuda.is_available():
fallback_to_cpu = True
elif device.type == "mps" and (
not hasattr(torch.backends, "mps")
or not torch.backends.mps.is_available()):
fallback_to_cpu = True

if fallback_to_cpu:
warnings.warn(
f"Checkpoint was saved on '{device}' which is not available in "
f"the current environment. Automatically falling back to CPU.",
UserWarning,
stacklevel=2,
)
return "cpu"

return sklearn_utils.check_device(str(device))


def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
"""Loads a CEBRA model with a Sklearn backend.

Expand All @@ -357,11 +459,20 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":

args, state, state_dict = cebra_info['args'], cebra_info[
'state'], cebra_info['state_dict']

saved_device = state["device_"]
load_device = _resolve_checkpoint_device(saved_device)

cebra_ = cebra.CEBRA(**args)

for key, value in state.items():
setattr(cebra_, key, value)

cebra_.device_ = load_device
saved_device = torch.device(saved_device) if isinstance(saved_device, str) else saved_device
if saved_device.type == "cuda" and load_device == "cpu":
cebra_.device = "cpu"

#TODO(stes): unused right now
#state_and_args = {**args, **state}

Expand All @@ -375,7 +486,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
num_neurons=state["n_features_in_"],
num_units=args["num_hidden_units"],
num_output=args["output_dimension"],
).to(state['device_'])
).to(load_device)

elif isinstance(cebra_.num_sessions_, int):
model = nn.ModuleList([
Expand All @@ -385,10 +496,10 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
num_units=args["num_hidden_units"],
num_output=args["output_dimension"],
) for n_features in state["n_features_in_"]
]).to(state['device_'])
]).to(load_device)

criterion = cebra_._prepare_criterion()
criterion.to(state['device_'])
criterion.to(load_device)

optimizer = torch.optim.Adam(
itertools.chain(model.parameters(), criterion.parameters()),
Expand All @@ -404,7 +515,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
tqdm_on=args['verbose'],
)
solver.load_state_dict(state_dict)
solver.to(state['device_'])
solver.to(load_device)

cebra_.model_ = model
cebra_.solver_ = solver
Expand Down
119 changes: 119 additions & 0 deletions tests/generate_cuda_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/usr/bin/env python
"""Generate a CUDA-saved checkpoint for integration testing.

Run this script on a machine with a CUDA GPU to produce a checkpoint file
that can be used to verify the CUDA-to-CPU loading fallback in a CI
environment (which typically has no GPU).

Usage::

# Default output path
python tests/generate_cuda_checkpoint.py

# Custom output path
python tests/generate_cuda_checkpoint.py --output /tmp/cuda_checkpoint.pt

# Verify an existing checkpoint
python tests/generate_cuda_checkpoint.py --verify tests/test_data/cuda_checkpoint.pt

Requirements:
- PyTorch with CUDA support (``torch.cuda.is_available()`` must be True)
- CEBRA installed (``pip install -e .`` from the repo root)

The generated file is a standard ``torch.save`` checkpoint in the CEBRA
sklearn format. It contains CUDA tensors, so loading it on a CPU-only
machine *without* the fallback logic will fail with::

RuntimeError: Attempting to deserialize object on a CUDA device but
torch.cuda.is_available() is False.
"""

import argparse
import os
import sys

import numpy as np
import torch


def generate(output_path: str) -> None:
"""Train a minimal CEBRA model on CUDA and save the checkpoint."""
if not torch.cuda.is_available():
print("ERROR: CUDA is not available. Run this on a GPU machine.",
file=sys.stderr)
sys.exit(1)

import cebra

print(f"PyTorch {torch.__version__}, CUDA {torch.version.cuda}")
print(f"Device: {torch.cuda.get_device_name(0)}")

# Train a tiny model on GPU
X = np.random.uniform(0, 1, (200, 10)).astype(np.float32)
model = cebra.CEBRA(
model_architecture="offset1-model",
max_iterations=10,
batch_size=64,
output_dimension=4,
device="cuda",
verbose=False,
)
model.fit(X)

# Sanity-check: model params should live on CUDA
param_device = next(model.solver_.model.parameters()).device
assert param_device.type == "cuda", f"Expected cuda, got {param_device}"

os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
model.save(output_path)
print(f"Saved CUDA checkpoint to {output_path}")

# Verify round-trip on GPU
loaded = cebra.CEBRA.load(output_path)
emb = loaded.transform(X)
assert emb.shape == (200, 4), f"Unexpected shape: {emb.shape}"
print("Round-trip verification on GPU: OK")


def verify(path: str) -> None:
"""Load a checkpoint on CPU and confirm the fallback works."""
import cebra

if not os.path.exists(path):
print(f"ERROR: {path} does not exist.", file=sys.stderr)
sys.exit(1)

print(f"Loading checkpoint from {path} ...")
model = cebra.CEBRA.load(path)
print(f" device_: {model.device_}")
print(f" device: {model.device}")

X = np.random.uniform(0, 1, (50, model.n_features_)).astype(np.float32)
emb = model.transform(X)
print(f" transform shape: {emb.shape}")
print("Verification: OK")


def main() -> None:
parser = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument(
"--output",
default="tests/test_data/cuda_checkpoint.pt",
help="Output path for the generated checkpoint (default: tests/test_data/cuda_checkpoint.pt)",
)
parser.add_argument(
"--verify",
metavar="PATH",
help="Instead of generating, verify an existing checkpoint can be loaded.",
)
args = parser.parse_args()

if args.verify:
verify(args.verify)
else:
generate(args.output)


if __name__ == "__main__":
main()
Loading