Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
kwargs["fp8_dtype"] = py::cast(this->dtype);
kwargs["data_transpose"] = transpose_py;
kwargs["quantizer"] = this->quantizer;
kwargs["fake_dtype"] = GetATenDType(dtype);

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args.ptr(), kwargs.ptr());
Expand Down Expand Up @@ -597,6 +598,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
kwargs["fp8_dtype"] = py::cast(this->dtype);
kwargs["data_transpose"] = transpose_py;
kwargs["quantizer"] = this->quantizer;
kwargs["fake_dtype"] = GetATenDType(dtype);

py::tuple args(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
Expand Down Expand Up @@ -966,6 +968,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
kwargs["fp8_dtype"] = py::cast(this->dtype);
kwargs["quantizer"] = this->quantizer;
kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2);
kwargs["fake_dtype"] = GetATenDType(dtype);

py::tuple args(0);
PyObject* result =
Expand Down Expand Up @@ -1367,6 +1370,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
kwargs["fp8_dtype"] = py::cast(this->dtype);
kwargs["quantizer"] = this->quantizer;
kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales);
kwargs["fake_dtype"] = GetATenDType(dtype);

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass),
args.ptr(), kwargs.ptr());
Expand Down Expand Up @@ -1773,6 +1777,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
kwargs["fp4_dtype"] = py::cast(this->dtype);
kwargs["quantizer"] = this->quantizer;
kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales);
kwargs["fake_dtype"] = GetATenDType(dtype);

py::tuple args(0);

Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ def _start_all_gather_fp8_blockwise(
device = inp._columnwise_data.device
else:
raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data")
dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant.
dtype = inp._dtype
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or"
Expand Down Expand Up @@ -1317,7 +1317,7 @@ def _all_gather_nvfp4(
if inp._columnwise_data is not None:
in_shape_t = inp._columnwise_data.size()
device = inp._columnwise_data.device
dtype = torch.bfloat16
dtype = inp._dtype
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, "
Expand Down Expand Up @@ -1486,7 +1486,7 @@ def _all_gather_mxfp8(
device = inp._columnwise_data.device
else:
raise ValueError("Got MXFP8 input tensor without any data")
dtype = torch.bfloat16 # Guess high-precision dtype.
dtype = inp._dtype
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, "
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def fill_userbuffers_buffer_for_all_gather(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
fake_dtype=local_tensor._dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
Expand Down Expand Up @@ -596,6 +597,7 @@ def fill_userbuffers_buffer_for_all_gather(
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
with_gemm_swizzled_scales=False,
fake_dtype=local_tensor._dtype,
)
return global_tensor, local_tensor

Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class QuantizedTensorStorage:
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (like __torch_dispatch__)."""

_dtype: torch.dtype
_quantizer: Optional[Quantizer]
Comment on lines +40 to 41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No lazy-init guard for _dtype on storage objects

QuantizedTensor.dtype (line 405–409) has a hasattr(self, "_dtype") lazy-initializer that protects against deserialization from pre-PR checkpoints. QuantizedTensorStorage and its subclasses have no equivalent protection — _dtype: torch.dtype is only a class-level annotation, not a default value.

If an *TensorStorage object is unpickled from a checkpoint that was saved before this PR, the first call to .dequantize() (or the distributed-ops in distributed.py that now access inp._dtype) will raise AttributeError: _dtype.

Consider adding a similar lazy fallback in the dequantize methods, e.g.:

if dtype is None:
    dtype = getattr(self, "_dtype", torch.float32)


def update_usage(
Expand Down Expand Up @@ -367,10 +368,13 @@ def __new__(
shape: Iterable[int],
dtype: torch.dtype,
*,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this redundant with the dtype kwarg?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to avoid issues with MRO and still have fairly straightforward constructors for the Storage classes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just noticed that the make_like call would be problematic there otherwise - we want to include the fake_dtype in get_metadata call, but if it was named dtype it would clash with the dtype that we pass directly in make_like.

requires_grad: bool = False,
device: Optional[torch.device] = None,
stride: Optional[Iterable[int]] = None,
):
if fake_dtype is not None and fake_dtype != dtype:
raise ValueError(f"fake_dtype ({fake_dtype}) does not match dtype ({dtype})")
Comment on lines +376 to +377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation breaks existing make_like call sites

This new guard will cause regressions on every call to make_like(tensor, dtype=X) where X differs from tensor._dtype, because get_metadata() now always injects fake_dtype=self._dtype into kwargs, and QuantizedTensor.__new__ is then called with both dtype=X (the intended new dtype) and fake_dtype=old_dtype (from metadata).

Confirmed breakage paths:

  • transformer_engine/pytorch/tensor/__init__.py:63 — module cast utility (model.half(), model.bfloat16(), etc.) calls tensor.__class__.make_like(tensor, dtype=dtype) for every QuantizedTensor; whenever dtype != tensor._dtype the model cast will raise ValueError.
  • attention/dot_product_attention/context_parallel.py — 10+ call sites of the form Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) where fwd_nominal_dtype may differ from x._dtype.
  • attention/dot_product_attention/utils.py:2220 — same pattern.

The root cause is that fake_dtype is being included in get_metadata() but the constructor-level guard then rejects any case where the caller wants to create a clone at a different nominal dtype. Either:

  1. Remove the guard (it is redundant for the full-tensor path, because QuantizedTensor.__new__ already sets _dtype = dtype), or
  2. Override fake_dtype in QuantizedTensor.make_like so it matches the requested dtype before calling the constructor.

# For stride, We are assuming only contiguous tensors
# Calculate stride from shape if not provided. When creating this object from
# C++ code, we provide the stride computed from shape in C++ to avoid the
Expand Down Expand Up @@ -485,7 +489,7 @@ def clear(self):
)

def __repr__(self, *, tensor_contents=None) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
return f"{self.__class__.__name__}(data={self.dequantize()})"

def float(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
Expand Down Expand Up @@ -588,7 +592,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):

def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype)
return arg.dequantize()
return arg

def maybe_update_inplace(arg, new_arg, schema_arg):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def __repr__(self, *, tensor_contents=None):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})"
f" data={self.dequantize()})"
)

def quantize_(
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def create_tensor_from_data(
data=data,
fp8_scale_inv=1 / self.scale,
fp8_dtype=self.dtype,
fake_dtype=fake_dtype,
requires_grad=requires_grad,
data_transpose=None,
quantizer=self,
Expand Down Expand Up @@ -407,6 +408,7 @@ def create_tensor_from_data(
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device),
fp8_dtype=self.dtype,
fake_dtype=fake_dtype,
requires_grad=requires_grad,
data_transpose=None,
quantizer=self,
Expand Down Expand Up @@ -498,7 +500,7 @@ def __repr__(self, *, tensor_contents=None):
"Float8Tensor("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.dequantize(dtype=self.dtype)}"
f"data={self.dequantize()}"
")"
)

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def __new__(
)

def __repr__(self, *, tensor_contents=None):
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})"
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})"

def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/tensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def __new__(
return instance

def __repr__(self, *, tensor_contents=None):
return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})"
return f"NVFP4Tensor, data={self.dequantize()})"

def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def __new__(
quantizer: Quantizer,
is_2D_scaled: bool,
*args,
fake_dtype: Optional[torch.dtype] = None,
**kwargs,
):
if cls is Float8BlockwiseQTensorStorage:
instance = object.__new__(cls)
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
else:
instance = super().__new__(cls, *args, **kwargs)
instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer.copy() if quantizer is not None else None
Expand Down Expand Up @@ -101,6 +103,7 @@ def get_metadata(self) -> Dict[str, Any]:
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled,
"fake_dtype": self._dtype,
}

def prepare_for_saving(
Expand Down Expand Up @@ -149,7 +152,9 @@ def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.
permute_dims.append(0)
return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous()

def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def _dequantize_vectorwise(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if dtype is None:
dtype = self._dtype
block_len = 128

q_M, q_K = 1, 1
Expand Down Expand Up @@ -211,10 +216,12 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch
return self._transpose_dq_columnwise_output(result)
return result

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
"""
if dtype is None:
dtype = self._dtype
block_len = 128
if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ def __new__(
data: Optional[torch.Tensor],
fp8_scale_inv: torch.Tensor,
fp8_dtype: TE_DType,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to just name it dtype since QuantizedTensor is already using that name in its constructor.

Suggested change
fake_dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = None,

data_transpose: Optional[torch.Tensor] = None,
quantizer: Optional[Quantizer] = None,
**kwargs,
):
if cls is Float8TensorStorage:
instance = object.__new__(cls)
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
else:
instance = super().__new__(cls, *args, **kwargs)
instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs)
instance._data = data
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
Expand Down Expand Up @@ -130,6 +132,7 @@ def get_metadata(self) -> Dict[str, Any]:
"fp8_dtype": self._fp8_dtype,
"data_transpose": self._transpose,
"quantizer": self._quantizer,
"fake_dtype": self._dtype,
}

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
Expand Down Expand Up @@ -159,8 +162,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr
return self._transpose
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Dequantize to a higher precision."""
if dtype is None:
dtype = self._dtype
return _FromFloat8Func.forward(None, self, dtype)

def size(self, *args, **kwargs):
Expand Down Expand Up @@ -192,6 +197,7 @@ def view(self, shape: torch.Size):
data=out_data,
fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype,
fake_dtype=self._dtype,
data_transpose=out_transpose,
quantizer=self._quantizer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ def __new__(
quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
*args,
fake_dtype: Optional[torch.dtype] = None,
**kwargs,
):
if cls is MXFP8TensorStorage:
instance = object.__new__(cls)
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
else:
instance = super().__new__(cls, *args, **kwargs)
instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._rowwise_scale_inv = rowwise_scale_inv
Expand Down Expand Up @@ -139,6 +141,7 @@ def get_metadata(self) -> Dict[str, Any]:
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"with_gemm_swizzled_scales": self._with_gemm_swizzled_scales,
"fake_dtype": self._dtype,
}

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]:
Expand Down Expand Up @@ -175,8 +178,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr
return self._columnwise_data
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Dequantize to a higher precision."""
if dtype is None:
dtype = self._dtype
return _FromMXFP8Func.forward(None, self, dtype)

def size(self, *args, **kwargs):
Expand Down Expand Up @@ -238,6 +243,7 @@ def view(self, shape: torch.Size):
fp8_dtype=self._fp8_dtype,
quantizer=self._quantizer,
with_gemm_swizzled_scales=self._with_gemm_swizzled_scales,
fake_dtype=self._dtype,
)

def __repr__(self):
Expand Down
14 changes: 11 additions & 3 deletions transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,14 @@ def __new__(
quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
*args,
fake_dtype: Optional[torch.dtype] = None,
**kwargs,
):

instance = super().__new__(cls, *args, **kwargs)
if cls is NVFP4TensorStorage:
instance = object.__new__(cls)
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
else:
instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs)

instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
Expand Down Expand Up @@ -168,6 +172,7 @@ def get_metadata(self) -> Dict[str, Any]:
"fp4_dtype": self._fp4_dtype,
"quantizer": self._quantizer,
"with_gemm_swizzled_scales": self._with_gemm_swizzled_scales,
"fake_dtype": self._dtype,
}

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]:
Expand Down Expand Up @@ -204,8 +209,10 @@ def get_data_tensors(self):
"""Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Dequantize to a higher precision."""
if dtype is None:
dtype = self._dtype
return _FromNVFP4Func.forward(None, self, dtype)

def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]:
Expand Down Expand Up @@ -295,6 +302,7 @@ def view(self, shape: torch.Size):
quantizer=self._quantizer,
fp4_dtype=self._fp4_dtype,
with_gemm_swizzled_scales=self._with_gemm_swizzled_scales,
fake_dtype=self._dtype,
)

def __repr__(self):
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def forward(
fp8_dtype=mixed_x_layer._fp8_dtype,
data=x.squeeze(split_dim) if squeeze else x,
shape=x.squeeze(split_dim).shape if squeeze else x.shape,
fake_dtype=mixed_x_layer._dtype,
quantizer=mixed_x_layer._quantizer,
)
for x in torch.split(
Expand Down
Loading