diff --git a/.gitignore b/.gitignore index e30f5f43..d21f5a98 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,10 @@ saved_models/ *.svg *.mp4 +# Test checkpoint binaries (generate with tests/generate_cuda_checkpoint.py) +tests/test_data/cuda_saved_checkpoint/ +tests/test_data/*.pt + ## gitignore for Python ## Source: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore # Byte-compiled / optimized / DLL files diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 00645523..c4e96117 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -69,18 +69,79 @@ def check_version(estimator): sklearn.__version__) < packaging.version.parse("1.6.dev") -def _safe_torch_load(filename, weights_only=False, **kwargs): - checkpoint = None +def _safe_torch_load(filename, weights_only=False, _is_retry=False, **kwargs): + """Load a checkpoint with automatic CUDA/MPS to CPU fallback. + + If loading fails due to a CUDA or MPS device error (e.g. checkpoint was + saved on a GPU but no GPU is available), this function automatically retries + with ``map_location="cpu"`` and issues a warning. + + Args: + filename: Path to the checkpoint file. + weights_only: Passed through to :func:`torch.load`. + _is_retry: Internal flag to prevent infinite recursion. Do not set + this manually. + **kwargs: Additional keyword arguments forwarded to :func:`torch.load`. + + Returns: + The loaded checkpoint dictionary. + + Raises: + RuntimeError: If loading fails on the retry attempt or for non-device + related reasons. + """ legacy_mode = packaging.version.parse( torch.__version__) < packaging.version.parse("2.6.0") - if legacy_mode: - checkpoint = torch.load(filename, weights_only=False, **kwargs) - else: - with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): - checkpoint = torch.load(filename, - weights_only=weights_only, - **kwargs) + try: + if legacy_mode: + checkpoint = torch.load(filename, weights_only=False, **kwargs) + else: + with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): + checkpoint = torch.load(filename, + weights_only=weights_only, + **kwargs) + except RuntimeError as e: + error_msg = str(e) + is_device_error = ("CUDA" in error_msg + or "cuda" in error_msg.lower() + or "MPS" in error_msg + or "mps" in error_msg.lower()) + if is_device_error: + if _is_retry: + raise RuntimeError( + f"Failed to load checkpoint even with map_location='cpu'. " + f"The checkpoint appears to require a device (CUDA/MPS) " + f"that is not available in the current environment. " + f"Please verify your PyTorch installation or load on a " + f"machine with the required hardware. " + f"Original error: {e}" + ) from e + if "map_location" in kwargs: + raise RuntimeError( + f"Loading the checkpoint failed with a device error even " + f"though map_location={kwargs['map_location']!r} was " + f"explicitly specified. The checkpoint was likely saved on " + f"a CUDA/MPS device that is not available. Please check " + f"your PyTorch installation or use a machine with the " + f"required hardware. Original error: {e}" + ) from e + warnings.warn( + f"Checkpoint was saved on a device that is not available " + f"(error: {error_msg}). Automatically falling back to CPU. " + f"To suppress this warning, pass map_location='cpu' " + f"explicitly.", + UserWarning, + stacklevel=2, + ) + kwargs["map_location"] = torch.device("cpu") + return _safe_torch_load( + filename, + weights_only=weights_only, + _is_retry=True, + **kwargs, + ) + raise if not isinstance(checkpoint, dict): _check_type_checkpoint(checkpoint) @@ -334,6 +395,47 @@ def _check_type_checkpoint(checkpoint): return checkpoint +def _resolve_checkpoint_device(device: Union[str, torch.device]) -> str: + """Resolve the device stored in a checkpoint for the current runtime. + + If a checkpoint was saved on a device (CUDA, MPS, ...) that is unavailable + at load time, this falls back to CPU and issues a warning. + + Args: + device: The device from the checkpoint. + + Returns: + The resolved device string. + """ + if isinstance(device, str): + device = torch.device(device) + + if not isinstance(device, torch.device): + raise TypeError( + f"Expected checkpoint device to be a string or torch.device, " + f"got {type(device)}.") + + fallback_to_cpu = False + + if device.type == "cuda" and not torch.cuda.is_available(): + fallback_to_cpu = True + elif device.type == "mps" and ( + not hasattr(torch.backends, "mps") + or not torch.backends.mps.is_available()): + fallback_to_cpu = True + + if fallback_to_cpu: + warnings.warn( + f"Checkpoint was saved on '{device}' which is not available in " + f"the current environment. Automatically falling back to CPU.", + UserWarning, + stacklevel=2, + ) + return "cpu" + + return sklearn_utils.check_device(str(device)) + + def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": """Loads a CEBRA model with a Sklearn backend. @@ -357,11 +459,20 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": args, state, state_dict = cebra_info['args'], cebra_info[ 'state'], cebra_info['state_dict'] + + saved_device = state["device_"] + load_device = _resolve_checkpoint_device(saved_device) + cebra_ = cebra.CEBRA(**args) for key, value in state.items(): setattr(cebra_, key, value) + cebra_.device_ = load_device + saved_device = torch.device(saved_device) if isinstance(saved_device, str) else saved_device + if saved_device.type == "cuda" and load_device == "cpu": + cebra_.device = "cpu" + #TODO(stes): unused right now #state_and_args = {**args, **state} @@ -375,7 +486,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": num_neurons=state["n_features_in_"], num_units=args["num_hidden_units"], num_output=args["output_dimension"], - ).to(state['device_']) + ).to(load_device) elif isinstance(cebra_.num_sessions_, int): model = nn.ModuleList([ @@ -385,10 +496,10 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": num_units=args["num_hidden_units"], num_output=args["output_dimension"], ) for n_features in state["n_features_in_"] - ]).to(state['device_']) + ]).to(load_device) criterion = cebra_._prepare_criterion() - criterion.to(state['device_']) + criterion.to(load_device) optimizer = torch.optim.Adam( itertools.chain(model.parameters(), criterion.parameters()), @@ -404,7 +515,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": tqdm_on=args['verbose'], ) solver.load_state_dict(state_dict) - solver.to(state['device_']) + solver.to(load_device) cebra_.model_ = model cebra_.solver_ = solver diff --git a/tests/generate_cuda_checkpoint.py b/tests/generate_cuda_checkpoint.py new file mode 100644 index 00000000..06b1c85a --- /dev/null +++ b/tests/generate_cuda_checkpoint.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +"""Generate a CUDA-saved checkpoint for integration testing. + +Run this script on a machine with a CUDA GPU to produce a checkpoint file +that can be used to verify the CUDA-to-CPU loading fallback in a CI +environment (which typically has no GPU). + +Usage:: + + # Default output path + python tests/generate_cuda_checkpoint.py + + # Custom output path + python tests/generate_cuda_checkpoint.py --output /tmp/cuda_checkpoint.pt + + # Verify an existing checkpoint + python tests/generate_cuda_checkpoint.py --verify tests/test_data/cuda_checkpoint.pt + +Requirements: + - PyTorch with CUDA support (``torch.cuda.is_available()`` must be True) + - CEBRA installed (``pip install -e .`` from the repo root) + +The generated file is a standard ``torch.save`` checkpoint in the CEBRA +sklearn format. It contains CUDA tensors, so loading it on a CPU-only +machine *without* the fallback logic will fail with:: + + RuntimeError: Attempting to deserialize object on a CUDA device but + torch.cuda.is_available() is False. +""" + +import argparse +import os +import sys + +import numpy as np +import torch + + +def generate(output_path: str) -> None: + """Train a minimal CEBRA model on CUDA and save the checkpoint.""" + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available. Run this on a GPU machine.", + file=sys.stderr) + sys.exit(1) + + import cebra + + print(f"PyTorch {torch.__version__}, CUDA {torch.version.cuda}") + print(f"Device: {torch.cuda.get_device_name(0)}") + + # Train a tiny model on GPU + X = np.random.uniform(0, 1, (200, 10)).astype(np.float32) + model = cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=10, + batch_size=64, + output_dimension=4, + device="cuda", + verbose=False, + ) + model.fit(X) + + # Sanity-check: model params should live on CUDA + param_device = next(model.solver_.model.parameters()).device + assert param_device.type == "cuda", f"Expected cuda, got {param_device}" + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + model.save(output_path) + print(f"Saved CUDA checkpoint to {output_path}") + + # Verify round-trip on GPU + loaded = cebra.CEBRA.load(output_path) + emb = loaded.transform(X) + assert emb.shape == (200, 4), f"Unexpected shape: {emb.shape}" + print("Round-trip verification on GPU: OK") + + +def verify(path: str) -> None: + """Load a checkpoint on CPU and confirm the fallback works.""" + import cebra + + if not os.path.exists(path): + print(f"ERROR: {path} does not exist.", file=sys.stderr) + sys.exit(1) + + print(f"Loading checkpoint from {path} ...") + model = cebra.CEBRA.load(path) + print(f" device_: {model.device_}") + print(f" device: {model.device}") + + X = np.random.uniform(0, 1, (50, model.n_features_)).astype(np.float32) + emb = model.transform(X) + print(f" transform shape: {emb.shape}") + print("Verification: OK") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--output", + default="tests/test_data/cuda_checkpoint.pt", + help="Output path for the generated checkpoint (default: tests/test_data/cuda_checkpoint.pt)", + ) + parser.add_argument( + "--verify", + metavar="PATH", + help="Instead of generating, verify an existing checkpoint can be loaded.", + ) + args = parser.parse_args() + + if args.verify: + verify(args.verify) + else: + generate(args.output) + + +if __name__ == "__main__": + main() diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 831ad49d..c41a55fe 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1097,6 +1097,253 @@ def get_ordered_cuda_devices(): ) else [] +@pytest.mark.parametrize("saved_device", [ + "cuda", + "cuda:0", + torch.device("cuda"), + torch.device("cuda", 0), +]) +@pytest.mark.parametrize("model_architecture", ["offset1-model", "parametrized-model-5"]) +def test_load_cuda_checkpoint_falls_back_to_cpu(saved_device, model_architecture, monkeypatch): + """Test that CUDA-saved checkpoints can be loaded on CPU-only machines. + + This tests the fix for: Loading a model saved on CUDA when only CPU is available + should gracefully fall back to CPU instead of raising RuntimeError. + """ + X = np.random.uniform(0, 1, (100, 5)) + + # Train a model on CPU + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture=model_architecture, + max_iterations=5, + device="cpu" + ).fit(X) + + with _windows_compatible_tempfile(mode="w+b") as tempname: + # Save the model + cebra_model.save(tempname) + + # Modify the checkpoint to have a CUDA device + checkpoint = cebra_sklearn_cebra._safe_torch_load(tempname) + checkpoint["state"]["device_"] = saved_device + torch.save(checkpoint, tempname) + + # Mock CUDA as unavailable (simulating CPU-only machine) + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + + # This should NOT raise RuntimeError: No CUDA GPUs are available + # A warning should be emitted about the automatic device fallback + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) + + # At least one warning about the device fallback should have been raised + device_warnings = [x for x in w if "falling back to CPU" in str(x.message)] + assert len(device_warnings) > 0, ( + f"Expected a warning about falling back to CPU, got: " + f"{[str(x.message) for x in w]}" + ) + + # Verify model is on CPU + assert loaded_model.device_ == "cpu", f"Expected device_='cpu', got {loaded_model.device_!r}" + assert loaded_model.device == "cpu", f"Expected device='cpu', got {loaded_model.device!r}" + assert next(loaded_model.solver_.model.parameters()).device == torch.device("cpu") + + # Verify model actually works (can do inference) + X_test = np.random.uniform(0, 1, (10, 5)) + embedding = loaded_model.transform(X_test) + assert embedding.shape[0] == 10 # Correct number of samples + assert embedding.shape[1] > 0 # Has some output dimensions + assert isinstance(embedding, np.ndarray) + + +def test_safe_torch_load_cuda_fallback(monkeypatch): + """Test that _safe_torch_load retries with map_location='cpu' on CUDA errors. + + This exercises the actual torch.load failure path when CUDA tensors are + present but CUDA is unavailable. + """ + import tempfile + import os + + # Create a simple checkpoint + checkpoint = {"test": torch.tensor([1.0, 2.0, 3.0])} + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tempname = f.name + torch.save(checkpoint, tempname) + + try: + # Mock torch.load to fail when map_location is not set + # (simulating CUDA tensor load error) + original_torch_load = torch.load + call_count = [0] + + def mock_torch_load(*args, **kwargs): + call_count[0] += 1 + if "map_location" not in kwargs: + raise RuntimeError( + "Attempting to deserialize object on a CUDA device " + "but torch.cuda.is_available() is False" + ) + return original_torch_load(*args, **kwargs) + + monkeypatch.setattr(torch, "load", mock_torch_load) + + # Should retry with map_location='cpu' and succeed, emitting a warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = cebra_sklearn_cebra._safe_torch_load(tempname) + + assert "test" in result + assert torch.allclose(result["test"], checkpoint["test"]) + # Two calls: first fails (no map_location), second succeeds (with map_location) + assert call_count[0] == 2 + + # Should have warned about the fallback + fallback_warnings = [ + x for x in w if "falling back to CPU" in str(x.message) + ] + assert len(fallback_warnings) == 1 + + finally: + os.unlink(tempname) + + +def test_safe_torch_load_meaningful_error_on_retry_failure(monkeypatch): + """Test that a meaningful error is raised when CPU fallback also fails.""" + import tempfile + import os + + checkpoint = {"test": torch.tensor([1.0, 2.0, 3.0])} + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tempname = f.name + torch.save(checkpoint, tempname) + + try: + # Mock torch.load to always fail with a CUDA error + def mock_torch_load(*args, **kwargs): + raise RuntimeError( + "CUDA error: device-side assert triggered" + ) + + monkeypatch.setattr(torch, "load", mock_torch_load) + + with pytest.raises(RuntimeError, match="Failed to load checkpoint even with"): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + cebra_sklearn_cebra._safe_torch_load(tempname) + finally: + os.unlink(tempname) + + +def test_safe_torch_load_error_with_explicit_map_location(monkeypatch): + """Test meaningful error when map_location is already set but CUDA error occurs.""" + import tempfile + import os + + checkpoint = {"test": torch.tensor([1.0, 2.0, 3.0])} + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tempname = f.name + torch.save(checkpoint, tempname) + + try: + def mock_torch_load(*args, **kwargs): + raise RuntimeError( + "Attempting to deserialize object on a CUDA device " + "but torch.cuda.is_available() is False" + ) + + monkeypatch.setattr(torch, "load", mock_torch_load) + + with pytest.raises(RuntimeError, match="explicitly specified"): + cebra_sklearn_cebra._safe_torch_load( + tempname, map_location=torch.device("cpu") + ) + finally: + os.unlink(tempname) + + +@pytest.mark.parametrize("saved_device", ["cuda", "cuda:0"]) +def test_load_cuda_checkpoint_with_device_override(saved_device, monkeypatch): + """Test that automatic CPU fallback works with CUDA checkpoints.""" + X = np.random.uniform(0, 1, (100, 5)) + + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=5, + device="cpu" + ).fit(X) + + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + checkpoint = cebra_sklearn_cebra._safe_torch_load(tempname) + checkpoint["state"]["device_"] = saved_device + torch.save(checkpoint, tempname) + + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + + # Load should automatically fall back to CPU with a warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) + + device_warnings = [x for x in w if "falling back to CPU" in str(x.message)] + assert len(device_warnings) > 0 + + # Model should be usable + X_test = np.random.uniform(0, 1, (10, 5)) + embedding = loaded_model.transform(X_test) + assert embedding.shape[0] == 10 + assert embedding.shape[1] > 0 + + +@pytest.mark.parametrize("saved_device", ["mps"]) +def test_load_mps_checkpoint_falls_back_to_cpu(saved_device, monkeypatch): + """Test that MPS-saved checkpoints can be loaded when MPS is unavailable. + + Mirrors the CUDA fallback test but for Apple Silicon MPS devices. + """ + X = np.random.uniform(0, 1, (100, 5)) + + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=5, + device="cpu" + ).fit(X) + + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + + # Patch the checkpoint to claim it was saved on MPS + checkpoint = cebra_sklearn_cebra._safe_torch_load(tempname) + checkpoint["state"]["device_"] = saved_device + torch.save(checkpoint, tempname) + + # Mock MPS as unavailable + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) + + device_warnings = [x for x in w if "falling back to CPU" in str(x.message)] + assert len(device_warnings) > 0, ( + f"Expected a warning about falling back to CPU, got: " + f"{[str(x.message) for x in w]}" + ) + + assert loaded_model.device_ == "cpu" + assert loaded_model.device == "cpu" + + X_test = np.random.uniform(0, 1, (10, 5)) + embedding = loaded_model.transform(X_test) + assert embedding.shape[0] == 10 + assert isinstance(embedding, np.ndarray) + + def test_fit_after_moving_to_device(): expected_device = 'cpu' expected_type = type(expected_device)