Skip to content

NVFP4 primary weight support#2691

Open
WanZzzzzz wants to merge 11 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights
Open

NVFP4 primary weight support#2691
WanZzzzzz wants to merge 11 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights

Conversation

@WanZzzzzz
Copy link

Description

This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:

NVFP4 Partial Cast Kernel (nvfp4_2d_partial_cast)

  • Implements nibble-accurate partial updates for NVFP4 tensors in distributed settings
  • Supports two-level NVFP4 scaling: global FP32 scale + per-block FP8 E4M3 scale

NVFP4 Transpose Kernel (nvfp4_transpose)

  • Custom transpose kernel for nibble-packed NVFP4 data with shared memory optimization
  • Uses vectorized uint2 loads/stores with 64×64 tiles for efficient memory access
  • Handles nibble repacking during transpose (unlike FP8 byte transpose)
  • Enables columnwise data generation for GEMM operations after rowwise AllGather

Fused Scale Kernel (nvfp4_fused_scale)

  • Fuses per-block scale computation, global amax copy, and FP8 scale expansion into a single kernel
  • Eliminates multiple kernel launches and avoids D2H transfers by accepting tensor pointers
  • Reduces kernel launch overhead in the critical path

Multi-Tensor Dispatch Pattern

  • C++-side loop dispatch for NVFP4 multi-tensor operations
  • Reduces Python–C++ transition overhead compared to per-tensor Python loops
  • Collects metadata in Python and executes batched operations in C++ wrappers

CPU Overhead Optimizations

  • Batched dtype conversion via torch.cat / torch.split
  • Replaced torch.zeros() with torch.empty() for immediately written buffers
  • Consolidated metadata collection and allocation phases
  • Optimized bucket partitioning for expert parallel buffers

Scale Computation Improvements

  • Fixed floating-point precision mismatch between Python and CUDA
  • Uses FP32 constants consistent with CUDA arithmetic
  • Ensures bitwise-identical results between partial and full quantization paths

New Public API

cast_master_weights_to_nvfp4()

  • Casts FP32 master weights to NVFP4 model weights
  • Handles global and per-block amax reduction across data parallel groups
  • Designed for low CPU overhead in distributed training loops

Testing

Test Description
test_nvfp4_transpose_kernel Verifies correctness for nibble-packed transpose
test_nvfp4_partial_cast_matches_full Multi-GPU: partial cast + all-gather equals full cast
test_single_gpu_partial_cast_vs_full Single-GPU: offset=0 partial cast matches reference quantizer
_test_cast_master_weights_to_nvfp4 500-iteration training loop with bitwise-identical loss

This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:

https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: qiyuw <qiyuw@nvidia.com>
@WanZzzzzz WanZzzzzz mentioned this pull request Feb 19, 2026
13 tasks
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

This PR implements NVFP4 (4-bit floating point) primary weight support for distributed training with ZeRO/FSDP optimizers. The implementation adds efficient partial casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks.

Key additions:

  • Custom NVFP4 kernels: partial cast with nibble-level updates, transpose with nibble repacking, and fused scale computation
  • Multi-tensor dispatch pattern to reduce Python-C++ overhead
  • CPU optimizations: batched dtype conversion via torch.cat/torch.split, replaced torch.zeros() with torch.empty() for buffers
  • New public API cast_master_weights_to_nvfp4() for distributed training loops

Critical issue found:

  • Byte index calculation in nvfp4_2d_partial_cast_kernel (line 236) is incorrect for odd start_offset values, which could cause data corruption with irregular sharding patterns. Current tests only exercise even offsets.

Test coverage:

  • Tests validate multi-GPU partial cast + allgather matches full cast
  • Single-GPU test verifies offset=0 partial cast matches reference
  • 500-iteration training loop with bitwise-identical loss validation
  • However, tests never exercise odd start_offset values

Confidence Score: 3/5

  • Safe for typical use cases with even-aligned sharding, but contains a bug that could cause silent data corruption with irregular sharding patterns
  • Score reflects a critical logic bug in the CUDA kernel byte indexing that wasn't caught by tests. While current test coverage is comprehensive for even offsets and the feature works correctly for typical ZeRO/FSDP sharding (which uses even boundaries), the bug could manifest in edge cases like expert parallelism or custom sharding. The implementation is otherwise well-structured with good optimizations.
  • transformer_engine/common/recipe/nvfp4.cu requires immediate attention to fix the byte indexing bug in nvfp4_2d_partial_cast_kernel

Important Files Changed

Filename Overview
transformer_engine/common/recipe/nvfp4.cu New NVFP4 CUDA kernels for partial cast, transpose, and scale operations. Contains potential bug in byte indexing for odd start_offsets.
transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp C++ wrapper for NVFP4 partial cast operations with proper input validation.
transformer_engine/pytorch/tensor/utils.py Python implementation of NVFP4 master weight casting with batched dtype conversion optimization.
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py Comprehensive tests for NVFP4 partial cast. Test coverage limited to even start_offsets.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Master Weights FP32] --> B[Batched Dtype Conversion]
    B --> C[Per-Rank Shards]
    C --> D[Compute Partial Amax]
    D --> E[AllReduce Block Amax]
    D --> F[AllReduce Global Amax]
    E --> G[Compute Scales]
    F --> G
    G --> H[Fused Scale Kernel]
    H --> I[FP8 E4M3 Scales]
    H --> J[FP32 Per-Block Scales]
    C --> K[Partial Cast Kernel]
    J --> K
    I --> K
    K --> L[NVFP4 Packed Bytes]
    L --> M[AllGather]
    M --> N[Transpose if Needed]
    N --> O[Full NVFP4 Weights]
Loading

Last reviewed commit: ebf9f4c

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10

This comment was marked as outdated.

start_offsets,
group,
fsdp_shard_model_weights=None,
manual_post_all_gather_processing=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:

Suggested change
manual_post_all_gather_processing=False,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535

Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense for now then. Let's change the default to True though since that's preferred.

I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:

cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)

y = model(x)  # Float8Tensor internally performs FP8 transpose

This is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:

cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)

y = model(x)  # FutureFormatTensor assumes data is interleaved

In this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, note it down.

Comment on lines 245 to 259
if isinstance(self.weights[0], QuantizedTensor):
weight_buffer_dtype = torch.uint8
if self.weights_are_nvfp4:
weight_buffer_length = self.storage_total
buffer_rank_start = storage_rank_start
buffer_rank_end = storage_rank_end
else:
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
else:
weight_buffer_dtype = weights[0].dtype
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

qiyuw and others added 2 commits February 20, 2026 05:52
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as resolved.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@timmoon10
Copy link
Collaborator

/te-ci L1

timmoon10
timmoon10 previously approved these changes Feb 21, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.

@timmoon10 timmoon10 self-requested a review February 21, 2026 00:09
@timmoon10 timmoon10 dismissed their stale review February 21, 2026 00:09

Test failures

Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
qiyuw and others added 2 commits February 26, 2026 22:13
Signed-off-by: qiyuw <qiyuw@nvidia.com>
continue;
}
const size_t ref_idx = mask[first] ? elem_idx[first] : elem_idx[second];
const size_t byte_idx = (ref_idx - start_offset) >> 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Byte index calculation incorrect for odd start_offset values

When start_offset is odd, this formula produces wrong byte indices. For example:

  • Element at index 2 with start_offset=1: (2-1)>>1 = 0 (wrong, should be 1)
  • Element at index 3 with start_offset=1: (3-1)>>1 = 1 (correct)

Should be: byte_idx = (ref_idx >> 1) - (start_offset >> 1)

Current tests only use even start_offset values (multiples of shard_size), so this bug isn't caught. While typical ZeRO/FSDP sharding uses even boundaries, irregular sharding (e.g., expert parallelism) could trigger this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants