55from torch import Tensor , nn
66
77from .math import attention , rope
8- import comfy .ops
9- import comfy .ldm .common_dit
108
119
1210class EmbedND (nn .Module ):
@@ -87,20 +85,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
8785 operations .Linear (mlp_hidden_dim , hidden_size , bias = True , dtype = dtype , device = device ),
8886 )
8987
90- class RMSNorm (torch .nn .Module ):
91- def __init__ (self , dim : int , dtype = None , device = None , operations = None ):
92- super ().__init__ ()
93- self .scale = nn .Parameter (torch .empty ((dim ), dtype = dtype , device = device ))
94-
95- def forward (self , x : Tensor ):
96- return comfy .ldm .common_dit .rms_norm (x , self .scale , 1e-6 )
97-
9888
9989class QKNorm (torch .nn .Module ):
10090 def __init__ (self , dim : int , dtype = None , device = None , operations = None ):
10191 super ().__init__ ()
102- self .query_norm = RMSNorm (dim , dtype = dtype , device = device , operations = operations )
103- self .key_norm = RMSNorm (dim , dtype = dtype , device = device , operations = operations )
92+ self .query_norm = operations . RMSNorm (dim , dtype = dtype , device = device )
93+ self .key_norm = operations . RMSNorm (dim , dtype = dtype , device = device )
10494
10595 def forward (self , q : Tensor , k : Tensor , v : Tensor ) -> tuple :
10696 q = self .query_norm (q )
@@ -169,7 +159,7 @@ def forward(self, x: Tensor) -> Tensor:
169159
170160
171161class DoubleStreamBlock (nn .Module ):
172- def __init__ (self , hidden_size : int , num_heads : int , mlp_ratio : float , qkv_bias : bool = False , flipped_img_txt = False , modulation = True , mlp_silu_act = False , proj_bias = True , yak_mlp = False , dtype = None , device = None , operations = None ):
162+ def __init__ (self , hidden_size : int , num_heads : int , mlp_ratio : float , qkv_bias : bool = False , modulation = True , mlp_silu_act = False , proj_bias = True , yak_mlp = False , dtype = None , device = None , operations = None ):
173163 super ().__init__ ()
174164
175165 mlp_hidden_dim = int (hidden_size * mlp_ratio )
@@ -197,8 +187,6 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
197187
198188 self .txt_mlp = build_mlp (hidden_size , mlp_hidden_dim , mlp_silu_act = mlp_silu_act , yak_mlp = yak_mlp , dtype = dtype , device = device , operations = operations )
199189
200- self .flipped_img_txt = flipped_img_txt
201-
202190 def forward (self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor , attn_mask = None , modulation_dims_img = None , modulation_dims_txt = None , transformer_options = {}):
203191 if self .modulation :
204192 img_mod1 , img_mod2 = self .img_mod (vec )
@@ -224,32 +212,17 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
224212 del txt_qkv
225213 txt_q , txt_k = self .txt_attn .norm (txt_q , txt_k , txt_v )
226214
227- if self .flipped_img_txt :
228- q = torch .cat ((img_q , txt_q ), dim = 2 )
229- del img_q , txt_q
230- k = torch .cat ((img_k , txt_k ), dim = 2 )
231- del img_k , txt_k
232- v = torch .cat ((img_v , txt_v ), dim = 2 )
233- del img_v , txt_v
234- # run actual attention
235- attn = attention (q , k , v ,
236- pe = pe , mask = attn_mask , transformer_options = transformer_options )
237- del q , k , v
238-
239- img_attn , txt_attn = attn [:, : img .shape [1 ]], attn [:, img .shape [1 ]:]
240- else :
241- q = torch .cat ((txt_q , img_q ), dim = 2 )
242- del txt_q , img_q
243- k = torch .cat ((txt_k , img_k ), dim = 2 )
244- del txt_k , img_k
245- v = torch .cat ((txt_v , img_v ), dim = 2 )
246- del txt_v , img_v
247- # run actual attention
248- attn = attention (q , k , v ,
249- pe = pe , mask = attn_mask , transformer_options = transformer_options )
250- del q , k , v
251-
252- txt_attn , img_attn = attn [:, : txt .shape [1 ]], attn [:, txt .shape [1 ]:]
215+ q = torch .cat ((txt_q , img_q ), dim = 2 )
216+ del txt_q , img_q
217+ k = torch .cat ((txt_k , img_k ), dim = 2 )
218+ del txt_k , img_k
219+ v = torch .cat ((txt_v , img_v ), dim = 2 )
220+ del txt_v , img_v
221+ # run actual attention
222+ attn = attention (q , k , v , pe = pe , mask = attn_mask , transformer_options = transformer_options )
223+ del q , k , v
224+
225+ txt_attn , img_attn = attn [:, : txt .shape [1 ]], attn [:, txt .shape [1 ]:]
253226
254227 # calculate the img bloks
255228 img += apply_mod (self .img_attn .proj (img_attn ), img_mod1 .gate , None , modulation_dims_img )
0 commit comments