|
65 | 65 | if TYPE_CHECKING: |
66 | 66 | from comfy.model_patcher import ModelPatcher |
67 | 67 |
|
68 | | - |
69 | | -class _CONDGuideEntries(comfy.conds.CONDConstant): |
70 | | - """CONDConstant subclass that safely compares guide_attention_entries. |
71 | | -
|
72 | | - guide_attention_entries may contain ``pixel_mask`` tensors. The default |
73 | | - ``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError`` |
74 | | - on tensors. This subclass performs a structural comparison instead. |
75 | | - """ |
76 | | - |
77 | | - def can_concat(self, other): |
78 | | - if not isinstance(other, _CONDGuideEntries): |
79 | | - return False |
80 | | - a, b = self.cond, other.cond |
81 | | - if len(a) != len(b): |
82 | | - return False |
83 | | - for ea, eb in zip(a, b): |
84 | | - if ea["pre_filter_count"] != eb["pre_filter_count"]: |
85 | | - return False |
86 | | - if ea["strength"] != eb["strength"]: |
87 | | - return False |
88 | | - if ea.get("latent_shape") != eb.get("latent_shape"): |
89 | | - return False |
90 | | - a_has = ea.get("pixel_mask") is not None |
91 | | - b_has = eb.get("pixel_mask") is not None |
92 | | - if a_has != b_has: |
93 | | - return False |
94 | | - if a_has: |
95 | | - pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"] |
96 | | - if pm_a is not pm_b: |
97 | | - if (pm_a.shape != pm_b.shape |
98 | | - or pm_a.device != pm_b.device |
99 | | - or pm_a.dtype != pm_b.dtype |
100 | | - or not torch.equal(pm_a, pm_b)): |
101 | | - return False |
102 | | - return True |
103 | | - |
104 | 68 | class ModelType(Enum): |
105 | 69 | EPS = 1 |
106 | 70 | V_PREDICTION = 2 |
@@ -1012,7 +976,7 @@ def extra_conds(self, **kwargs): |
1012 | 976 |
|
1013 | 977 | guide_attention_entries = kwargs.get("guide_attention_entries", None) |
1014 | 978 | if guide_attention_entries is not None: |
1015 | | - out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) |
| 979 | + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) |
1016 | 980 |
|
1017 | 981 | return out |
1018 | 982 |
|
@@ -1068,7 +1032,7 @@ def extra_conds(self, **kwargs): |
1068 | 1032 |
|
1069 | 1033 | guide_attention_entries = kwargs.get("guide_attention_entries", None) |
1070 | 1034 | if guide_attention_entries is not None: |
1071 | | - out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) |
| 1035 | + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) |
1072 | 1036 |
|
1073 | 1037 | return out |
1074 | 1038 |
|
|
0 commit comments