Skip to content

Enable dequantization from MXFP8 tensor with only columnwise data#2712

Open
ptrendx wants to merge 1 commit intoNVIDIA:mainfrom
ptrendx:pr_dequantize
Open

Enable dequantization from MXFP8 tensor with only columnwise data#2712
ptrendx wants to merge 1 commit intoNVIDIA:mainfrom
ptrendx:pr_dequantize

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Feb 26, 2026

Description

Enabled the dequantization from MXFP8 tensor with only columnwise data. The C++ part of this was already in place, the only change needed was to enable that on the Python side.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 26, 2026

Greptile Summary

This PR unlocks dequantization of MXFP8 tensors that carry only columnwise-scaled data. The entire C++ path (NVTETensorFromMXFP8Tensornvte_dequantize) was already able to handle columnwise-only tensors; the only barrier was a too-restrictive Python guard in _FromMXFP8Func.forward that required _rowwise_data to be non-None. The change expands that guard to accept either data view and upgrades the failure-path exception from NotImplementedError to the more semantically accurate ValueError.

Key changes:

  • mxfp8_tensor_storage.py: Condition broadened from _rowwise_data is not None_rowwise_data is not None or _columnwise_data is not None; dead-end error changed to ValueError.
  • test_quantized_tensor.py: New TestMXFP8Tensor class adds two tests — one that strips rowwise data from a dual-mode tensor before dequantizing, and one that quantizes in columnwise-only mode from the start.
  • The second new test (test_mxfp8_dequantize_columnwise_only_quantized_separately) has narrower coverage than the first: it hardcodes dtype = torch.bfloat16 and skips the asymmetric [128, 256] dimension, which could mask regressions for other dtypes or non-square shapes.
  • The pre-existing __repr__ method now silently dequantizes from columnwise data when that is all that is available, but still labels the output as rowwise_scaled_data — a minor cosmetic inaccuracy that is more visible after this PR.

Confidence Score: 4/5

  • Safe to merge; the logic change is minimal and correct, backed by the pre-existing C++ implementation.
  • The core two-line Python change is consistent with the C++ layer which already handled columnwise-only tensors. Tests cover the primary scenarios. Minor gaps in the second test's dtype/dimension coverage and a pre-existing __repr__ label issue prevent a perfect score.
  • tests/pytorch/test_quantized_tensor.py — the second test's limited parameterization may hide regressions for non-bfloat16 dtypes or asymmetric tensor shapes.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Two-line change: extends the dequantization guard from rowwise-only to either rowwise OR columnwise, and improves exception type from NotImplementedError to ValueError. The C++ layer (NVTETensorFromMXFP8Tensor) already handles columnwise-only tensors correctly.
tests/pytorch/test_quantized_tensor.py Adds TestMXFP8Tensor class with two new tests covering columnwise-only dequantization. First test (parameterized over all dtypes and including asymmetric [128,256] dims) strips rowwise data after dual-mode quantization. Second test (hardcoded to bfloat16, only symmetric dims) quantizes columnwise-only from the start — the narrower coverage may hide gaps for other dtypes/shapes.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant _FromMXFP8Func
    participant tex_dequantize as tex.dequantize (C++)
    participant NVTETensorFromMXFP8Tensor

    Caller->>_FromMXFP8Func: forward(tensor, dtype)
    alt _rowwise_data is not None OR _columnwise_data is not None
        _FromMXFP8Func->>tex_dequantize: dequantize(tensor, dtype)
        tex_dequantize->>NVTETensorFromMXFP8Tensor: convert tensor
        Note over NVTETensorFromMXFP8Tensor: Sets rowwise data if present<br/>Sets columnwise data if present<br/>NVTE_CHECK(rowwise || columnwise)
        NVTETensorFromMXFP8Tensor-->>tex_dequantize: TensorWrapper
        tex_dequantize-->>_FromMXFP8Func: output Tensor
        _FromMXFP8Func-->>Caller: dequantized Tensor
    else neither rowwise nor columnwise
        _FromMXFP8Func-->>Caller: raises ValueError
    end
Loading

Last reviewed commit: 27196ec

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +707 to +741
torch.testing.assert_close(x_deq_columnwise, x_deq_rowwise, **_tols[fp8_dtype])

# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_deq_columnwise, -x_ref, **_tols[fp8_dtype])

@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dims", [[128, 128], [256, 256]])
def test_mxfp8_dequantize_columnwise_only_quantized_separately(
self,
fp8_dtype: tex.DType,
dims: DimsType,
) -> None:
"""Check dequantization of MXFP8 tensor quantized with columnwise only"""

dtype = torch.bfloat16

# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1

# Quantize with columnwise only (no rowwise)
quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=False, columnwise=True)
x_mxfp8 = quantizer(x_ref)
assert x_mxfp8._rowwise_data is None
assert x_mxfp8._columnwise_data is not None

# Dequantize from columnwise only
x_deq = x_mxfp8.dequantize(dtype=dtype)

# Should be close to the original
torch.testing.assert_close(x_deq, x_ref, **_tols[fp8_dtype])

# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_deq, -x_ref, **_tols[fp8_dtype])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limited dtype and dimension coverage in second test

test_mxfp8_dequantize_columnwise_only_quantized_separately hardcodes dtype = torch.bfloat16 and omits the asymmetric [128, 256] dimension, while the companion test test_mxfp8_dequantize_columnwise_only is fully parameterized over _dtypes (float32, float16, bfloat16) and includes [128, 256].

If the columnwise-only quantization path genuinely cannot handle float32/float16 inputs or non-square tensors, that constraint is invisible from the test and should be documented. If it can, the missing coverage means regressions in those cases could go undetected.

Consider either:

  • Parameterizing the second test the same way as the first (adding @pytest.mark.parametrize("dtype", _dtypes) and [128, 256] to dims), or
  • Adding a comment explaining why those combinations are intentionally excluded.
Suggested change
torch.testing.assert_close(x_deq_columnwise, x_deq_rowwise, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_deq_columnwise, -x_ref, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dims", [[128, 128], [256, 256]])
def test_mxfp8_dequantize_columnwise_only_quantized_separately(
self,
fp8_dtype: tex.DType,
dims: DimsType,
) -> None:
"""Check dequantization of MXFP8 tensor quantized with columnwise only"""
dtype = torch.bfloat16
# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1
# Quantize with columnwise only (no rowwise)
quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=False, columnwise=True)
x_mxfp8 = quantizer(x_ref)
assert x_mxfp8._rowwise_data is None
assert x_mxfp8._columnwise_data is not None
# Dequantize from columnwise only
x_deq = x_mxfp8.dequantize(dtype=dtype)
# Should be close to the original
torch.testing.assert_close(x_deq, x_ref, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_deq, -x_ref, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("dims", [[128, 128], [256, 256], [128, 256]])
def test_mxfp8_dequantize_columnwise_only_quantized_separately(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
dims: DimsType,
) -> None:
"""Check dequantization of MXFP8 tensor quantized with columnwise only"""

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 26, 2026

Additional Comments (1)

transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
__repr__ label misleading for columnwise-only tensors

__repr__ calls self.dequantize() (which now succeeds even when only _columnwise_data is present), but unconditionally labels the output as rowwise_scaled_data and prints rowwise_scale_inv. For a columnwise-only tensor the label is incorrect and _rowwise_scale_inv is None, producing output like rowwise_scale_inv=None.

This is a pre-existing issue that the PR's extended dequantize path now exposes more broadly. A small guard would improve clarity:

def __repr__(self):
    data = self.dequantize()
    data_label = "rowwise_scaled_data" if self._rowwise_data is not None else "columnwise_scaled_data"
    scale_inv = self._rowwise_scale_inv if self._rowwise_data is not None else self._columnwise_scale_inv

    return (
        "MXFP8TensorStorage("
        f"fp8_dtype={self._fp8_dtype}, "
        f"{data_label}={data}"
        f"scale_inv={scale_inv}, "
        ")"
    )

@ptrendx
Copy link
Member Author

ptrendx commented Feb 27, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants