Skip to content

Commit 8a4d85c

Browse files
Cleanups to the last PR. (Comfy-Org#12646)
1 parent a452201 commit 8a4d85c

File tree

2 files changed

+22
-39
lines changed

2 files changed

+22
-39
lines changed

comfy/conds.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,25 @@
44
import logging
55

66

7+
def is_equal(x, y):
8+
if torch.is_tensor(x) and torch.is_tensor(y):
9+
return torch.equal(x, y)
10+
elif isinstance(x, dict) and isinstance(y, dict):
11+
if x.keys() != y.keys():
12+
return False
13+
return all(is_equal(x[k], y[k]) for k in x)
14+
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
15+
if type(x) is not type(y) or len(x) != len(y):
16+
return False
17+
return all(is_equal(a, b) for a, b in zip(x, y))
18+
else:
19+
try:
20+
return x == y
21+
except Exception:
22+
logging.warning("comparison issue with COND")
23+
return False
24+
25+
726
class CONDRegular:
827
def __init__(self, cond):
928
self.cond = cond
@@ -84,7 +103,7 @@ def process_cond(self, batch_size, **kwargs):
84103
return self._copy_with(self.cond)
85104

86105
def can_concat(self, other):
87-
if self.cond != other.cond:
106+
if not is_equal(self.cond, other.cond):
88107
return False
89108
return True
90109

comfy/model_base.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -65,42 +65,6 @@
6565
if TYPE_CHECKING:
6666
from comfy.model_patcher import ModelPatcher
6767

68-
69-
class _CONDGuideEntries(comfy.conds.CONDConstant):
70-
"""CONDConstant subclass that safely compares guide_attention_entries.
71-
72-
guide_attention_entries may contain ``pixel_mask`` tensors. The default
73-
``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError``
74-
on tensors. This subclass performs a structural comparison instead.
75-
"""
76-
77-
def can_concat(self, other):
78-
if not isinstance(other, _CONDGuideEntries):
79-
return False
80-
a, b = self.cond, other.cond
81-
if len(a) != len(b):
82-
return False
83-
for ea, eb in zip(a, b):
84-
if ea["pre_filter_count"] != eb["pre_filter_count"]:
85-
return False
86-
if ea["strength"] != eb["strength"]:
87-
return False
88-
if ea.get("latent_shape") != eb.get("latent_shape"):
89-
return False
90-
a_has = ea.get("pixel_mask") is not None
91-
b_has = eb.get("pixel_mask") is not None
92-
if a_has != b_has:
93-
return False
94-
if a_has:
95-
pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"]
96-
if pm_a is not pm_b:
97-
if (pm_a.shape != pm_b.shape
98-
or pm_a.device != pm_b.device
99-
or pm_a.dtype != pm_b.dtype
100-
or not torch.equal(pm_a, pm_b)):
101-
return False
102-
return True
103-
10468
class ModelType(Enum):
10569
EPS = 1
10670
V_PREDICTION = 2
@@ -1012,7 +976,7 @@ def extra_conds(self, **kwargs):
1012976

1013977
guide_attention_entries = kwargs.get("guide_attention_entries", None)
1014978
if guide_attention_entries is not None:
1015-
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
979+
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
1016980

1017981
return out
1018982

@@ -1068,7 +1032,7 @@ def extra_conds(self, **kwargs):
10681032

10691033
guide_attention_entries = kwargs.get("guide_attention_entries", None)
10701034
if guide_attention_entries is not None:
1071-
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
1035+
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
10721036

10731037
return out
10741038

0 commit comments

Comments
 (0)