Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,19 @@ def load_model(self, tokens={}):
def get_key_patches(self):
return self.patcher.get_key_patches()

def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()

if self.layer_idx is not None:
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})

self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)

class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
Expand Down Expand Up @@ -1182,6 +1195,7 @@ class TEModel(Enum):
JINA_CLIP_2 = 19
QWEN3_8B = 20
QWEN3_06B = 21
GEMMA_3_4B_VISION = 22


def detect_te_model(sd):
Expand Down Expand Up @@ -1210,7 +1224,10 @@ def detect_te_model(sd):
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
if 'vision_model.embeddings.patch_embedding.weight' in sd:
return TEModel.GEMMA_3_4B_VISION
else:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
Expand Down Expand Up @@ -1270,6 +1287,8 @@ class EmptyClass:
else:
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
if "lm_head.weight" in clip_data[i]:
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models

tokenizer_data = {}
clip_target = EmptyClass()
Expand Down Expand Up @@ -1335,6 +1354,14 @@ class EmptyClass:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B_VISION:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_12B:
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
Expand Down
18 changes: 18 additions & 0 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,15 @@ def encode(self, tokens):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))

def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens)

def parse_parentheses(string):
result = []
current_item = ""
Expand Down Expand Up @@ -663,6 +672,9 @@ def untokenize(self, token_weight_pair):
def state_dict(self):
return {}

def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)

class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
if name is not None:
Expand All @@ -686,6 +698,9 @@ def untokenize(self, token_weight_pair):
def state_dict(self):
return getattr(self, self.clip).state_dict()

def decode(self, token_ids, skip_special_tokens=True):
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)

class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
Expand Down Expand Up @@ -722,3 +737,6 @@ def encode_token_weights(self, token_weight_pairs):

def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)

def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
148 changes: 143 additions & 5 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import Optional, Any, Tuple
import math
from tqdm import tqdm
import comfy.utils

from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
Expand Down Expand Up @@ -313,6 +315,13 @@ class Gemma3_4B_Config:
final_norm: bool = True
lm_head: bool = False

GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}

@dataclass
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256

@dataclass
class Gemma3_12B_Config:
vocab_size: int = 262208
Expand All @@ -336,7 +345,7 @@ class Gemma3_12B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256

class RMSNorm(nn.Module):
Expand Down Expand Up @@ -441,8 +450,10 @@ def forward(
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
sliding_window: Optional[int] = None,
):
batch_size, seq_length, _ = hidden_states.shape

xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
Expand Down Expand Up @@ -477,6 +488,11 @@ def forward(
else:
present_key_value = (xk, xv, index + num_tokens)

if sliding_window is not None and xk.shape[2] > sliding_window:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None

xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)

Expand Down Expand Up @@ -559,10 +575,12 @@ def forward(
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
sliding_window = None
if self.transformer_type == 'gemma3':
if self.sliding_attention:
sliding_window = self.sliding_attention
if x.shape[1] > self.sliding_attention:
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
Expand All @@ -581,6 +599,7 @@ def forward(
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
sliding_window=sliding_window,
)

x = self.post_attention_layernorm(x)
Expand Down Expand Up @@ -765,6 +784,104 @@ def set_input_embeddings(self, embeddings):
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)

class BaseGenerate:
def logits(self, x):
input = x[:, -1:]
if hasattr(self.model, "lm_head"):
module = self.model.lm_head
else:
module = self.model.embed_tokens

offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)

x = torch.nn.functional.linear(input, weight, None)

comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x

def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
model_config = self.model.config

if execution_dtype is None:
if comfy.model_management.should_use_bf16(device):
execution_dtype = torch.bfloat16
else:
execution_dtype = torch.float32
embeds = embeds.to(execution_dtype)

if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)

past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))

generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None

generated_token_ids = []
pbar = comfy.utils.ProgressBar(max_length)

# Generation loop
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
token_id = next_token[0].item()
generated_token_ids.append(token_id)

embeds = self.model.embed_tokens(next_token).to(execution_dtype)
pbar.update(1)

if token_id in stop_tokens:
break

return generated_token_ids

def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):

if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)

# Sampling mode
if repetition_penalty != 1.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty

if temperature != 1.0:
logits = logits / temperature

if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = torch.finfo(logits.dtype).min

if min_p > 0.0:
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
min_threshold = min_p * top_probs
indices_to_remove = probs_before_filter < min_threshold
logits[indices_to_remove] = torch.finfo(logits.dtype).min

if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 0] = False
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = torch.finfo(logits.dtype).min

probs = torch.nn.functional.softmax(logits, dim=-1)

return torch.multinomial(probs, num_samples=1, generator=generator)

class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
Expand Down Expand Up @@ -871,7 +988,7 @@ def __init__(self, config_dict, dtype, device, operations):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_7BVLI_Config(**config_dict)
Expand All @@ -881,6 +998,9 @@ def __init__(self, config_dict, dtype, device, operations):
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

# todo: should this be tied or not?
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)

def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
Expand Down Expand Up @@ -923,7 +1043,7 @@ def __init__(self, config_dict, dtype, device, operations):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Gemma3_4B(BaseLlama, torch.nn.Module):
class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
Expand All @@ -932,7 +1052,25 @@ def __init__(self, config_dict, dtype, device, operations):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

class Gemma3_12B(BaseLlama, torch.nn.Module):
class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Vision_Config(**config_dict)
self.num_layers = config.num_hidden_layers

self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
self.image_size = config.vision_config["image_size"]

def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
return None, None

class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_12B_Config(**config_dict)
Expand Down
Loading