Skip to content

Commit e1add56

Browse files
Use torch RMSNorm for flux models and refactor hunyuan video code. (Comfy-Org#12432)
1 parent 8902907 commit e1add56

File tree

10 files changed

+75
-70
lines changed

10 files changed

+75
-70
lines changed

comfy/controlnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
560560
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
561561
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
562562
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
563+
sd = model_config.process_unet_state_dict(sd)
563564
control_model = controlnet_load_state_dict(control_model, sd)
564565
extra_conds = ['y', 'guidance']
565566
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)

comfy/ldm/chroma/layers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from comfy.ldm.flux.layers import (
55
MLPEmbedder,
6-
RMSNorm,
76
ModulationOut,
87
)
98

@@ -29,7 +28,7 @@ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dty
2928
super().__init__()
3029
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
3130
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
32-
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
31+
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
3332
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
3433

3534
@property

comfy/ldm/chroma_radiance/layers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import torch
55
from torch import nn
66

7-
from comfy.ldm.flux.layers import RMSNorm
8-
97

108
class NerfEmbedder(nn.Module):
119
"""
@@ -145,7 +143,7 @@ def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None
145143
# We now need to generate parameters for 3 matrices.
146144
total_params = 3 * hidden_size_x**2 * mlp_ratio
147145
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
148-
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
146+
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
149147
self.mlp_ratio = mlp_ratio
150148

151149

@@ -178,7 +176,7 @@ def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
178176
class NerfFinalLayer(nn.Module):
179177
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
180178
super().__init__()
181-
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
179+
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
182180
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
183181

184182
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -190,7 +188,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
190188
class NerfFinalLayerConv(nn.Module):
191189
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
192190
super().__init__()
193-
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
191+
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
194192
self.conv = operations.Conv2d(
195193
in_channels=hidden_size,
196194
out_channels=out_channels,

comfy/ldm/flux/layers.py

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from torch import Tensor, nn
66

77
from .math import attention, rope
8-
import comfy.ops
9-
import comfy.ldm.common_dit
108

119

1210
class EmbedND(nn.Module):
@@ -87,20 +85,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
8785
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
8886
)
8987

90-
class RMSNorm(torch.nn.Module):
91-
def __init__(self, dim: int, dtype=None, device=None, operations=None):
92-
super().__init__()
93-
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
94-
95-
def forward(self, x: Tensor):
96-
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
97-
9888

9989
class QKNorm(torch.nn.Module):
10090
def __init__(self, dim: int, dtype=None, device=None, operations=None):
10191
super().__init__()
102-
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
103-
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
92+
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
93+
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
10494

10595
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
10696
q = self.query_norm(q)
@@ -169,7 +159,7 @@ def forward(self, x: Tensor) -> Tensor:
169159

170160

171161
class DoubleStreamBlock(nn.Module):
172-
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
162+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
173163
super().__init__()
174164

175165
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -197,8 +187,6 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
197187

198188
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
199189

200-
self.flipped_img_txt = flipped_img_txt
201-
202190
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
203191
if self.modulation:
204192
img_mod1, img_mod2 = self.img_mod(vec)
@@ -224,32 +212,17 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
224212
del txt_qkv
225213
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
226214

227-
if self.flipped_img_txt:
228-
q = torch.cat((img_q, txt_q), dim=2)
229-
del img_q, txt_q
230-
k = torch.cat((img_k, txt_k), dim=2)
231-
del img_k, txt_k
232-
v = torch.cat((img_v, txt_v), dim=2)
233-
del img_v, txt_v
234-
# run actual attention
235-
attn = attention(q, k, v,
236-
pe=pe, mask=attn_mask, transformer_options=transformer_options)
237-
del q, k, v
238-
239-
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
240-
else:
241-
q = torch.cat((txt_q, img_q), dim=2)
242-
del txt_q, img_q
243-
k = torch.cat((txt_k, img_k), dim=2)
244-
del txt_k, img_k
245-
v = torch.cat((txt_v, img_v), dim=2)
246-
del txt_v, img_v
247-
# run actual attention
248-
attn = attention(q, k, v,
249-
pe=pe, mask=attn_mask, transformer_options=transformer_options)
250-
del q, k, v
251-
252-
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
215+
q = torch.cat((txt_q, img_q), dim=2)
216+
del txt_q, img_q
217+
k = torch.cat((txt_k, img_k), dim=2)
218+
del txt_k, img_k
219+
v = torch.cat((txt_v, img_v), dim=2)
220+
del txt_v, img_v
221+
# run actual attention
222+
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
223+
del q, k, v
224+
225+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
253226

254227
# calculate the img bloks
255228
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)

comfy/ldm/flux/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
SingleStreamBlock,
1717
timestep_embedding,
1818
Modulation,
19-
RMSNorm
2019
)
2120

2221
@dataclass
@@ -81,7 +80,7 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
8180
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
8281

8382
if params.txt_norm:
84-
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
83+
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
8584
else:
8685
self.txt_norm = None
8786

comfy/ldm/hunyuan_video/model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
241241
self.num_heads,
242242
mlp_ratio=params.mlp_ratio,
243243
qkv_bias=params.qkv_bias,
244-
flipped_img_txt=True,
245244
dtype=dtype, device=device, operations=operations
246245
)
247246
for _ in range(params.depth)
@@ -378,14 +377,14 @@ def forward_orig(
378377
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
379378
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
380379

381-
ids = torch.cat((img_ids, txt_ids), dim=1)
380+
ids = torch.cat((txt_ids, img_ids), dim=1)
382381
pe = self.pe_embedder(ids)
383382

384383
img_len = img.shape[1]
385384
if txt_mask is not None:
386385
attn_mask_len = img_len + txt.shape[1]
387386
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
388-
attn_mask[:, 0, img_len:] = txt_mask
387+
attn_mask[:, 0, :txt.shape[1]] = txt_mask
389388
else:
390389
attn_mask = None
391390

@@ -413,7 +412,7 @@ def block_wrap(args):
413412
if add is not None:
414413
img += add
415414

416-
img = torch.cat((img, txt), 1)
415+
img = torch.cat((txt, img), 1)
417416

418417
transformer_options["total_blocks"] = len(self.single_blocks)
419418
transformer_options["block_type"] = "single"
@@ -435,9 +434,9 @@ def block_wrap(args):
435434
if i < len(control_o):
436435
add = control_o[i]
437436
if add is not None:
438-
img[:, : img_len] += add
437+
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
439438

440-
img = img[:, : img_len]
439+
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
441440
if ref_latent is not None:
442441
img = img[:, ref_latent.shape[1]:]
443442

comfy/lora_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
def convert_lora_bfl_control(sd): #BFL loras for Flux
66
sd_out = {}
77
for k in sd:
8-
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
8+
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
99
sd_out[k_to] = sd[k]
1010

1111
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])

comfy/model_detection.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
1919
count += 1
2020
return count
2121

22+
def any_suffix_in(keys, prefix, main, suffix_list=[]):
23+
for x in suffix_list:
24+
if "{}{}{}".format(prefix, main, x) in keys:
25+
return True
26+
return False
27+
2228
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
2329
context_dim = None
2430
use_linear_in_transformer = False
@@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
186192
dit_config["meanflow_sum"] = False
187193
return dit_config
188194

189-
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
195+
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
190196
dit_config = {}
191197
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
192198
dit_config["image_model"] = "flux2"
@@ -241,15 +247,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
241247

242248
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
243249
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
244-
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
250+
251+
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
245252
dit_config["image_model"] = "chroma"
246253
dit_config["in_channels"] = 64
247254
dit_config["out_channels"] = 64
248255
dit_config["in_dim"] = 64
249256
dit_config["out_dim"] = 3072
250257
dit_config["hidden_dim"] = 5120
251258
dit_config["n_layers"] = 5
252-
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
259+
260+
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
253261
dit_config["image_model"] = "chroma_radiance"
254262
dit_config["in_channels"] = 3
255263
dit_config["out_channels"] = 3
@@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
259267
dit_config["nerf_depth"] = 4
260268
dit_config["nerf_max_freqs"] = 8
261269
dit_config["nerf_tile_size"] = 512
262-
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
270+
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
263271
dit_config["nerf_embedder_dtype"] = torch.float32
264272
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
265273
dit_config["use_x0"] = True
@@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
268276
else:
269277
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
270278
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
271-
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
279+
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
272280
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
273281
dit_config["txt_ids_dims"] = [1, 2]
274282

comfy/supported_models.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
710710

711711
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
712712

713+
def process_unet_state_dict(self, state_dict):
714+
out_sd = {}
715+
for k in list(state_dict.keys()):
716+
key_out = k
717+
if key_out.endswith("_norm.scale"):
718+
key_out = "{}.weight".format(key_out[:-len(".scale")])
719+
out_sd[key_out] = state_dict[k]
720+
return out_sd
721+
713722
vae_key_prefix = ["vae."]
714723
text_encoder_key_prefix = ["text_encoders."]
715724

@@ -898,11 +907,13 @@ def process_unet_state_dict(self, state_dict):
898907
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
899908
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
900909
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
901-
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
902-
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
910+
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
911+
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
903912
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
904913
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
905914
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
915+
if key_out.endswith(".scale"):
916+
key_out = "{}.weight".format(key_out[:-len(".scale")])
906917
out_sd[key_out] = state_dict[k]
907918
return out_sd
908919

@@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
12641275

12651276
latent_format = latent_formats.Hunyuan3Dv2
12661277

1278+
def process_unet_state_dict(self, state_dict):
1279+
out_sd = {}
1280+
for k in list(state_dict.keys()):
1281+
key_out = k
1282+
if key_out.endswith(".scale"):
1283+
key_out = "{}.weight".format(key_out[:-len(".scale")])
1284+
out_sd[key_out] = state_dict[k]
1285+
return out_sd
1286+
12671287
def process_unet_state_dict_for_saving(self, state_dict):
12681288
replace_prefix = {"": "model."}
12691289
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
13411361

13421362
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
13431363

1364+
def process_unet_state_dict(self, state_dict):
1365+
out_sd = {}
1366+
for k in list(state_dict.keys()):
1367+
key_out = k
1368+
if key_out.endswith(".scale"):
1369+
key_out = "{}.weight".format(key_out[:-len(".scale")])
1370+
out_sd[key_out] = state_dict[k]
1371+
return out_sd
13441372

13451373
def get_model(self, state_dict, prefix="", device=None):
13461374
out = model_base.Chroma(self, device=device)

comfy/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -675,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
675675
"ff_context.linear_in.bias": "txt_mlp.0.bias",
676676
"ff_context.linear_out.weight": "txt_mlp.2.weight",
677677
"ff_context.linear_out.bias": "txt_mlp.2.bias",
678-
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
679-
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
680-
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
681-
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
678+
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
679+
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
680+
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
681+
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
682682
}
683683

684684
for k in block_map:
@@ -701,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
701701
"norm.linear.bias": "modulation.lin.bias",
702702
"proj_out.weight": "linear2.weight",
703703
"proj_out.bias": "linear2.bias",
704-
"attn.norm_q.weight": "norm.query_norm.scale",
705-
"attn.norm_k.weight": "norm.key_norm.scale",
704+
"attn.norm_q.weight": "norm.query_norm.weight",
705+
"attn.norm_k.weight": "norm.key_norm.weight",
706706
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
707707
"attn.to_out.weight": "linear2.weight", # Flux 2
708708
}

0 commit comments

Comments
 (0)