Add fused_adam, quantized_model_init, and fsdp2 example#2698
Add fused_adam, quantized_model_init, and fsdp2 example#2698pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
22604c4 to
4d89e04
Compare
Greptile SummaryThis PR enables Key changes:
Note: The PR correctly documents that Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
Start([FusedAdam.step]) --> CheckGrad{p.grad or<br/>p.decoupled_grad<br/>exists?}
CheckGrad -->|No| Skip[Skip parameter]
CheckGrad -->|Yes| ExtractGrad[Get gradient tensor]
ExtractGrad --> IsDTensorGrad{Is gradient<br/>DTensor?}
IsDTensorGrad -->|Yes| UnwrapGrad[Extract p_grad._local_tensor]
IsDTensorGrad -->|No| UseGrad[Use gradient as-is]
UnwrapGrad --> GetStates[Get optimizer states:<br/>exp_avg, exp_avg_sq,<br/>master_param]
UseGrad --> GetStates
GetStates --> IsDTensorParam{Is param<br/>DTensor?}
IsDTensorParam -->|Yes| UnwrapParam[Extract p._local_tensor]
IsDTensorParam -->|No| UseParam[Use param as-is]
UnwrapParam --> IsFP8{Is local tensor<br/>Float8Tensor?}
UseParam --> IsFP8
IsFP8 -->|Yes| ExtractFP8[Extract _data, scale,<br/>amax, scale_inv]
IsFP8 -->|No| CheckLowPrec{Is fp16/bf16?}
ExtractFP8 --> CallKernel[Call multi_tensor_adam<br/>with FP8 metadata]
CheckLowPrec -->|Yes| CallKernelFP16[Call multi_tensor_adam<br/>for fp16/bf16]
CheckLowPrec -->|No| CallKernelFP32[Call multi_tensor_adam<br/>for fp32]
CallKernel --> End([Update complete])
CallKernelFP16 --> End
CallKernelFP32 --> End
Skip --> End
Last reviewed commit: 9710315 |
0103b53 to
3c3dbd2
Compare
|
@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up. |
a4d691f to
872caef
Compare
9ccc0c3 to
eb8606a
Compare
…hard Signed-off-by: Peter St. John <pstjohn@nvidia.com>
eb8606a to
c2415e4
Compare
| pytest.xfail( | ||
| "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " | ||
| "MXFP8 quantized tensors, causing illegal memory access" | ||
| ) |
There was a problem hiding this comment.
claude's analysis:
Root cause: MXFP8Tensor inherits from QuantizedTensor but NOT from Float8Tensor. In fused_adam.py step():
- Line 617: isinstance(p, Float8Tensor) → False for MXFP8Tensor
- Line 629: p.dtype in [torch.float16, torch.bfloat16] → True (nominal dtype is bfloat16)
- So p.data (which is still an MXFP8Tensor wrapper with MXFP8-layout storage) gets added to p_f16_model
- The multi_tensor_adam CUDA kernel treats this as plain bf16 memory → illegal memory access
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
| continue | ||
| # Extract local tensors from DTensors (e.g. from FSDP2) | ||
| # so that multi_tensor kernels receive plain CUDA tensors. | ||
| if isinstance(p_grad, DTensor): |
There was a problem hiding this comment.
Is this really needed? Since p_grad is high precision, all Dtensor's op should get translated to local tensor's op without needing to extract it.
| column-wise FP8 data. | ||
|
|
||
| """ | ||
| data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data |
There was a problem hiding this comment.
Can you please provide reason(maybe in the PR details) as to why this is needed? And where is this used? Is this needed for DCP checkpointing. And if so why?
|
/te-ci L1 pytorch |
| if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): | ||
| model_state = { | ||
| k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state") | ||
| } |
There was a problem hiding this comment.
Can you also please comment why we should avoid saving _extra_state. As in what error we get with dcp if we dont do so?
There was a problem hiding this comment.
This one I remember, the others I'll need to comment out and run the test suite again 😅
But this is a known hassle where torch DCP needs the sizes of these tensors to remain consistent during saving & loading, and since this is pickled data, it changes when there's data in that field.
The alternative is a detailed load_planner for DelayedScaling that reads and allocates the extra state data tensor
Summary
FusedAdamto work with PyTorch-native FSDP2 (fully_shard) when parameters areDTensor-wrappedFloat8Tensor/QuantizedTensorfuse_wgrad_accumulationguard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)quantized_model_initon single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)Note:
fuse_wgrad_accumulationremains incompatible with vanilla FSDP2fuse_wgrad_accumulationstill cannot be used with vanilla FSDP2. The feature writes weight gradients directly intomain_gradand returnsNoneto autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiringget_main_grad()into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.Fixes #2682