From 26dd7eb42180fb57c9da47e60d0a2bac659e47ad Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Feb 2026 15:25:06 -0800 Subject: [PATCH 1/4] Fix ace step nan issue on some hardware/pytorch configs. (#12289) --- comfy/text_encoders/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 3afd094d1fdd..b6735d210a43 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -651,10 +651,10 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) if seq_len > 1: - causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1) + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1) if mask is not None: mask += causal_mask else: From c8fcbd66eef0ab48d9fe7e4ee35c683a193af46b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Feb 2026 16:37:05 -0800 Subject: [PATCH 2/4] Try to fix ace text encoder slowness on some configs. (#12290) --- comfy/ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 53c5e4dc3b4e..0f4eca7c768d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -54,6 +54,8 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) def scaled_dot_product_attention(q, k, v, *args, **kwargs): + if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) else: From 6125b8097952a374009af39639ff45da85f65500 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:29:22 -0800 Subject: [PATCH 3/4] Add llm sampling options and make reference audio work on ace step 1.5 (#12295) --- comfy/ldm/ace/ace_step15.py | 3 +-- comfy/model_base.py | 17 +++++++++++------ comfy/text_encoders/ace15.py | 29 ++++++++++++++++++++++------- comfy_extras/nodes_ace.py | 16 +++++++++++----- 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index d9054965862f..17a37e573782 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -1035,8 +1035,7 @@ def prepare_condition( audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847) lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype) else: - assert False - # TODO ? + lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed) lm_hints = self.detokenizer(lm_hints_5Hz) diff --git a/comfy/model_base.py b/comfy/model_base.py index 89944548cf3d..a2a34f191eb6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1548,6 +1548,7 @@ def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) device = kwargs["device"] + noise = kwargs["noise"] cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: @@ -1571,15 +1572,19 @@ def extra_conds(self, **kwargs): 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, - 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750) + 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2]) + pass_audio_codes = True else: - refer_audio = refer_audio[-1] - out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) + refer_audio = refer_audio[-1][:, :, :noise.shape[2]] + pass_audio_codes = False - audio_codes = kwargs.get("audio_codes", None) - if audio_codes is not None: - out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) + if pass_audio_codes: + audio_codes = kwargs.get("audio_codes", None) + if audio_codes is not None: + out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) + refer_audio = refer_audio[:, :, :750] + out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) return out class Omnigen2(BaseModel): diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index fce2b67cec97..74e62733eb39 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -101,9 +101,7 @@ def sample_manual_loop_no_classes( return output_audio_codes -def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0): - cfg_scale = 2.0 - +def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0): positive = [[token for token, _ in inner_list] for inner_list in positive] negative = [[token for token, _ in inner_list] for inner_list in negative] positive = positive[0] @@ -120,7 +118,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102 positive = [model.special_tokens["pad"]] * pos_pad + positive paddings = [pos_pad, neg_pad] - return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) class ACE15Tokenizer(sd1_clip.SD1Tokenizer): @@ -137,6 +135,12 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): language = kwargs.get("language", "en") seed = kwargs.get("seed", 0) + generate_audio_codes = kwargs.get("generate_audio_codes", True) + cfg_scale = kwargs.get("cfg_scale", 2.0) + temperature = kwargs.get("temperature", 0.85) + top_p = kwargs.get("top_p", 0.9) + top_k = kwargs.get("top_k", 0.0) + duration = math.ceil(duration) meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature) lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n" @@ -147,7 +151,14 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs) out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) - out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed} + out["lm_metadata"] = {"min_tokens": duration * 5, + "seed": seed, + "generate_audio_codes": generate_audio_codes, + "cfg_scale": cfg_scale, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + } return out @@ -203,10 +214,14 @@ def encode_token_weights(self, token_weight_pairs): self.qwen3_06b.set_clip_options({"layer": [0]}) lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) + out = {"conditioning_lyrics": lyrics_embeds[:, 0]} + lm_metadata = token_weight_pairs["lm_metadata"] - audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) + if lm_metadata["generate_audio_codes"]: + audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"]) + out["audio_codes"] = [audio_codes] - return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]} + return base_out, None, out def set_clip_options(self, options): self.qwen3_06b.set_clip_options(options) diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 376584e5c11f..dde5bbd2adac 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -44,13 +44,18 @@ def define_schema(cls): io.Combo.Input("timesignature", options=['2', '3', '4', '6']), io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), + io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), + io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), + io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), + io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), ], outputs=[io.Conditioning.Output()], ) @classmethod - def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput: - tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed) + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k) conditioning = clip.encode_from_tokens_scheduled(tokens) return io.NodeOutput(conditioning) @@ -100,14 +105,15 @@ def execute(cls, seconds, batch_size) -> io.NodeOutput: latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) return io.NodeOutput({"samples": latent, "type": "audio"}) -class ReferenceTimbreAudio(io.ComfyNode): +class ReferenceAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="ReferenceTimbreAudio", + display_name="Reference Audio", category="advanced/conditioning/audio", is_experimental=True, - description="This node sets the reference audio for timbre (for ace step 1.5)", + description="This node sets the reference audio for ace step 1.5", inputs=[ io.Conditioning.Input("conditioning"), io.Latent.Input("latent", optional=True), @@ -131,7 +137,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: EmptyAceStepLatentAudio, TextEncodeAceStepAudio15, EmptyAceStep15LatentAudio, - ReferenceTimbreAudio, + ReferenceAudio, ] async def comfy_entrypoint() -> AceExtension: From a50c32d63fe55d073edd7af2242f0536f50b362e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:15:30 -0800 Subject: [PATCH 4/4] Disable sage attention on ace step 1.5 (#12297) --- comfy/ldm/ace/ace_step15.py | 2 +- comfy/ldm/modules/attention.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index 17a37e573782..f2b130bc113e 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -183,7 +183,7 @@ def forward( else: attn_bias = window_bias - attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True) + attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False) attn_output = self.o_proj(attn_output) return attn_output diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ccf690945aaa..10d0513256cc 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if kwargs.get("low_precision_attention", True) is False: + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) + exception_fallback = False if skip_reshape: b, _, _, dim_head = q.shape