-
Notifications
You must be signed in to change notification settings - Fork 653
[PyTorch] Add dtype information to QuantizedTensorStorage class #2676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f6ad0bb
ab059a8
8200276
be723b2
369f8b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,7 @@ class QuantizedTensorStorage: | |
| XTensor should only implement the functionality needed | ||
| to behave like regular torch.Tensor (like __torch_dispatch__).""" | ||
|
|
||
| _dtype: torch.dtype | ||
| _quantizer: Optional[Quantizer] | ||
|
|
||
| def update_usage( | ||
|
|
@@ -367,10 +368,13 @@ def __new__( | |
| shape: Iterable[int], | ||
| dtype: torch.dtype, | ||
| *, | ||
| fake_dtype: Optional[torch.dtype] = None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this redundant with the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is mostly to avoid issues with MRO and still have fairly straightforward constructors for the Storage classes.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also just noticed that the make_like call would be problematic there otherwise - we want to include the fake_dtype in get_metadata call, but if it was named dtype it would clash with the dtype that we pass directly in make_like. |
||
| requires_grad: bool = False, | ||
| device: Optional[torch.device] = None, | ||
| stride: Optional[Iterable[int]] = None, | ||
| ): | ||
| if fake_dtype is not None and fake_dtype != dtype: | ||
| raise ValueError(f"fake_dtype ({fake_dtype}) does not match dtype ({dtype})") | ||
|
Comment on lines
+376
to
+377
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validation breaks existing This new guard will cause regressions on every call to Confirmed breakage paths:
The root cause is that
|
||
| # For stride, We are assuming only contiguous tensors | ||
| # Calculate stride from shape if not provided. When creating this object from | ||
| # C++ code, we provide the stride computed from shape in C++ to avoid the | ||
|
|
@@ -485,7 +489,7 @@ def clear(self): | |
| ) | ||
|
|
||
| def __repr__(self, *, tensor_contents=None) -> str: | ||
| return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" | ||
| return f"{self.__class__.__name__}(data={self.dequantize()})" | ||
|
|
||
| def float(self) -> torch.Tensor: | ||
| # pylint: disable=missing-function-docstring | ||
|
|
@@ -588,7 +592,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): | |
|
|
||
| def maybe_unwrap(arg): | ||
| if isinstance(arg, QuantizedTensor): | ||
| return arg.dequantize(dtype=arg.dtype) | ||
| return arg.dequantize() | ||
| return arg | ||
|
|
||
| def maybe_update_inplace(arg, new_arg, schema_arg): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -75,14 +75,16 @@ def __new__( | |||||
| data: Optional[torch.Tensor], | ||||||
| fp8_scale_inv: torch.Tensor, | ||||||
| fp8_dtype: TE_DType, | ||||||
| fake_dtype: Optional[torch.dtype] = None, | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer to just name it
Suggested change
|
||||||
| data_transpose: Optional[torch.Tensor] = None, | ||||||
| quantizer: Optional[Quantizer] = None, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| if cls is Float8TensorStorage: | ||||||
| instance = object.__new__(cls) | ||||||
| instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 | ||||||
| else: | ||||||
| instance = super().__new__(cls, *args, **kwargs) | ||||||
| instance = super().__new__(cls, *args, fake_dtype=fake_dtype, **kwargs) | ||||||
| instance._data = data | ||||||
| instance._quantizer = quantizer.copy() if quantizer is not None else None | ||||||
| instance._fp8_dtype = fp8_dtype | ||||||
|
|
@@ -130,6 +132,7 @@ def get_metadata(self) -> Dict[str, Any]: | |||||
| "fp8_dtype": self._fp8_dtype, | ||||||
| "data_transpose": self._transpose, | ||||||
| "quantizer": self._quantizer, | ||||||
| "fake_dtype": self._dtype, | ||||||
| } | ||||||
|
|
||||||
| def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: | ||||||
|
|
@@ -159,8 +162,10 @@ def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = Tr | |||||
| return self._transpose | ||||||
| raise ValueError("No data to get, both rowwise_data and columnwise_data are False") | ||||||
|
|
||||||
| def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: | ||||||
| def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: | ||||||
| """Dequantize to a higher precision.""" | ||||||
| if dtype is None: | ||||||
| dtype = self._dtype | ||||||
| return _FromFloat8Func.forward(None, self, dtype) | ||||||
|
|
||||||
| def size(self, *args, **kwargs): | ||||||
|
|
@@ -192,6 +197,7 @@ def view(self, shape: torch.Size): | |||||
| data=out_data, | ||||||
| fp8_scale_inv=self._scale_inv, | ||||||
| fp8_dtype=self._fp8_dtype, | ||||||
| fake_dtype=self._dtype, | ||||||
| data_transpose=out_transpose, | ||||||
| quantizer=self._quantizer, | ||||||
| ) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No lazy-init guard for
_dtypeon storage objectsQuantizedTensor.dtype(line 405–409) has ahasattr(self, "_dtype")lazy-initializer that protects against deserialization from pre-PR checkpoints.QuantizedTensorStorageand its subclasses have no equivalent protection —_dtype: torch.dtypeis only a class-level annotation, not a default value.If an
*TensorStorageobject is unpickled from a checkpoint that was saved before this PR, the first call to.dequantize()(or the distributed-ops indistributed.pythat now accessinp._dtype) will raiseAttributeError: _dtype.Consider adding a similar lazy fallback in the
dequantizemethods, e.g.: