diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0135c7f01c..c212d165cc 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -282,6 +282,7 @@ std::pair Float8Quantizer::create_tensor( kwargs["fp8_dtype"] = py::cast(this->dtype); kwargs["data_transpose"] = transpose_py; kwargs["quantizer"] = this->quantizer; + kwargs["fake_dtype"] = GetATenDType(dtype); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), args.ptr(), kwargs.ptr()); @@ -597,6 +598,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso kwargs["fp8_dtype"] = py::cast(this->dtype); kwargs["data_transpose"] = transpose_py; kwargs["quantizer"] = this->quantizer; + kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), @@ -966,6 +968,7 @@ std::pair Float8BlockQuantizer::create_tensor( kwargs["fp8_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); + kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); PyObject* result = @@ -1367,6 +1370,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve kwargs["fp8_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["fake_dtype"] = GetATenDType(dtype); PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), args.ptr(), kwargs.ptr()); @@ -1773,6 +1777,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..f3b6716200 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1079,7 +1079,7 @@ def _start_all_gather_fp8_blockwise( device = inp._columnwise_data.device else: raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data") - dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. + dtype = inp._dtype else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or" @@ -1317,7 +1317,7 @@ def _all_gather_nvfp4( if inp._columnwise_data is not None: in_shape_t = inp._columnwise_data.size() device = inp._columnwise_data.device - dtype = torch.bfloat16 + dtype = inp._dtype else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, " @@ -1486,7 +1486,7 @@ def _all_gather_mxfp8( device = inp._columnwise_data.device else: raise ValueError("Got MXFP8 input tensor without any data") - dtype = torch.bfloat16 # Guess high-precision dtype. + dtype = inp._dtype else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, " diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9c21141a39..3dee5a4aed 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -522,6 +522,7 @@ def fill_userbuffers_buffer_for_all_gather( data=global_tensor_data, fp8_scale_inv=local_tensor._scale_inv, fp8_dtype=local_tensor._fp8_dtype, + fake_dtype=local_tensor._dtype, quantizer=quantizer, ) return global_tensor, local_tensor @@ -596,6 +597,7 @@ def fill_userbuffers_buffer_for_all_gather( fp8_dtype=local_tensor._fp8_dtype, quantizer=quantizer, with_gemm_swizzled_scales=False, + fake_dtype=local_tensor._dtype, ) return global_tensor, local_tensor diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index cb697bc197..6ef7247545 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -37,6 +37,7 @@ class QuantizedTensorStorage: XTensor should only implement the functionality needed to behave like regular torch.Tensor (like __torch_dispatch__).""" + _dtype: torch.dtype _quantizer: Optional[Quantizer] def update_usage( @@ -367,10 +368,13 @@ def __new__( shape: Iterable[int], dtype: torch.dtype, *, + fake_dtype: Optional[torch.dtype] = None, requires_grad: bool = False, device: Optional[torch.device] = None, stride: Optional[Iterable[int]] = None, ): + if fake_dtype is not None and fake_dtype != dtype: + raise ValueError(f"fake_dtype ({fake_dtype}) does not match dtype ({dtype})") # For stride, We are assuming only contiguous tensors # Calculate stride from shape if not provided. When creating this object from # C++ code, we provide the stride computed from shape in C++ to avoid the @@ -485,7 +489,7 @@ def clear(self): ) def __repr__(self, *, tensor_contents=None) -> str: - return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" + return f"{self.__class__.__name__}(data={self.dequantize()})" def float(self) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -588,7 +592,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): - return arg.dequantize(dtype=arg.dtype) + return arg.dequantize() return arg def maybe_update_inplace(arg, new_arg, schema_arg): diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index e65730c015..0fae40f786 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -326,7 +326,7 @@ def __repr__(self, *, tensor_contents=None): return ( f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f" is_2D_scaled={self._is_2D_scaled}," - f" data={self.dequantize(dtype=self.dtype)})" + f" data={self.dequantize()})" ) def quantize_( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c60bb2308d..9cc00855cd 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -182,6 +182,7 @@ def create_tensor_from_data( data=data, fp8_scale_inv=1 / self.scale, fp8_dtype=self.dtype, + fake_dtype=fake_dtype, requires_grad=requires_grad, data_transpose=None, quantizer=self, @@ -407,6 +408,7 @@ def create_tensor_from_data( data=data, fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, + fake_dtype=fake_dtype, requires_grad=requires_grad, data_transpose=None, quantizer=self, @@ -498,7 +500,7 @@ def __repr__(self, *, tensor_contents=None): "Float8Tensor(" f"fp8_dtype={self._fp8_dtype}, " f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize(dtype=self.dtype)}" + f"data={self.dequantize()}" ")" ) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 96b6a67ea8..baff9cc2aa 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -296,7 +296,7 @@ def __new__( ) def __repr__(self, *, tensor_contents=None): - return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})" def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 4314fd248c..8ed1b4682c 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -450,7 +450,7 @@ def __new__( return instance def __repr__(self, *, tensor_contents=None): - return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})" + return f"NVFP4Tensor, data={self.dequantize()})" def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 2a86717017..52e292125e 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -46,12 +46,14 @@ def __new__( quantizer: Quantizer, is_2D_scaled: bool, *args, + fake_dtype: Optional[torch.dtype] = None, **kwargs, ): if cls is Float8BlockwiseQTensorStorage: instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 else: - instance = super().__new__(cls, *args, **kwargs) + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data instance._quantizer = quantizer.copy() if quantizer is not None else None @@ -101,6 +103,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, "is_2D_scaled": self._is_2D_scaled, + "fake_dtype": self._dtype, } def prepare_for_saving( @@ -149,7 +152,9 @@ def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch. permute_dims.append(0) return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() - def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def _dequantize_vectorwise(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if dtype is None: + dtype = self._dtype block_len = 128 q_M, q_K = 1, 1 @@ -211,10 +216,12 @@ def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch return self._transpose_dq_columnwise_output(result) return result - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ Construct plain PyTorch tensor from Float8BlockwiseQTensor """ + if dtype is None: + dtype = self._dtype block_len = 128 if not self._is_2D_scaled: return self._dequantize_vectorwise(dtype=dtype) diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index a815b366b2..0fb7966c2f 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -75,14 +75,16 @@ def __new__( data: Optional[torch.Tensor], fp8_scale_inv: torch.Tensor, fp8_dtype: TE_DType, + fake_dtype: Optional[torch.dtype] = None, data_transpose: Optional[torch.Tensor] = None, quantizer: Optional[Quantizer] = None, **kwargs, ): if cls is Float8TensorStorage: instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 else: - instance = super().__new__(cls, *args, **kwargs) + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._data = data instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype @@ -130,6 +132,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "data_transpose": self._transpose, "quantizer": self._quantizer, + "fake_dtype": self._dtype, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: @@ -159,8 +162,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr return self._transpose raise ValueError("No data to get, both rowwise_data and columnwise_data are False") - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype return _FromFloat8Func.forward(None, self, dtype) def size(self, *args, **kwargs): @@ -192,6 +197,7 @@ def view(self, shape: torch.Size): data=out_data, fp8_scale_inv=self._scale_inv, fp8_dtype=self._fp8_dtype, + fake_dtype=self._dtype, data_transpose=out_transpose, quantizer=self._quantizer, ) diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 12757aa58c..d2fb7e0df2 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -84,12 +84,14 @@ def __new__( quantizer: Optional[Quantizer], with_gemm_swizzled_scales: bool, *args, + fake_dtype: Optional[torch.dtype] = None, **kwargs, ): if cls is MXFP8TensorStorage: instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 else: - instance = super().__new__(cls, *args, **kwargs) + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data instance._rowwise_scale_inv = rowwise_scale_inv @@ -139,6 +141,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "fake_dtype": self._dtype, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]: @@ -175,8 +178,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr return self._columnwise_data raise ValueError("No data to get, both rowwise_data and columnwise_data are False") - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype return _FromMXFP8Func.forward(None, self, dtype) def size(self, *args, **kwargs): @@ -238,6 +243,7 @@ def view(self, shape: torch.Size): fp8_dtype=self._fp8_dtype, quantizer=self._quantizer, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + fake_dtype=self._dtype, ) def __repr__(self): diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index e7509f3994..6d7f1efb88 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -106,10 +106,14 @@ def __new__( quantizer: Optional[Quantizer], with_gemm_swizzled_scales: bool, *args, + fake_dtype: Optional[torch.dtype] = None, **kwargs, ): - - instance = super().__new__(cls, *args, **kwargs) + if cls is NVFP4TensorStorage: + instance = object.__new__(cls) + instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 + else: + instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data @@ -168,6 +172,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp4_dtype": self._fp4_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "fake_dtype": self._dtype, } def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]: @@ -204,8 +209,10 @@ def get_data_tensors(self): """Get this Tensor's data.""" return self._rowwise_data, self._columnwise_data - def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" + if dtype is None: + dtype = self._dtype return _FromNVFP4Func.forward(None, self, dtype) def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: @@ -295,6 +302,7 @@ def view(self, shape: torch.Size): quantizer=self._quantizer, fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + fake_dtype=self._dtype, ) def __repr__(self): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 47af9fabe1..43eb6f2de2 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -242,6 +242,7 @@ def forward( fp8_dtype=mixed_x_layer._fp8_dtype, data=x.squeeze(split_dim) if squeeze else x, shape=x.squeeze(split_dim).shape if squeeze else x.shape, + fake_dtype=mixed_x_layer._dtype, quantizer=mixed_x_layer._quantizer, ) for x in torch.split(