Enable dequantization from MXFP8 tensor with only columnwise data#2712
Enable dequantization from MXFP8 tensor with only columnwise data#2712ptrendx wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Greptile SummaryThis PR unlocks dequantization of MXFP8 tensors that carry only columnwise-scaled data. The entire C++ path ( Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant _FromMXFP8Func
participant tex_dequantize as tex.dequantize (C++)
participant NVTETensorFromMXFP8Tensor
Caller->>_FromMXFP8Func: forward(tensor, dtype)
alt _rowwise_data is not None OR _columnwise_data is not None
_FromMXFP8Func->>tex_dequantize: dequantize(tensor, dtype)
tex_dequantize->>NVTETensorFromMXFP8Tensor: convert tensor
Note over NVTETensorFromMXFP8Tensor: Sets rowwise data if present<br/>Sets columnwise data if present<br/>NVTE_CHECK(rowwise || columnwise)
NVTETensorFromMXFP8Tensor-->>tex_dequantize: TensorWrapper
tex_dequantize-->>_FromMXFP8Func: output Tensor
_FromMXFP8Func-->>Caller: dequantized Tensor
else neither rowwise nor columnwise
_FromMXFP8Func-->>Caller: raises ValueError
end
Last reviewed commit: 27196ec |
| torch.testing.assert_close(x_deq_columnwise, x_deq_rowwise, **_tols[fp8_dtype]) | ||
|
|
||
| # Make sure we are not trivially passing the test | ||
| with pytest.raises(AssertionError): | ||
| torch.testing.assert_close(x_deq_columnwise, -x_ref, **_tols[fp8_dtype]) | ||
|
|
||
| @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) | ||
| @pytest.mark.parametrize("dims", [[128, 128], [256, 256]]) | ||
| def test_mxfp8_dequantize_columnwise_only_quantized_separately( | ||
| self, | ||
| fp8_dtype: tex.DType, | ||
| dims: DimsType, | ||
| ) -> None: | ||
| """Check dequantization of MXFP8 tensor quantized with columnwise only""" | ||
|
|
||
| dtype = torch.bfloat16 | ||
|
|
||
| # Initialize random data | ||
| x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1 | ||
|
|
||
| # Quantize with columnwise only (no rowwise) | ||
| quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=False, columnwise=True) | ||
| x_mxfp8 = quantizer(x_ref) | ||
| assert x_mxfp8._rowwise_data is None | ||
| assert x_mxfp8._columnwise_data is not None | ||
|
|
||
| # Dequantize from columnwise only | ||
| x_deq = x_mxfp8.dequantize(dtype=dtype) | ||
|
|
||
| # Should be close to the original | ||
| torch.testing.assert_close(x_deq, x_ref, **_tols[fp8_dtype]) | ||
|
|
||
| # Make sure we are not trivially passing the test | ||
| with pytest.raises(AssertionError): | ||
| torch.testing.assert_close(x_deq, -x_ref, **_tols[fp8_dtype]) |
There was a problem hiding this comment.
Limited dtype and dimension coverage in second test
test_mxfp8_dequantize_columnwise_only_quantized_separately hardcodes dtype = torch.bfloat16 and omits the asymmetric [128, 256] dimension, while the companion test test_mxfp8_dequantize_columnwise_only is fully parameterized over _dtypes (float32, float16, bfloat16) and includes [128, 256].
If the columnwise-only quantization path genuinely cannot handle float32/float16 inputs or non-square tensors, that constraint is invisible from the test and should be documented. If it can, the missing coverage means regressions in those cases could go undetected.
Consider either:
- Parameterizing the second test the same way as the first (adding
@pytest.mark.parametrize("dtype", _dtypes)and[128, 256]todims), or - Adding a comment explaining why those combinations are intentionally excluded.
| torch.testing.assert_close(x_deq_columnwise, x_deq_rowwise, **_tols[fp8_dtype]) | |
| # Make sure we are not trivially passing the test | |
| with pytest.raises(AssertionError): | |
| torch.testing.assert_close(x_deq_columnwise, -x_ref, **_tols[fp8_dtype]) | |
| @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) | |
| @pytest.mark.parametrize("dims", [[128, 128], [256, 256]]) | |
| def test_mxfp8_dequantize_columnwise_only_quantized_separately( | |
| self, | |
| fp8_dtype: tex.DType, | |
| dims: DimsType, | |
| ) -> None: | |
| """Check dequantization of MXFP8 tensor quantized with columnwise only""" | |
| dtype = torch.bfloat16 | |
| # Initialize random data | |
| x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1 | |
| # Quantize with columnwise only (no rowwise) | |
| quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=False, columnwise=True) | |
| x_mxfp8 = quantizer(x_ref) | |
| assert x_mxfp8._rowwise_data is None | |
| assert x_mxfp8._columnwise_data is not None | |
| # Dequantize from columnwise only | |
| x_deq = x_mxfp8.dequantize(dtype=dtype) | |
| # Should be close to the original | |
| torch.testing.assert_close(x_deq, x_ref, **_tols[fp8_dtype]) | |
| # Make sure we are not trivially passing the test | |
| with pytest.raises(AssertionError): | |
| torch.testing.assert_close(x_deq, -x_ref, **_tols[fp8_dtype]) | |
| @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) | |
| @pytest.mark.parametrize("dtype", _dtypes) | |
| @pytest.mark.parametrize("dims", [[128, 128], [256, 256], [128, 256]]) | |
| def test_mxfp8_dequantize_columnwise_only_quantized_separately( | |
| self, | |
| fp8_dtype: tex.DType, | |
| dtype: torch.dtype, | |
| dims: DimsType, | |
| ) -> None: | |
| """Check dequantization of MXFP8 tensor quantized with columnwise only""" |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Additional Comments (1)
This is a pre-existing issue that the PR's extended dequantize path now exposes more broadly. A small guard would improve clarity: def __repr__(self):
data = self.dequantize()
data_label = "rowwise_scaled_data" if self._rowwise_data is not None else "columnwise_scaled_data"
scale_inv = self._rowwise_scale_inv if self._rowwise_data is not None else self._columnwise_scale_inv
return (
"MXFP8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"{data_label}={data}"
f"scale_inv={scale_inv}, "
")"
) |
|
/te-ci pytorch |
Description
Enabled the dequantization from MXFP8 tensor with only columnwise data. The C++ part of this was already in place, the only change needed was to enable that on the Python side.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: