Conversation
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Greptile SummaryThis PR implements NVFP4 (4-bit floating point) primary weight support for distributed training with ZeRO/FSDP optimizers. The implementation adds efficient partial casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks. Key additions:
Critical issue found:
Test coverage:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Master Weights FP32] --> B[Batched Dtype Conversion]
B --> C[Per-Rank Shards]
C --> D[Compute Partial Amax]
D --> E[AllReduce Block Amax]
D --> F[AllReduce Global Amax]
E --> G[Compute Scales]
F --> G
G --> H[Fused Scale Kernel]
H --> I[FP8 E4M3 Scales]
H --> J[FP32 Per-Block Scales]
C --> K[Partial Cast Kernel]
J --> K
I --> K
K --> L[NVFP4 Packed Bytes]
L --> M[AllGather]
M --> N[Transpose if Needed]
N --> O[Full NVFP4 Weights]
Last reviewed commit: ebf9f4c |
This comment was marked as outdated.
This comment was marked as outdated.
| start_offsets, | ||
| group, | ||
| fsdp_shard_model_weights=None, | ||
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535
There was a problem hiding this comment.
I see, that makes sense for now then. Let's change the default to True though since that's preferred.
I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:
cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)
y = model(x) # Float8Tensor internally performs FP8 transposeThis is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:
cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)
y = model(x) # FutureFormatTensor assumes data is interleavedIn this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
| if isinstance(self.weights[0], QuantizedTensor): | ||
| weight_buffer_dtype = torch.uint8 | ||
| if self.weights_are_nvfp4: | ||
| weight_buffer_length = self.storage_total | ||
| buffer_rank_start = storage_rank_start | ||
| buffer_rank_end = storage_rank_end | ||
| else: | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end | ||
| else: | ||
| weight_buffer_dtype = weights[0].dtype | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end |
There was a problem hiding this comment.
Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
This comment was marked as resolved.
This comment was marked as resolved.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
|
/te-ci L1 |
timmoon10
left a comment
There was a problem hiding this comment.
Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
for more information, see https://pre-commit.ci
| continue; | ||
| } | ||
| const size_t ref_idx = mask[first] ? elem_idx[first] : elem_idx[second]; | ||
| const size_t byte_idx = (ref_idx - start_offset) >> 1; |
There was a problem hiding this comment.
Byte index calculation incorrect for odd start_offset values
When start_offset is odd, this formula produces wrong byte indices. For example:
- Element at index 2 with
start_offset=1:(2-1)>>1 = 0(wrong, should be 1) - Element at index 3 with
start_offset=1:(3-1)>>1 = 1(correct)
Should be: byte_idx = (ref_idx >> 1) - (start_offset >> 1)
Current tests only use even start_offset values (multiples of shard_size), so this bug isn't caught. While typical ZeRO/FSDP sharding uses even boundaries, irregular sharding (e.g., expert parallelism) could trigger this.
Description
This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.
Type of change
Changes
This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:
NVFP4 Partial Cast Kernel (
nvfp4_2d_partial_cast)NVFP4 Transpose Kernel (
nvfp4_transpose)uint2loads/stores with 64×64 tiles for efficient memory accessFused Scale Kernel (
nvfp4_fused_scale)Multi-Tensor Dispatch Pattern
CPU Overhead Optimizations
torch.cat/torch.splittorch.zeros()withtorch.empty()for immediately written buffersScale Computation Improvements
New Public API
cast_master_weights_to_nvfp4()Testing
test_nvfp4_transpose_kerneltest_nvfp4_partial_cast_matches_fulltest_single_gpu_partial_cast_vs_full_test_cast_master_weights_to_nvfp4This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:
https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads
Checklist: