diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 2e8ef0687462..9fd865f20d2d 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -152,6 +152,7 @@ def forward_orig( transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: + transformer_options = transformer_options.copy() patches_replace = transformer_options.get("patches_replace", {}) # running on sequences img @@ -228,6 +229,7 @@ def block_wrap(args): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" + transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] for i, block in enumerate(self.single_blocks): transformer_options["block_index"] = i if i not in self.skip_dit: diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 3518a19224a3..8b3f500d735e 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -196,6 +196,9 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N else: (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + # prepare image for attention img_modulated = self.img_norm1(img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) @@ -224,6 +227,12 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v + if "attn1_output_patch" in transformer_patches: + extra_options["img_slice"] = [txt.shape[1], attn.shape[1]] + patch = transformer_patches["attn1_output_patch"] + for p in patch: + attn = p(attn, extra_options) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks @@ -303,6 +312,9 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation else: mod = vec + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) @@ -312,6 +324,12 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v + + if "attn1_output_patch" in transformer_patches: + patch = transformer_patches["attn1_output_patch"] + for p in patch: + attn = p(attn, extra_options) + # compute activation in mlp stream, cat again and run second linear layer if self.yak_mlp: mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 260ccad7e67b..ef4dcf7c5616 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -142,6 +142,7 @@ def forward_orig( attn_mask: Tensor = None, ) -> Tensor: + transformer_options = transformer_options.copy() patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: @@ -231,6 +232,7 @@ def block_wrap(args): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" + transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] for i, block in enumerate(self.single_blocks): transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 563f28f6b302..b94cdfa8762e 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -304,6 +304,7 @@ def forward_orig( control=None, transformer_options={}, ) -> Tensor: + transformer_options = transformer_options.copy() patches_replace = transformer_options.get("patches_replace", {}) initial_shape = list(img.shape) @@ -416,6 +417,7 @@ def block_wrap(args): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" + transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] for i, block in enumerate(self.single_blocks): transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: diff --git a/comfy/model_base.py b/comfy/model_base.py index 4a74cb1cef4d..9dcef8741423 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -178,10 +178,7 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn - dtype = self.get_dtype() - - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype + dtype = self.get_dtype_inference() xc = xc.to(dtype) device = xc.device @@ -218,6 +215,13 @@ def process_timestep(self, timestep, **kwargs): def get_dtype(self): return self.diffusion_model.dtype + def get_dtype_inference(self): + dtype = self.get_dtype() + + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + return dtype + def encode_adm(self, **kwargs): return None @@ -372,9 +376,7 @@ def memory_required(self, input_shape, cond_shapes={}): input_shapes += shape if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype + dtype = self.get_dtype_inference() #TODO: this needs to be tweaked area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) @@ -1165,7 +1167,7 @@ def extra_conds(self, **kwargs): t5xxl_ids = t5xxl_ids.unsqueeze(0) if torch.is_inference_mode_enabled(): # if not we are training - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype())) + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference())) else: out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids) out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f01818f50c2d..21b4ce53e3fd 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -406,13 +406,16 @@ def clone_has_same_weights(self, clone: 'ModelPatcher'): def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) + def disable_model_cfg1_optimization(self): + self.model_options["disable_cfg1_optimization"] = True + def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: self.model_options["sampler_cfg_function"] = sampler_cfg_function if disable_cfg1_optimization: - self.model_options["disable_cfg1_optimization"] = True + self.disable_model_cfg1_optimization() def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization) diff --git a/comfy_extras/nodes_nag.py b/comfy_extras/nodes_nag.py new file mode 100644 index 000000000000..033e40eb9206 --- /dev/null +++ b/comfy_extras/nodes_nag.py @@ -0,0 +1,99 @@ +import torch +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + + +class NAGuidance(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="NAGuidance", + display_name="Normalized Attention Guidance", + description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.", + category="", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to apply NAG to."), + io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."), + io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."), + io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01), + # io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."), + # io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."), + ], + outputs=[ + io.Model.Output(tooltip="The patched model with NAG enabled."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput: + m = model.clone() + + # sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent) + # sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent) + + def nag_attention_output_patch(out, extra_options): + cond_or_uncond = extra_options.get("cond_or_uncond", None) + if cond_or_uncond is None: + return out + + if not (1 in cond_or_uncond and 0 in cond_or_uncond): + return out + + # sigma = extra_options.get("sigmas", None) + # if sigma is not None and len(sigma) > 0: + # sigma = sigma[0].item() + # if sigma > sigma_start or sigma < sigma_end: + # return out + + img_slice = extra_options.get("img_slice", None) + + if img_slice is not None: + orig_out = out + out = out[:, img_slice[0]:img_slice[1]] # only apply on img part + + batch_size = out.shape[0] + half_size = batch_size // len(cond_or_uncond) + + ind_neg = cond_or_uncond.index(1) + ind_pos = cond_or_uncond.index(0) + z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)] + z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)] + + guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0) + + eps = 1e-6 + norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps) + norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps) + + ratio = norm_guided / norm_pos + scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio + + guided_normalized = guided * scale_factor + + z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha) + + if img_slice is not None: + orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final + orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final + return orig_out + else: + out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final + return out + + m.set_model_attn1_output_patch(nag_attention_output_patch) + m.disable_model_cfg1_optimization() + + return io.NodeOutput(m) + + +class NagExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + NAGuidance, + ] + + +async def comfy_entrypoint() -> NagExtension: + return NagExtension() diff --git a/comfyui_version.py b/comfyui_version.py index cf4e89816e35..8f7f3228efcb 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.13.0" +__version__ = "0.14.0" diff --git a/nodes.py b/nodes.py index db5f98408782..dff56b79ce1a 100644 --- a/nodes.py +++ b/nodes.py @@ -2437,6 +2437,7 @@ async def init_builtin_extra_nodes(): "nodes_color.py", "nodes_toolkit.py", "nodes_replacements.py", + "nodes_nag.py", ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index 9dab9a50c3c2..b132bb9c4b8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.13.0" +version = "0.14.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10"