-
Notifications
You must be signed in to change notification settings - Fork 95
fix: Allow loading CUDA-saved models on CPU-only machines #296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
robosimon
wants to merge
4
commits into
AdaptiveMotorControlLab:main
Choose a base branch
from
robosimon:fix/cuda-load-on-cpu
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+494
−13
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
acd63c7
test: Add tests for loading CUDA-saved models on CPU-only machines
robosimon 97d5b90
fix: Allow loading CUDA-saved models on CPU-only machines
robosimon 55f7589
refactor: Address Steffen's review comments on PR #296
robosimon c02a95f
refactor: Address Steffen's code review comments
robosimon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| #!/usr/bin/env python | ||
| """Generate a CUDA-saved checkpoint for integration testing. | ||
|
|
||
| Run this script on a machine with a CUDA GPU to produce a checkpoint file | ||
| that can be used to verify the CUDA-to-CPU loading fallback in a CI | ||
| environment (which typically has no GPU). | ||
|
|
||
| Usage:: | ||
|
|
||
| # Default output path | ||
| python tests/generate_cuda_checkpoint.py | ||
|
|
||
| # Custom output path | ||
| python tests/generate_cuda_checkpoint.py --output /tmp/cuda_checkpoint.pt | ||
|
|
||
| # Verify an existing checkpoint | ||
| python tests/generate_cuda_checkpoint.py --verify tests/test_data/cuda_checkpoint.pt | ||
|
|
||
| Requirements: | ||
| - PyTorch with CUDA support (``torch.cuda.is_available()`` must be True) | ||
| - CEBRA installed (``pip install -e .`` from the repo root) | ||
|
|
||
| The generated file is a standard ``torch.save`` checkpoint in the CEBRA | ||
| sklearn format. It contains CUDA tensors, so loading it on a CPU-only | ||
| machine *without* the fallback logic will fail with:: | ||
|
|
||
| RuntimeError: Attempting to deserialize object on a CUDA device but | ||
| torch.cuda.is_available() is False. | ||
| """ | ||
|
|
||
| import argparse | ||
| import os | ||
| import sys | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
|
|
||
| def generate(output_path: str) -> None: | ||
| """Train a minimal CEBRA model on CUDA and save the checkpoint.""" | ||
| if not torch.cuda.is_available(): | ||
| print("ERROR: CUDA is not available. Run this on a GPU machine.", | ||
| file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| import cebra | ||
|
|
||
| print(f"PyTorch {torch.__version__}, CUDA {torch.version.cuda}") | ||
| print(f"Device: {torch.cuda.get_device_name(0)}") | ||
|
|
||
| # Train a tiny model on GPU | ||
| X = np.random.uniform(0, 1, (200, 10)).astype(np.float32) | ||
| model = cebra.CEBRA( | ||
| model_architecture="offset1-model", | ||
| max_iterations=10, | ||
| batch_size=64, | ||
| output_dimension=4, | ||
| device="cuda", | ||
| verbose=False, | ||
| ) | ||
| model.fit(X) | ||
|
|
||
| # Sanity-check: model params should live on CUDA | ||
| param_device = next(model.solver_.model.parameters()).device | ||
| assert param_device.type == "cuda", f"Expected cuda, got {param_device}" | ||
|
|
||
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) | ||
| model.save(output_path) | ||
| print(f"Saved CUDA checkpoint to {output_path}") | ||
|
|
||
| # Verify round-trip on GPU | ||
| loaded = cebra.CEBRA.load(output_path) | ||
| emb = loaded.transform(X) | ||
| assert emb.shape == (200, 4), f"Unexpected shape: {emb.shape}" | ||
| print("Round-trip verification on GPU: OK") | ||
|
|
||
|
|
||
| def verify(path: str) -> None: | ||
| """Load a checkpoint on CPU and confirm the fallback works.""" | ||
| import cebra | ||
|
|
||
| if not os.path.exists(path): | ||
| print(f"ERROR: {path} does not exist.", file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| print(f"Loading checkpoint from {path} ...") | ||
| model = cebra.CEBRA.load(path) | ||
| print(f" device_: {model.device_}") | ||
| print(f" device: {model.device}") | ||
|
|
||
| X = np.random.uniform(0, 1, (50, model.n_features_)).astype(np.float32) | ||
| emb = model.transform(X) | ||
| print(f" transform shape: {emb.shape}") | ||
| print("Verification: OK") | ||
|
|
||
|
|
||
| def main() -> None: | ||
| parser = argparse.ArgumentParser(description=__doc__, | ||
| formatter_class=argparse.RawDescriptionHelpFormatter) | ||
| parser.add_argument( | ||
| "--output", | ||
| default="tests/test_data/cuda_checkpoint.pt", | ||
| help="Output path for the generated checkpoint (default: tests/test_data/cuda_checkpoint.pt)", | ||
| ) | ||
| parser.add_argument( | ||
| "--verify", | ||
| metavar="PATH", | ||
| help="Instead of generating, verify an existing checkpoint can be loaded.", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| if args.verify: | ||
| verify(args.verify) | ||
| else: | ||
| generate(args.output) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.