Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
Expand Down
3 changes: 1 addition & 2 deletions comfy/ldm/chroma/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
ModulationOut,
)

Expand All @@ -29,7 +28,7 @@ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dty
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)

@property
Expand Down
8 changes: 3 additions & 5 deletions comfy/ldm/chroma_radiance/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
from torch import nn

from comfy.ldm.flux.layers import RMSNorm


class NerfEmbedder(nn.Module):
"""
Expand Down Expand Up @@ -145,7 +143,7 @@ def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.mlp_ratio = mlp_ratio


Expand Down Expand Up @@ -178,7 +176,7 @@ def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -190,7 +188,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,
Expand Down
55 changes: 14 additions & 41 deletions comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from torch import Tensor, nn

from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit


class EmbedND(nn.Module):
Expand Down Expand Up @@ -87,20 +85,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)

class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))

def forward(self, x: Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)


class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)

def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
Expand Down Expand Up @@ -169,7 +159,7 @@ def forward(self, x: Tensor) -> Tensor:


class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()

mlp_hidden_dim = int(hidden_size * mlp_ratio)
Expand Down Expand Up @@ -197,8 +187,6 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:

self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)

self.flipped_img_txt = flipped_img_txt

def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
Expand All @@ -224,32 +212,17 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v

img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v

txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v

txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]

# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
Expand Down
3 changes: 1 addition & 2 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
SingleStreamBlock,
timestep_embedding,
Modulation,
RMSNorm
)

@dataclass
Expand Down Expand Up @@ -81,7 +80,7 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)

if params.txt_norm:
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
else:
self.txt_norm = None

Expand Down
11 changes: 5 additions & 6 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
Expand Down Expand Up @@ -378,14 +377,14 @@ def forward_orig(
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)

ids = torch.cat((img_ids, txt_ids), dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
attn_mask[:, 0, :txt.shape[1]] = txt_mask
else:
attn_mask = None

Expand Down Expand Up @@ -413,7 +412,7 @@ def block_wrap(args):
if add is not None:
img += add

img = torch.cat((img, txt), 1)
img = torch.cat((txt, img), 1)

transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
Expand All @@ -435,9 +434,9 @@ def block_wrap(args):
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img[:, txt.shape[1]: img_len + txt.shape[1]] += add

img = img[:, : img_len]
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]

Expand Down
2 changes: 1 addition & 1 deletion comfy/lora_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
sd_out[k_to] = sd[k]

sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
Expand Down
18 changes: 13 additions & 5 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count

def any_suffix_in(keys, prefix, main, suffix_list=[]):
for x in suffix_list:
if "{}{}{}".format(prefix, main, x) in keys:
return True
return False

def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
Expand Down Expand Up @@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["meanflow_sum"] = False
return dit_config

if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
Expand Down Expand Up @@ -241,15 +247,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):

dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma

if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
dit_config["in_dim"] = 64
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance

if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
Expand All @@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
Expand All @@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]

Expand Down
4 changes: 4 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,8 @@ def setup_param(self, m, n, param_key):
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)

move_weight_functions(m, device_to)

logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")

self.model.device = device_to
Expand Down Expand Up @@ -1601,6 +1603,8 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
for m in self.model.modules():
move_weight_functions(m, device_to)

def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
Expand Down
32 changes: 30 additions & 2 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):

supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith("_norm.scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd

vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]

Expand Down Expand Up @@ -898,11 +907,13 @@ def process_unet_state_dict(self, state_dict):
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd

Expand Down Expand Up @@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):

latent_format = latent_formats.Hunyuan3Dv2

def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd

def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
Expand Down Expand Up @@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):

supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd

def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)
Expand Down
Loading
Loading