From f262444dd4818b6acdbc1350856679dd6245f7f5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 18 Feb 2026 15:36:35 -0800 Subject: [PATCH 1/2] Add simple 3 band equalizer node for audio. (#12519) --- comfy_extras/nodes_audio.py | 62 +++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index b63dd8e9717e..7e74169f2e37 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -698,6 +698,67 @@ def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput: create_empty_audio = execute # TODO: remove +class AudioEqualizer3Band(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioEqualizer3Band", + search_aliases=["eq", "bass boost", "treble boost", "equalizer"], + display_name="Audio Equalizer (3-Band)", + category="audio", + is_experimental=True, + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"), + IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"), + IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"), + IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"), + IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"), + IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"), + IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput: + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + eq_waveform = waveform.clone() + + # 1. Apply Low Shelf (Bass) + if low_gain_dB != 0: + eq_waveform = torchaudio.functional.bass_biquad( + eq_waveform, + sample_rate, + gain=low_gain_dB, + central_freq=float(low_freq), + Q=0.707 + ) + + # 2. Apply Peaking EQ (Mids) + if mid_gain_dB != 0: + eq_waveform = torchaudio.functional.equalizer_biquad( + eq_waveform, + sample_rate, + center_freq=float(mid_freq), + gain=mid_gain_dB, + Q=mid_q + ) + + # 3. Apply High Shelf (Treble) + if high_gain_dB != 0: + eq_waveform = torchaudio.functional.treble_biquad( + eq_waveform, + sample_rate, + gain=high_gain_dB, + central_freq=float(high_freq), + Q=0.707 + ) + + return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate}) + + class AudioExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -720,6 +781,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: AudioMerge, AudioAdjustVolume, EmptyAudio, + AudioEqualizer3Band, ] async def comfy_entrypoint() -> AudioExtension: From 6d11cc73549e14a0a31e9ff8c90bfd71b380fe2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 19 Feb 2026 03:49:43 +0200 Subject: [PATCH 2/2] feat: Add basic text generation support with native models, initially supporting Gemma3 (#12392) --- comfy/sd.py | 29 +++- comfy/sd1_clip.py | 18 +++ comfy/text_encoders/llama.py | 148 +++++++++++++++++++- comfy/text_encoders/lt.py | 90 +++++++++--- comfy/text_encoders/lumina2.py | 36 ++++- comfy/text_encoders/spiece_tokenizer.py | 27 +++- comfy/utils.py | 8 ++ comfy_extras/nodes_textgen.py | 176 ++++++++++++++++++++++++ nodes.py | 1 + 9 files changed, 501 insertions(+), 32 deletions(-) create mode 100644 comfy_extras/nodes_textgen.py diff --git a/comfy/sd.py b/comfy/sd.py index f65e7caddd55..164f30803cd3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -423,6 +423,19 @@ def load_model(self, tokens={}): def get_key_patches(self): return self.patcher.get_key_patches() + def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): + self.cond_stage_model.reset_clip_options() + + if self.layer_idx is not None: + self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) + + self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) + return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) + + def decode(self, token_ids, skip_special_tokens=True): + return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format @@ -1182,6 +1195,7 @@ class TEModel(Enum): JINA_CLIP_2 = 19 QWEN3_8B = 20 QWEN3_06B = 21 + GEMMA_3_4B_VISION = 22 def detect_te_model(sd): @@ -1210,7 +1224,10 @@ def detect_te_model(sd): if 'model.layers.47.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: - return TEModel.GEMMA_3_4B + if 'vision_model.embeddings.patch_embedding.weight' in sd: + return TEModel.GEMMA_3_4B_VISION + else: + return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B if 'model.layers.0.self_attn.k_proj.bias' in sd: weight = sd['model.layers.0.self_attn.k_proj.bias'] @@ -1270,6 +1287,8 @@ class EmptyClass: else: if "text_projection" in clip_data[i]: clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node + if "lm_head.weight" in clip_data[i]: + clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models tokenizer_data = {} clip_target = EmptyClass() @@ -1335,6 +1354,14 @@ class EmptyClass: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.GEMMA_3_4B_VISION: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision") + clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.GEMMA_3_12B: + clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index b564d152989d..d9d014055e6d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -308,6 +308,15 @@ def encode(self, tokens): def load_sd(self, sd): return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]): + if isinstance(tokens, dict): + tokens_only = next(iter(tokens.values())) # todo: get this better? + else: + tokens_only = tokens + tokens_only = [[t[0] for t in b] for b in tokens_only] + embeds = self.process_tokens(tokens_only, device=self.execution_device)[0] + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens) + def parse_parentheses(string): result = [] current_item = "" @@ -663,6 +672,9 @@ def untokenize(self, token_weight_pair): def state_dict(self): return {} + def decode(self, token_ids, skip_special_tokens=True): + return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + class SD1Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None): if name is not None: @@ -686,6 +698,9 @@ def untokenize(self, token_weight_pair): def state_dict(self): return getattr(self, self.clip).state_dict() + def decode(self, token_ids, skip_special_tokens=True): + return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens) + class SD1CheckpointClipModel(SDClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options) @@ -722,3 +737,6 @@ def encode_token_weights(self, token_weight_pairs): def load_sd(self, sd): return getattr(self, self.clip).load_sd(sd) + + def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): + return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 54f3d5595416..e5d21fa74a91 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Optional, Any, Tuple import math +from tqdm import tqdm +import comfy.utils from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management @@ -313,6 +315,13 @@ class Gemma3_4B_Config: final_norm: bool = True lm_head: bool = False +GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + +@dataclass +class Gemma3_4B_Vision_Config(Gemma3_4B_Config): + vision_config = GEMMA3_VISION_CONFIG + mm_tokens_per_image = 256 + @dataclass class Gemma3_12B_Config: vocab_size: int = 262208 @@ -336,7 +345,7 @@ class Gemma3_12B_Config: rope_scale = [8.0, 1.0] final_norm: bool = True lm_head: bool = False - vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + vision_config = GEMMA3_VISION_CONFIG mm_tokens_per_image = 256 class RMSNorm(nn.Module): @@ -441,8 +450,10 @@ def forward( freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + sliding_window: Optional[int] = None, ): batch_size, seq_length, _ = hidden_states.shape + xq = self.q_proj(hidden_states) xk = self.k_proj(hidden_states) xv = self.v_proj(hidden_states) @@ -477,6 +488,11 @@ def forward( else: present_key_value = (xk, xv, index + num_tokens) + if sliding_window is not None and xk.shape[2] > sliding_window: + xk = xk[:, :, -sliding_window:] + xv = xv[:, :, -sliding_window:] + attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) @@ -559,10 +575,12 @@ def forward( optimized_attention=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): + sliding_window = None if self.transformer_type == 'gemma3': if self.sliding_attention: + sliding_window = self.sliding_attention if x.shape[1] > self.sliding_attention: - sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype) + sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype) sliding_mask.tril_(diagonal=-self.sliding_attention) if attention_mask is not None: attention_mask = attention_mask + sliding_mask @@ -581,6 +599,7 @@ def forward( freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value, + sliding_window=sliding_window, ) x = self.post_attention_layernorm(x) @@ -765,6 +784,104 @@ def set_input_embeddings(self, embeddings): def forward(self, input_ids, *args, **kwargs): return self.model(input_ids, *args, **kwargs) +class BaseGenerate: + def logits(self, x): + input = x[:, -1:] + if hasattr(self.model, "lm_head"): + module = self.model.lm_head + else: + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x + + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0): + device = embeds.device + model_config = self.model.config + + if execution_dtype is None: + if comfy.model_management.should_use_bf16(device): + execution_dtype = torch.bfloat16 + else: + execution_dtype = torch.float32 + embeds = embeds.to(execution_dtype) + + if embeds.ndim == 2: + embeds = embeds.unsqueeze(0) + + past_key_values = [] #kv_cache init + max_cache_len = embeds.shape[1] + max_length + for x in range(model_config.num_hidden_layers): + past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), + torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) + + generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None + + generated_token_ids = [] + pbar = comfy.utils.ProgressBar(max_length) + + # Generation loop + for step in tqdm(range(max_length), desc="Generating tokens"): + x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) + logits = self.logits(x)[:, -1] + next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample) + token_id = next_token[0].item() + generated_token_ids.append(token_id) + + embeds = self.model.embed_tokens(next_token).to(execution_dtype) + pbar.update(1) + + if token_id in stop_tokens: + break + + return generated_token_ids + + def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True): + + if not do_sample or temperature == 0.0: + return torch.argmax(logits, dim=-1, keepdim=True) + + # Sampling mode + if repetition_penalty != 1.0: + for i in range(logits.shape[0]): + for token_id in set(token_history): + logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty + + if temperature != 1.0: + logits = logits / temperature + + if top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = torch.finfo(logits.dtype).min + + if min_p > 0.0: + probs_before_filter = torch.nn.functional.softmax(logits, dim=-1) + top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True) + min_threshold = min_p * top_probs + indices_to_remove = probs_before_filter < min_threshold + logits[indices_to_remove] = torch.finfo(logits.dtype).min + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 0] = False + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) + indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = torch.finfo(logits.dtype).min + + probs = torch.nn.functional.softmax(logits, dim=-1) + + return torch.multinomial(probs, num_samples=1, generator=generator) + class BaseQwen3: def logits(self, x): input = x[:, -1:] @@ -871,7 +988,7 @@ def __init__(self, config_dict, dtype, device, operations): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen25_7BVLI(BaseLlama, torch.nn.Module): +class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen25_7BVLI_Config(**config_dict) @@ -881,6 +998,9 @@ def __init__(self, config_dict, dtype, device, operations): self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) self.dtype = dtype + # todo: should this be tied or not? + #self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) + def preprocess_embed(self, embed, device): if embed["type"] == "image": image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) @@ -923,7 +1043,7 @@ def __init__(self, config_dict, dtype, device, operations): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Gemma3_4B(BaseLlama, torch.nn.Module): +class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Gemma3_4B_Config(**config_dict) @@ -932,7 +1052,25 @@ def __init__(self, config_dict, dtype, device, operations): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Gemma3_12B(BaseLlama, torch.nn.Module): +class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_4B_Vision_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations) + self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations) + self.image_size = config.vision_config["image_size"] + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True) + return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None + return None, None + +class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Gemma3_12B_Config(**config_dict) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 9cf87c0b290e..82fbacf59c60 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,6 +6,7 @@ from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import torch import comfy.utils +import math class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -22,40 +23,79 @@ def ltxv_te(*args, **kwargs): return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) -class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): +class Gemma3_Tokenizer(): + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + + def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs): + self.llama_template = "system\nYou are a helpful assistant.\nuser\n{}\nmodel\n" + self.llama_template_images = "system\nYou are a helpful assistant.\nuser\n\n{}\n\nmodel\n" + + if image is None: + images = [] + else: + samples = image.movedim(-1, 1) + total = int(896 * 896) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1) + images = [s[:, :, :, :3]] + + if text.startswith(''): + skip_template = True + + if skip_template: + llama_text = text + else: + if llama_template is None: + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) + + if len(images) > 0: + embed_count = 0 + for r in text_tokens: + for i, token in enumerate(r): + if token[0] == 262144 and embed_count < len(images): + r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:] + embed_count += 1 + return text_tokens + +class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + special_tokens = {"": 262144, "": 106} + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) - def state_dict(self): - return {"spiece_model": self.tokenizer.serialize_model()} class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) + class Gemma3_12BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) if llama_quantization_metadata is not None: model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata - + self.dtypes = set() + self.dtypes.add(dtype) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) - def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs): - text = llama_template.format(text) - text_tokens = super().tokenize_with_weights(text, return_word_ids) - embed_count = 0 - for k in text_tokens: - tt = text_tokens[k] - for r in tt: - for i in range(len(r)): - if r[i][0] == 262144: - if image_embeds is not None and embed_count < image_embeds.shape[0]: - r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] - embed_count += 1 - return text_tokens + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + tokens_only = [[t[0] for t in b] for b in tokens] + embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) + comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is class LTXAVTEModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): @@ -112,6 +152,9 @@ def encode_token_weights(self, token_weight_pairs): return out.to(out_device), pooled + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) + def load_sd(self, sd): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) @@ -152,3 +195,14 @@ def __init__(self, device="cpu", dtype=None, model_options={}): dtype = dtype_llama super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return LTXAVTEModel_ + +def gemma3_te(dtype_llama=None, llama_quantization_metadata=None): + class Gemma3_12BModel_(Gemma3_12BModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Gemma3_12BModel_ diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index b29a7cc873e0..1b731e09441e 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -1,23 +1,23 @@ from comfy import sd1_clip from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.llama - +from comfy.text_encoders.lt import Gemma3_Tokenizer +import comfy.utils class Gemma2BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + special_tokens = {"": 107} + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} -class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): +class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data) - - def state_dict(self): - return {"spiece_model": self.tokenizer.serialize_model()} + special_tokens = {"": 262144, "": 106} + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data) class LuminaTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -31,6 +31,9 @@ class Gemma2_2BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107]) + class Gemma3_4BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) @@ -40,6 +43,23 @@ def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, atten super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) + +class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def process_tokens(self, tokens, device): + embeds, _, _, embeds_info = super().process_tokens(tokens, device) + comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + return embeds + class LuminaModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel): super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) @@ -50,6 +70,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b model = Gemma2_2BModel elif model_type == "gemma3_4b": model = Gemma3_4BModel + elif model_type == "gemma3_4b_vision": + model = Gemma3_4B_Vision_Model class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index caccb3ca283b..099d8d2d98a5 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -6,9 +6,10 @@ class SPieceTokenizer: def from_pretrained(path, **kwargs): return SPieceTokenizer(path, **kwargs) - def __init__(self, tokenizer_path, add_bos=False, add_eos=True): + def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None): self.add_bos = add_bos self.add_eos = add_eos + self.special_tokens = special_tokens import sentencepiece if torch.is_tensor(tokenizer_path): tokenizer_path = tokenizer_path.numpy().tobytes() @@ -27,8 +28,32 @@ def get_vocab(self): return out def __call__(self, string): + if self.special_tokens is not None: + import re + special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys()) + if special_tokens_pattern and re.search(special_tokens_pattern, string): + parts = re.split(f'({special_tokens_pattern})', string) + result = [] + for part in parts: + if not part: + continue + if part in self.special_tokens: + result.append(self.special_tokens[part]) + else: + encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False) + result.extend(encoded) + return {"input_ids": result} + out = self.tokenizer.encode(string) return {"input_ids": out} + def decode(self, token_ids, skip_special_tokens=False): + + if skip_special_tokens and self.special_tokens: + special_token_ids = set(self.special_tokens.values()) + token_ids = [tid for tid in token_ids if tid not in special_token_ids] + + return self.tokenizer.decode(token_ids) + def serialize_model(self): return torch.ByteTensor(list(self.tokenizer.serialized_model_proto())) diff --git a/comfy/utils.py b/comfy/utils.py index c1ce540b5555..17443b4ccd6d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1418,3 +1418,11 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res + +def normalize_image_embeddings(embeds, embeds_info, scale_factor): + """Normalize image embeddings to match text embedding scale""" + for info in embeds_info: + if info.get("type") == "image": + start_idx = info["index"] + end_idx = start_idx + info["size"] + embeds[:, start_idx:end_idx, :] /= scale_factor diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py new file mode 100644 index 000000000000..dd4f6b0d39b2 --- /dev/null +++ b/comfy_extras/nodes_textgen.py @@ -0,0 +1,176 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + +class TextGenerate(io.ComfyNode): + @classmethod + def define_schema(cls): + # Define dynamic combo options for sampling mode + sampling_options = [ + io.DynamicCombo.Option( + key="on", + inputs=[ + io.Float.Input("temperature", default=0.7, min=0.01, max=2.0, step=0.000001), + io.Int.Input("top_k", default=64, min=0, max=1000), + io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01), + io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01), + io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01), + io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), + ] + ), + io.DynamicCombo.Option( + key="off", + inputs=[] + ), + ] + + return io.Schema( + node_id="TextGenerate", + category="textgen/", + search_aliases=["LLM", "gemma"], + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), + io.Image.Input("image", optional=True), + io.Int.Input("max_length", default=256, min=1, max=2048), + io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), + ], + outputs=[ + io.String.Output(display_name="generated_text"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput: + + tokens = clip.tokenize(prompt, image=image, skip_template=False) + + # Get sampling parameters from dynamic combo + do_sample = sampling_mode.get("sampling_mode") == "on" + temperature = sampling_mode.get("temperature", 1.0) + top_k = sampling_mode.get("top_k", 50) + top_p = sampling_mode.get("top_p", 1.0) + min_p = sampling_mode.get("min_p", 0.0) + seed = sampling_mode.get("seed", None) + repetition_penalty = sampling_mode.get("repetition_penalty", 1.0) + + generated_ids = clip.generate( + tokens, + do_sample=do_sample, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + min_p=min_p, + repetition_penalty=repetition_penalty, + seed=seed + ) + + generated_text = clip.decode(generated_ids, skip_special_tokens=True) + return io.NodeOutput(generated_text) + + +LTX2_T2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model. +#### Guidelines +- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio). + - If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc. + - For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters. +- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements. +- Maintain chronological flow: use temporal connectors ("as," "then," "while"). +- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present"). +- Speech (only when requested): + - For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'"). + - Specify language if not English and accent if relevant. +- Style: Include visual style at the beginning: "Style: