Skip to content

Add fused_adam, quantized_model_init, and fsdp2 example#2698

Open
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam
Open

Add fused_adam, quantized_model_init, and fsdp2 example#2698
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Feb 22, 2026

Summary

  • Fix FusedAdam to work with PyTorch-native FSDP2 (fully_shard) when parameters are DTensor-wrapped Float8Tensor/QuantizedTensor
  • Fix fuse_wgrad_accumulation guard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)
  • Add examples for quantized_model_init on single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)

Note: fuse_wgrad_accumulation remains incompatible with vanilla FSDP2

fuse_wgrad_accumulation still cannot be used with vanilla FSDP2. The feature writes weight gradients directly into main_grad and returns None to autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiring get_main_grad() into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.

Fixes #2682

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch 2 times, most recently from 22604c4 to 4d89e04 Compare February 23, 2026 15:28
@pstjohn pstjohn marked this pull request as ready for review February 23, 2026 17:27
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 23, 2026

Greptile Summary

This PR enables FusedAdam to work with PyTorch FSDP2 when parameters are DTensor-wrapped Float8Tensor/QuantizedTensor, and adds examples for quantized_model_init on single-GPU and multi-GPU FSDP2.

Key changes:

  • FusedAdam extracts _local_tensor from DTensor parameters before operations to ensure multi-tensor kernels receive plain CUDA tensors
  • Quantizers add __getstate__ to exclude unpicklable process groups for checkpoint serialization
  • Block-scaled tensors (Float8Blockwise, NVFP4) convert unsupported view/reshape operations from hard errors to warnings with dequantize fallback
  • Adds untyped_storage() methods for MXFP8 and NVFP4 tensors
  • Comprehensive test suite covering FP8/bf16 params, master weights, DCP checkpointing (sync/async), and safetensors export
  • Recipe handling unified to PascalCase with getattr pattern

Note: The PR correctly documents that fuse_wgrad_accumulation remains incompatible with vanilla FSDP2 (see xfail test on line 118 of test_torch_fsdp2.py)

Confidence Score: 4/5

  • This PR is safe to merge with minor caveats around FSDP2 compatibility workarounds
  • Score reflects solid implementation with comprehensive testing, but includes deliberate workarounds (dequantize fallback for unsupported ops, process group serialization handling) that merit awareness. The test coverage is excellent with 8 distinct scenarios, examples are clear, and DTensor extraction logic is correct. Known limitations are properly documented.
  • transformer_engine/pytorch/tensor/float8_blockwise_tensor.py and transformer_engine/pytorch/tensor/nvfp4_tensor.py - the warning + dequantize fallback for unsupported view/reshape operations is a workaround that changes tensor types mid-computation

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/fused_adam.py Adds DTensor support by extracting _local_tensor before operations, handles QuantizedTensor dequantization for optimizer states
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Converts RuntimeError to warnings with dequantize fallback for unsupported view/reshape operations, FSDP2 workaround
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds __getstate__ and untyped_storage(), converts reshape/view errors to warnings with dequantize fallback
tests/pytorch/distributed/run_fsdp2_fused_adam.py New test runner implementing 8 distinct FSDP2+FusedAdam scenarios with FP8/bf16 params, master weights, DCP, and safetensors
examples/pytorch/quantized_model_init/fully_shard.py New multi-GPU FSDP2 example with meta-device init, quantized sharding, DCP checkpointing, and FP32 safetensors export

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    Start([FusedAdam.step]) --> CheckGrad{p.grad or<br/>p.decoupled_grad<br/>exists?}
    CheckGrad -->|No| Skip[Skip parameter]
    CheckGrad -->|Yes| ExtractGrad[Get gradient tensor]
    ExtractGrad --> IsDTensorGrad{Is gradient<br/>DTensor?}
    IsDTensorGrad -->|Yes| UnwrapGrad[Extract p_grad._local_tensor]
    IsDTensorGrad -->|No| UseGrad[Use gradient as-is]
    UnwrapGrad --> GetStates[Get optimizer states:<br/>exp_avg, exp_avg_sq,<br/>master_param]
    UseGrad --> GetStates
    GetStates --> IsDTensorParam{Is param<br/>DTensor?}
    IsDTensorParam -->|Yes| UnwrapParam[Extract p._local_tensor]
    IsDTensorParam -->|No| UseParam[Use param as-is]
    UnwrapParam --> IsFP8{Is local tensor<br/>Float8Tensor?}
    UseParam --> IsFP8
    IsFP8 -->|Yes| ExtractFP8[Extract _data, scale,<br/>amax, scale_inv]
    IsFP8 -->|No| CheckLowPrec{Is fp16/bf16?}
    ExtractFP8 --> CallKernel[Call multi_tensor_adam<br/>with FP8 metadata]
    CheckLowPrec -->|Yes| CallKernelFP16[Call multi_tensor_adam<br/>for fp16/bf16]
    CheckLowPrec -->|No| CallKernelFP32[Call multi_tensor_adam<br/>for fp32]
    CallKernel --> End([Update complete])
    CallKernelFP16 --> End
    CallKernelFP32 --> End
    Skip --> End
Loading

Last reviewed commit: 9710315

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 13 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Member

@cspades cspades left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, clean edits.

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from 0103b53 to 3c3dbd2 Compare February 24, 2026 20:06
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@XueSongTap
Copy link

@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up.

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from a4d691f to 872caef Compare February 26, 2026 15:11
@pstjohn pstjohn marked this pull request as draft February 26, 2026 20:08
@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch 4 times, most recently from 9ccc0c3 to eb8606a Compare February 26, 2026 21:55
@pstjohn pstjohn marked this pull request as ready for review February 26, 2026 21:56
…hard

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from eb8606a to c2415e4 Compare February 26, 2026 22:50
Comment on lines +167 to +170
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
"MXFP8 quantized tensors, causing illegal memory access"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude's analysis:

Root cause: MXFP8Tensor inherits from QuantizedTensor but NOT from Float8Tensor. In fused_adam.py step():

  1. Line 617: isinstance(p, Float8Tensor) → False for MXFP8Tensor
  2. Line 629: p.dtype in [torch.float16, torch.bfloat16] → True (nominal dtype is bfloat16)
  3. So p.data (which is still an MXFP8Tensor wrapper with MXFP8-layout storage) gets added to p_f16_model
  4. The multi_tensor_adam CUDA kernel treats this as plain bf16 memory → illegal memory access

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
continue
# Extract local tensors from DTensors (e.g. from FSDP2)
# so that multi_tensor kernels receive plain CUDA tensors.
if isinstance(p_grad, DTensor):
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? Since p_grad is high precision, all Dtensor's op should get translated to local tensor's op without needing to extract it.

column-wise FP8 data.

"""
data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please provide reason(maybe in the PR details) as to why this is needed? And where is this used? Is this needed for DCP checkpointing. And if so why?

@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

Copy link
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean PR and great work. Left a few minor comments. LGTM post CI success.

if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling):
model_state = {
k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state")
}
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also please comment why we should avoid saving _extra_state. As in what error we get with dcp if we dont do so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one I remember, the others I'll need to comment out and run the test suite again 😅

But this is a known hassle where torch DCP needs the sizes of these tensors to remain consistent during saving & loading, and since this is pickled data, it changes when there's data in that field.

The alternative is a detailed load_planner for DelayedScaling that reads and allocates the extra state data tensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Example of quantized_model_init for low-precision compute weights, fp32 main weights using fused_adam with fsdp2

4 participants