Skip to content

Commit 55bd606

Browse files
authored
LTX2: Refactor forward function for better VRAM efficiency and fix spatial inpainting (Comfy-Org#12046)
* Disable timestep embed compression when inpainting Spatial inpainting not compatible with the compression * Reduce crossattn peak VRAM * LTX2: Refactor forward function for better VRAM efficiency
1 parent 79cdbc8 commit 55bd606

File tree

1 file changed

+95
-137
lines changed

1 file changed

+95
-137
lines changed

comfy/ldm/lightricks/av_model.py

Lines changed: 95 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ class CompressedTimestep:
1818
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
1919
"""
2020
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
21-
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
21+
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
2222
"""
2323
self.batch_size, num_tokens, self.feature_dim = tensor.shape
2424

2525
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
26-
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
26+
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
2727
self.patches_per_frame = patches_per_frame
2828
self.num_frames = num_tokens // patches_per_frame
2929

@@ -215,22 +215,9 @@ def get_av_ca_ada_values(
215215
return (*scale_shift_ada_values, *gate_ada_values)
216216

217217
def forward(
218-
self,
219-
x: Tuple[torch.Tensor, torch.Tensor],
220-
v_context=None,
221-
a_context=None,
222-
attention_mask=None,
223-
v_timestep=None,
224-
a_timestep=None,
225-
v_pe=None,
226-
a_pe=None,
227-
v_cross_pe=None,
228-
a_cross_pe=None,
229-
v_cross_scale_shift_timestep=None,
230-
a_cross_scale_shift_timestep=None,
231-
v_cross_gate_timestep=None,
232-
a_cross_gate_timestep=None,
233-
transformer_options=None,
218+
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
219+
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
220+
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
234221
) -> Tuple[torch.Tensor, torch.Tensor]:
235222
run_vx = transformer_options.get("run_vx", True)
236223
run_ax = transformer_options.get("run_ax", True)
@@ -240,144 +227,102 @@ def forward(
240227
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
241228
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
242229

230+
# video
243231
if run_vx:
244-
vshift_msa, vscale_msa, vgate_msa = (
245-
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
246-
)
247-
232+
# video self-attention
233+
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
248234
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
249-
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
250-
vx += self.attn2(
251-
comfy.ldm.common_dit.rms_norm(vx),
252-
context=v_context,
253-
mask=attention_mask,
254-
transformer_options=transformer_options,
255-
)
256-
257-
del vshift_msa, vscale_msa, vgate_msa
258-
235+
del vshift_msa, vscale_msa
236+
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
237+
del norm_vx
238+
# video cross-attention
239+
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
240+
vx.addcmul_(attn1_out, vgate_msa)
241+
del vgate_msa, attn1_out
242+
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
243+
244+
# audio
259245
if run_ax:
260-
ashift_msa, ascale_msa, agate_msa = (
261-
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
262-
)
263-
246+
# audio self-attention
247+
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
264248
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
265-
ax += (
266-
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
267-
* agate_msa
268-
)
269-
ax += self.audio_attn2(
270-
comfy.ldm.common_dit.rms_norm(ax),
271-
context=a_context,
272-
mask=attention_mask,
273-
transformer_options=transformer_options,
274-
)
275-
276-
del ashift_msa, ascale_msa, agate_msa
277-
278-
# Audio - Video cross attention.
249+
del ashift_msa, ascale_msa
250+
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
251+
del norm_ax
252+
# audio cross-attention
253+
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
254+
ax.addcmul_(attn1_out, agate_msa)
255+
del agate_msa, attn1_out
256+
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
257+
258+
# video - audio cross attention.
279259
if run_a2v or run_v2a:
280-
# norm3
281260
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
282261
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
283262

284-
(
285-
scale_ca_audio_hidden_states_a2v,
286-
shift_ca_audio_hidden_states_a2v,
287-
scale_ca_audio_hidden_states_v2a,
288-
shift_ca_audio_hidden_states_v2a,
289-
gate_out_v2a,
290-
) = self.get_av_ca_ada_values(
291-
self.scale_shift_table_a2v_ca_audio,
292-
ax.shape[0],
293-
a_cross_scale_shift_timestep,
294-
a_cross_gate_timestep,
295-
)
263+
# audio to video cross attention
264+
if run_a2v:
265+
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
266+
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
267+
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
268+
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
296269

297-
(
298-
scale_ca_video_hidden_states_a2v,
299-
shift_ca_video_hidden_states_a2v,
300-
scale_ca_video_hidden_states_v2a,
301-
shift_ca_video_hidden_states_v2a,
302-
gate_out_a2v,
303-
) = self.get_av_ca_ada_values(
304-
self.scale_shift_table_a2v_ca_video,
305-
vx.shape[0],
306-
v_cross_scale_shift_timestep,
307-
v_cross_gate_timestep,
308-
)
270+
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
271+
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
272+
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
309273

310-
if run_a2v:
311-
vx_scaled = (
312-
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
313-
+ shift_ca_video_hidden_states_a2v
314-
)
315-
ax_scaled = (
316-
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
317-
+ shift_ca_audio_hidden_states_a2v
318-
)
319-
vx += (
320-
self.audio_to_video_attn(
321-
vx_scaled,
322-
context=ax_scaled,
323-
pe=v_cross_pe,
324-
k_pe=a_cross_pe,
325-
transformer_options=transformer_options,
326-
)
327-
* gate_out_a2v
328-
)
274+
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
275+
del vx_scaled, ax_scaled
329276

330-
del gate_out_a2v
331-
del scale_ca_video_hidden_states_a2v,\
332-
shift_ca_video_hidden_states_a2v,\
333-
scale_ca_audio_hidden_states_a2v,\
334-
shift_ca_audio_hidden_states_a2v,\
277+
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
278+
vx.addcmul_(a2v_out, gate_out_a2v)
279+
del gate_out_a2v, a2v_out
335280

281+
# video to audio cross attention
336282
if run_v2a:
337-
ax_scaled = (
338-
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
339-
+ shift_ca_audio_hidden_states_v2a
340-
)
341-
vx_scaled = (
342-
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
343-
+ shift_ca_video_hidden_states_v2a
344-
)
345-
ax += (
346-
self.video_to_audio_attn(
347-
ax_scaled,
348-
context=vx_scaled,
349-
pe=a_cross_pe,
350-
k_pe=v_cross_pe,
351-
transformer_options=transformer_options,
352-
)
353-
* gate_out_v2a
354-
)
283+
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
284+
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
285+
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
286+
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
355287

356-
del gate_out_v2a
357-
del scale_ca_video_hidden_states_v2a,\
358-
shift_ca_video_hidden_states_v2a,\
359-
scale_ca_audio_hidden_states_v2a,\
360-
shift_ca_audio_hidden_states_v2a
288+
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
289+
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
290+
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
361291

362-
if run_vx:
363-
vshift_mlp, vscale_mlp, vgate_mlp = (
364-
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
365-
)
292+
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
293+
del ax_scaled, vx_scaled
366294

295+
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
296+
ax.addcmul_(v2a_out, gate_out_v2a)
297+
del gate_out_v2a, v2a_out
298+
299+
del vx_norm3, ax_norm3
300+
301+
# video feedforward
302+
if run_vx:
303+
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
367304
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
368-
vx += self.ff(vx_scaled) * vgate_mlp
369-
del vshift_mlp, vscale_mlp, vgate_mlp
305+
del vshift_mlp, vscale_mlp
370306

371-
if run_ax:
372-
ashift_mlp, ascale_mlp, agate_mlp = (
373-
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
374-
)
307+
ff_out = self.ff(vx_scaled)
308+
del vx_scaled
375309

310+
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
311+
vx.addcmul_(ff_out, vgate_mlp)
312+
del vgate_mlp, ff_out
313+
314+
# audio feedforward
315+
if run_ax:
316+
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
376317
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
377-
ax += self.audio_ff(ax_scaled) * agate_mlp
318+
del ashift_mlp, ascale_mlp
378319

379-
del ashift_mlp, ascale_mlp, agate_mlp
320+
ff_out = self.audio_ff(ax_scaled)
321+
del ax_scaled
380322

323+
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
324+
ax.addcmul_(ff_out, agate_mlp)
325+
del agate_mlp, ff_out
381326

382327
return vx, ax
383328

@@ -589,9 +534,20 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
589534
audio_length = kwargs.get("audio_length", 0)
590535
# Separate audio and video latents
591536
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
537+
538+
has_spatial_mask = False
539+
if denoise_mask is not None:
540+
# check if any frame has spatial variation (inpainting)
541+
for frame_idx in range(denoise_mask.shape[2]):
542+
frame_mask = denoise_mask[0, 0, frame_idx]
543+
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
544+
has_spatial_mask = True
545+
break
546+
592547
[vx, v_pixel_coords, additional_args] = super()._process_input(
593548
vx, keyframe_idxs, denoise_mask, **kwargs
594549
)
550+
additional_args["has_spatial_mask"] = has_spatial_mask
595551

596552
ax, a_latent_coords = self.a_patchifier.patchify(ax)
597553
ax = self.audio_patchify_proj(ax)
@@ -618,8 +574,9 @@ def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
618574
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
619575
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
620576
orig_shape = kwargs.get("orig_shape")
577+
has_spatial_mask = kwargs.get("has_spatial_mask", None)
621578
v_patches_per_frame = None
622-
if orig_shape is not None and len(orig_shape) == 5:
579+
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
623580
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
624581
v_patches_per_frame = orig_shape[3] * orig_shape[4]
625582

@@ -662,10 +619,11 @@ def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
662619
)
663620

664621
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
622+
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
665623
cross_av_timestep_ss = [
666624
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
667-
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
668-
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
625+
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
626+
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
669627
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
670628
]
671629

0 commit comments

Comments
 (0)