Skip to content

Commit 5495589

Browse files
Respect the dtype the op was initialized in for non quant mixed op. (Comfy-Org#11282)
1 parent 982876d commit 5495589

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

comfy/ops.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,10 @@ def __init__(
497497
) -> None:
498498
super().__init__()
499499

500-
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
501-
# self.factory_kwargs = {"device": device, "dtype": dtype}
500+
if dtype is None:
501+
dtype = MixedPrecisionOps._compute_dtype
502+
503+
self.factory_kwargs = {"device": device, "dtype": dtype}
502504

503505
self.in_features = in_features
504506
self.out_features = out_features
@@ -530,7 +532,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
530532
layer_conf = json.loads(layer_conf.numpy().tobytes())
531533

532534
if layer_conf is None:
533-
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
535+
dtype = self.factory_kwargs["dtype"]
536+
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
537+
if dtype != MixedPrecisionOps._compute_dtype:
538+
self.comfy_cast_weights = True
534539
else:
535540
self.quant_format = layer_conf.get("format", None)
536541
if not self._full_precision_mm:

0 commit comments

Comments
 (0)