diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py index d9054965862f..f2b130bc113e 100644 --- a/comfy/ldm/ace/ace_step15.py +++ b/comfy/ldm/ace/ace_step15.py @@ -183,7 +183,7 @@ def forward( else: attn_bias = window_bias - attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True) + attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False) attn_output = self.o_proj(attn_output) return attn_output @@ -1035,8 +1035,7 @@ def prepare_condition( audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847) lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype) else: - assert False - # TODO ? + lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed) lm_hints = self.detokenizer(lm_hints_5Hz) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ccf690945aaa..10d0513256cc 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if kwargs.get("low_precision_attention", True) is False: + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) + exception_fallback = False if skip_reshape: b, _, _, dim_head = q.shape diff --git a/comfy/model_base.py b/comfy/model_base.py index 89944548cf3d..a2a34f191eb6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1548,6 +1548,7 @@ def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) device = kwargs["device"] + noise = kwargs["noise"] cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: @@ -1571,15 +1572,19 @@ def extra_conds(self, **kwargs): 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, - 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750) + 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2]) + pass_audio_codes = True else: - refer_audio = refer_audio[-1] - out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) + refer_audio = refer_audio[-1][:, :, :noise.shape[2]] + pass_audio_codes = False - audio_codes = kwargs.get("audio_codes", None) - if audio_codes is not None: - out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) + if pass_audio_codes: + audio_codes = kwargs.get("audio_codes", None) + if audio_codes is not None: + out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) + refer_audio = refer_audio[:, :, :750] + out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) return out class Omnigen2(BaseModel): diff --git a/comfy/ops.py b/comfy/ops.py index 53c5e4dc3b4e..0f4eca7c768d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -54,6 +54,8 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) def scaled_dot_product_attention(q, k, v, *args, **kwargs): + if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) else: diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index fce2b67cec97..74e62733eb39 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -101,9 +101,7 @@ def sample_manual_loop_no_classes( return output_audio_codes -def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0): - cfg_scale = 2.0 - +def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0): positive = [[token for token, _ in inner_list] for inner_list in positive] negative = [[token for token, _ in inner_list] for inner_list in negative] positive = positive[0] @@ -120,7 +118,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102 positive = [model.special_tokens["pad"]] * pos_pad + positive paddings = [pos_pad, neg_pad] - return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) class ACE15Tokenizer(sd1_clip.SD1Tokenizer): @@ -137,6 +135,12 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): language = kwargs.get("language", "en") seed = kwargs.get("seed", 0) + generate_audio_codes = kwargs.get("generate_audio_codes", True) + cfg_scale = kwargs.get("cfg_scale", 2.0) + temperature = kwargs.get("temperature", 0.85) + top_p = kwargs.get("top_p", 0.9) + top_k = kwargs.get("top_k", 0.0) + duration = math.ceil(duration) meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature) lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n" @@ -147,7 +151,14 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs) out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) - out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed} + out["lm_metadata"] = {"min_tokens": duration * 5, + "seed": seed, + "generate_audio_codes": generate_audio_codes, + "cfg_scale": cfg_scale, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + } return out @@ -203,10 +214,14 @@ def encode_token_weights(self, token_weight_pairs): self.qwen3_06b.set_clip_options({"layer": [0]}) lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) + out = {"conditioning_lyrics": lyrics_embeds[:, 0]} + lm_metadata = token_weight_pairs["lm_metadata"] - audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) + if lm_metadata["generate_audio_codes"]: + audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"]) + out["audio_codes"] = [audio_codes] - return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]} + return base_out, None, out def set_clip_options(self, options): self.qwen3_06b.set_clip_options(options) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 3afd094d1fdd..b6735d210a43 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -651,10 +651,10 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) if seq_len > 1: - causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1) + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1) if mask is not None: mask += causal_mask else: diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 376584e5c11f..dde5bbd2adac 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -44,13 +44,18 @@ def define_schema(cls): io.Combo.Input("timesignature", options=['2', '3', '4', '6']), io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), + io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), + io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), + io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), + io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), ], outputs=[io.Conditioning.Output()], ) @classmethod - def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput: - tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed) + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k) conditioning = clip.encode_from_tokens_scheduled(tokens) return io.NodeOutput(conditioning) @@ -100,14 +105,15 @@ def execute(cls, seconds, batch_size) -> io.NodeOutput: latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) return io.NodeOutput({"samples": latent, "type": "audio"}) -class ReferenceTimbreAudio(io.ComfyNode): +class ReferenceAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="ReferenceTimbreAudio", + display_name="Reference Audio", category="advanced/conditioning/audio", is_experimental=True, - description="This node sets the reference audio for timbre (for ace step 1.5)", + description="This node sets the reference audio for ace step 1.5", inputs=[ io.Conditioning.Input("conditioning"), io.Latent.Input("latent", optional=True), @@ -131,7 +137,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: EmptyAceStepLatentAudio, TextEncodeAceStepAudio15, EmptyAceStep15LatentAudio, - ReferenceTimbreAudio, + ReferenceAudio, ] async def comfy_entrypoint() -> AceExtension: