@@ -18,12 +18,12 @@ class CompressedTimestep:
1818 def __init__ (self , tensor : torch .Tensor , patches_per_frame : int ):
1919 """
2020 tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
21- patches_per_frame: Number of spatial patches per frame (height * width in latent space)
21+ patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
2222 """
2323 self .batch_size , num_tokens , self .feature_dim = tensor .shape
2424
2525 # Check if compression is valid (num_tokens must be divisible by patches_per_frame)
26- if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame :
26+ if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame :
2727 self .patches_per_frame = patches_per_frame
2828 self .num_frames = num_tokens // patches_per_frame
2929
@@ -215,22 +215,9 @@ def get_av_ca_ada_values(
215215 return (* scale_shift_ada_values , * gate_ada_values )
216216
217217 def forward (
218- self ,
219- x : Tuple [torch .Tensor , torch .Tensor ],
220- v_context = None ,
221- a_context = None ,
222- attention_mask = None ,
223- v_timestep = None ,
224- a_timestep = None ,
225- v_pe = None ,
226- a_pe = None ,
227- v_cross_pe = None ,
228- a_cross_pe = None ,
229- v_cross_scale_shift_timestep = None ,
230- a_cross_scale_shift_timestep = None ,
231- v_cross_gate_timestep = None ,
232- a_cross_gate_timestep = None ,
233- transformer_options = None ,
218+ self , x : Tuple [torch .Tensor , torch .Tensor ], v_context = None , a_context = None , attention_mask = None , v_timestep = None , a_timestep = None ,
219+ v_pe = None , a_pe = None , v_cross_pe = None , a_cross_pe = None , v_cross_scale_shift_timestep = None , a_cross_scale_shift_timestep = None ,
220+ v_cross_gate_timestep = None , a_cross_gate_timestep = None , transformer_options = None ,
234221 ) -> Tuple [torch .Tensor , torch .Tensor ]:
235222 run_vx = transformer_options .get ("run_vx" , True )
236223 run_ax = transformer_options .get ("run_ax" , True )
@@ -240,144 +227,102 @@ def forward(
240227 run_a2v = run_vx and transformer_options .get ("a2v_cross_attn" , True ) and ax .numel () > 0
241228 run_v2a = run_ax and transformer_options .get ("v2a_cross_attn" , True )
242229
230+ # video
243231 if run_vx :
244- vshift_msa , vscale_msa , vgate_msa = (
245- self .get_ada_values (self .scale_shift_table , vx .shape [0 ], v_timestep , slice (0 , 3 ))
246- )
247-
232+ # video self-attention
233+ vshift_msa , vscale_msa = (self .get_ada_values (self .scale_shift_table , vx .shape [0 ], v_timestep , slice (0 , 2 )))
248234 norm_vx = comfy .ldm .common_dit .rms_norm (vx ) * (1 + vscale_msa ) + vshift_msa
249- vx += self . attn1 ( norm_vx , pe = v_pe , transformer_options = transformer_options ) * vgate_msa
250- vx + = self .attn2 (
251- comfy . ldm . common_dit . rms_norm ( vx ),
252- context = v_context ,
253- mask = attention_mask ,
254- transformer_options = transformer_options ,
255- )
256-
257- del vshift_msa , vscale_msa , vgate_msa
258-
235+ del vshift_msa , vscale_msa
236+ attn1_out = self .attn1 ( norm_vx , pe = v_pe , transformer_options = transformer_options )
237+ del norm_vx
238+ # video cross-attention
239+ vgate_msa = self . get_ada_values ( self . scale_shift_table , vx . shape [ 0 ], v_timestep , slice ( 2 , 3 ))[ 0 ]
240+ vx . addcmul_ ( attn1_out , vgate_msa )
241+ del vgate_msa , attn1_out
242+ vx . add_ ( self . attn2 ( comfy . ldm . common_dit . rms_norm ( vx ), context = v_context , mask = attention_mask , transformer_options = transformer_options ))
243+
244+ # audio
259245 if run_ax :
260- ashift_msa , ascale_msa , agate_msa = (
261- self .get_ada_values (self .audio_scale_shift_table , ax .shape [0 ], a_timestep , slice (0 , 3 ))
262- )
263-
246+ # audio self-attention
247+ ashift_msa , ascale_msa = (self .get_ada_values (self .audio_scale_shift_table , ax .shape [0 ], a_timestep , slice (0 , 2 )))
264248 norm_ax = comfy .ldm .common_dit .rms_norm (ax ) * (1 + ascale_msa ) + ashift_msa
265- ax += (
266- self .audio_attn1 (norm_ax , pe = a_pe , transformer_options = transformer_options )
267- * agate_msa
268- )
269- ax += self .audio_attn2 (
270- comfy .ldm .common_dit .rms_norm (ax ),
271- context = a_context ,
272- mask = attention_mask ,
273- transformer_options = transformer_options ,
274- )
275-
276- del ashift_msa , ascale_msa , agate_msa
277-
278- # Audio - Video cross attention.
249+ del ashift_msa , ascale_msa
250+ attn1_out = self .audio_attn1 (norm_ax , pe = a_pe , transformer_options = transformer_options )
251+ del norm_ax
252+ # audio cross-attention
253+ agate_msa = self .get_ada_values (self .audio_scale_shift_table , ax .shape [0 ], a_timestep , slice (2 , 3 ))[0 ]
254+ ax .addcmul_ (attn1_out , agate_msa )
255+ del agate_msa , attn1_out
256+ ax .add_ (self .audio_attn2 (comfy .ldm .common_dit .rms_norm (ax ), context = a_context , mask = attention_mask , transformer_options = transformer_options ))
257+
258+ # video - audio cross attention.
279259 if run_a2v or run_v2a :
280- # norm3
281260 vx_norm3 = comfy .ldm .common_dit .rms_norm (vx )
282261 ax_norm3 = comfy .ldm .common_dit .rms_norm (ax )
283262
284- (
285- scale_ca_audio_hidden_states_a2v ,
286- shift_ca_audio_hidden_states_a2v ,
287- scale_ca_audio_hidden_states_v2a ,
288- shift_ca_audio_hidden_states_v2a ,
289- gate_out_v2a ,
290- ) = self .get_av_ca_ada_values (
291- self .scale_shift_table_a2v_ca_audio ,
292- ax .shape [0 ],
293- a_cross_scale_shift_timestep ,
294- a_cross_gate_timestep ,
295- )
263+ # audio to video cross attention
264+ if run_a2v :
265+ scale_ca_audio_hidden_states_a2v , shift_ca_audio_hidden_states_a2v = self .get_ada_values (
266+ self .scale_shift_table_a2v_ca_audio [:4 , :], ax .shape [0 ], a_cross_scale_shift_timestep )[:2 ]
267+ scale_ca_video_hidden_states_a2v_v , shift_ca_video_hidden_states_a2v_v = self .get_ada_values (
268+ self .scale_shift_table_a2v_ca_video [:4 , :], vx .shape [0 ], v_cross_scale_shift_timestep )[:2 ]
296269
297- (
298- scale_ca_video_hidden_states_a2v ,
299- shift_ca_video_hidden_states_a2v ,
300- scale_ca_video_hidden_states_v2a ,
301- shift_ca_video_hidden_states_v2a ,
302- gate_out_a2v ,
303- ) = self .get_av_ca_ada_values (
304- self .scale_shift_table_a2v_ca_video ,
305- vx .shape [0 ],
306- v_cross_scale_shift_timestep ,
307- v_cross_gate_timestep ,
308- )
270+ vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v ) + shift_ca_video_hidden_states_a2v_v
271+ ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v ) + shift_ca_audio_hidden_states_a2v
272+ del scale_ca_video_hidden_states_a2v_v , shift_ca_video_hidden_states_a2v_v , scale_ca_audio_hidden_states_a2v , shift_ca_audio_hidden_states_a2v
309273
310- if run_a2v :
311- vx_scaled = (
312- vx_norm3 * (1 + scale_ca_video_hidden_states_a2v )
313- + shift_ca_video_hidden_states_a2v
314- )
315- ax_scaled = (
316- ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v )
317- + shift_ca_audio_hidden_states_a2v
318- )
319- vx += (
320- self .audio_to_video_attn (
321- vx_scaled ,
322- context = ax_scaled ,
323- pe = v_cross_pe ,
324- k_pe = a_cross_pe ,
325- transformer_options = transformer_options ,
326- )
327- * gate_out_a2v
328- )
274+ a2v_out = self .audio_to_video_attn (vx_scaled , context = ax_scaled , pe = v_cross_pe , k_pe = a_cross_pe , transformer_options = transformer_options )
275+ del vx_scaled , ax_scaled
329276
330- del gate_out_a2v
331- del scale_ca_video_hidden_states_a2v ,\
332- shift_ca_video_hidden_states_a2v ,\
333- scale_ca_audio_hidden_states_a2v ,\
334- shift_ca_audio_hidden_states_a2v ,\
277+ gate_out_a2v = self .get_ada_values (self .scale_shift_table_a2v_ca_video [4 :, :], vx .shape [0 ], v_cross_gate_timestep )[0 ]
278+ vx .addcmul_ (a2v_out , gate_out_a2v )
279+ del gate_out_a2v , a2v_out
335280
281+ # video to audio cross attention
336282 if run_v2a :
337- ax_scaled = (
338- ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a )
339- + shift_ca_audio_hidden_states_v2a
340- )
341- vx_scaled = (
342- vx_norm3 * (1 + scale_ca_video_hidden_states_v2a )
343- + shift_ca_video_hidden_states_v2a
344- )
345- ax += (
346- self .video_to_audio_attn (
347- ax_scaled ,
348- context = vx_scaled ,
349- pe = a_cross_pe ,
350- k_pe = v_cross_pe ,
351- transformer_options = transformer_options ,
352- )
353- * gate_out_v2a
354- )
283+ scale_ca_audio_hidden_states_v2a , shift_ca_audio_hidden_states_v2a = self .get_ada_values (
284+ self .scale_shift_table_a2v_ca_audio [:4 , :], ax .shape [0 ], a_cross_scale_shift_timestep )[2 :4 ]
285+ scale_ca_video_hidden_states_v2a , shift_ca_video_hidden_states_v2a = self .get_ada_values (
286+ self .scale_shift_table_a2v_ca_video [:4 , :], vx .shape [0 ], v_cross_scale_shift_timestep )[2 :4 ]
355287
356- del gate_out_v2a
357- del scale_ca_video_hidden_states_v2a ,\
358- shift_ca_video_hidden_states_v2a ,\
359- scale_ca_audio_hidden_states_v2a ,\
360- shift_ca_audio_hidden_states_v2a
288+ ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a ) + shift_ca_audio_hidden_states_v2a
289+ vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a ) + shift_ca_video_hidden_states_v2a
290+ del scale_ca_video_hidden_states_v2a , shift_ca_video_hidden_states_v2a , scale_ca_audio_hidden_states_v2a , shift_ca_audio_hidden_states_v2a
361291
362- if run_vx :
363- vshift_mlp , vscale_mlp , vgate_mlp = (
364- self .get_ada_values (self .scale_shift_table , vx .shape [0 ], v_timestep , slice (3 , None ))
365- )
292+ v2a_out = self .video_to_audio_attn (ax_scaled , context = vx_scaled , pe = a_cross_pe , k_pe = v_cross_pe , transformer_options = transformer_options )
293+ del ax_scaled , vx_scaled
366294
295+ gate_out_v2a = self .get_ada_values (self .scale_shift_table_a2v_ca_audio [4 :, :], ax .shape [0 ], a_cross_gate_timestep )[0 ]
296+ ax .addcmul_ (v2a_out , gate_out_v2a )
297+ del gate_out_v2a , v2a_out
298+
299+ del vx_norm3 , ax_norm3
300+
301+ # video feedforward
302+ if run_vx :
303+ vshift_mlp , vscale_mlp = self .get_ada_values (self .scale_shift_table , vx .shape [0 ], v_timestep , slice (3 , 5 ))
367304 vx_scaled = comfy .ldm .common_dit .rms_norm (vx ) * (1 + vscale_mlp ) + vshift_mlp
368- vx += self .ff (vx_scaled ) * vgate_mlp
369- del vshift_mlp , vscale_mlp , vgate_mlp
305+ del vshift_mlp , vscale_mlp
370306
371- if run_ax :
372- ashift_mlp , ascale_mlp , agate_mlp = (
373- self .get_ada_values (self .audio_scale_shift_table , ax .shape [0 ], a_timestep , slice (3 , None ))
374- )
307+ ff_out = self .ff (vx_scaled )
308+ del vx_scaled
375309
310+ vgate_mlp = self .get_ada_values (self .scale_shift_table , vx .shape [0 ], v_timestep , slice (5 , 6 ))[0 ]
311+ vx .addcmul_ (ff_out , vgate_mlp )
312+ del vgate_mlp , ff_out
313+
314+ # audio feedforward
315+ if run_ax :
316+ ashift_mlp , ascale_mlp = self .get_ada_values (self .audio_scale_shift_table , ax .shape [0 ], a_timestep , slice (3 , 5 ))
376317 ax_scaled = comfy .ldm .common_dit .rms_norm (ax ) * (1 + ascale_mlp ) + ashift_mlp
377- ax += self . audio_ff ( ax_scaled ) * agate_mlp
318+ del ashift_mlp , ascale_mlp
378319
379- del ashift_mlp , ascale_mlp , agate_mlp
320+ ff_out = self .audio_ff (ax_scaled )
321+ del ax_scaled
380322
323+ agate_mlp = self .get_ada_values (self .audio_scale_shift_table , ax .shape [0 ], a_timestep , slice (5 , 6 ))[0 ]
324+ ax .addcmul_ (ff_out , agate_mlp )
325+ del agate_mlp , ff_out
381326
382327 return vx , ax
383328
@@ -589,9 +534,20 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
589534 audio_length = kwargs .get ("audio_length" , 0 )
590535 # Separate audio and video latents
591536 vx , ax = self .separate_audio_and_video_latents (x , audio_length )
537+
538+ has_spatial_mask = False
539+ if denoise_mask is not None :
540+ # check if any frame has spatial variation (inpainting)
541+ for frame_idx in range (denoise_mask .shape [2 ]):
542+ frame_mask = denoise_mask [0 , 0 , frame_idx ]
543+ if frame_mask .numel () > 0 and frame_mask .min () != frame_mask .max ():
544+ has_spatial_mask = True
545+ break
546+
592547 [vx , v_pixel_coords , additional_args ] = super ()._process_input (
593548 vx , keyframe_idxs , denoise_mask , ** kwargs
594549 )
550+ additional_args ["has_spatial_mask" ] = has_spatial_mask
595551
596552 ax , a_latent_coords = self .a_patchifier .patchify (ax )
597553 ax = self .audio_patchify_proj (ax )
@@ -618,8 +574,9 @@ def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
618574 # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
619575 # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
620576 orig_shape = kwargs .get ("orig_shape" )
577+ has_spatial_mask = kwargs .get ("has_spatial_mask" , None )
621578 v_patches_per_frame = None
622- if orig_shape is not None and len (orig_shape ) == 5 :
579+ if not has_spatial_mask and orig_shape is not None and len (orig_shape ) == 5 :
623580 # orig_shape[3] = height, orig_shape[4] = width (in latent space)
624581 v_patches_per_frame = orig_shape [3 ] * orig_shape [4 ]
625582
@@ -662,10 +619,11 @@ def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
662619 )
663620
664621 # Compress cross-attention timesteps (only video side, audio is too small to benefit)
622+ # v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
665623 cross_av_timestep_ss = [
666624 av_ca_audio_scale_shift_timestep .view (batch_size , - 1 , av_ca_audio_scale_shift_timestep .shape [- 1 ]),
667- CompressedTimestep (av_ca_video_scale_shift_timestep .view (batch_size , - 1 , av_ca_video_scale_shift_timestep .shape [- 1 ]), v_patches_per_frame ), # video - compressed
668- CompressedTimestep (av_ca_a2v_gate_noise_timestep .view (batch_size , - 1 , av_ca_a2v_gate_noise_timestep .shape [- 1 ]), v_patches_per_frame ), # video - compressed
625+ CompressedTimestep (av_ca_video_scale_shift_timestep .view (batch_size , - 1 , av_ca_video_scale_shift_timestep .shape [- 1 ]), v_patches_per_frame ), # video - compressed if possible
626+ CompressedTimestep (av_ca_a2v_gate_noise_timestep .view (batch_size , - 1 , av_ca_a2v_gate_noise_timestep .shape [- 1 ]), v_patches_per_frame ), # video - compressed if possible
669627 av_ca_v2a_gate_noise_timestep .view (batch_size , - 1 , av_ca_v2a_gate_noise_timestep .shape [- 1 ]),
670628 ]
671629
0 commit comments