Skip to content

Commit a1c101f

Browse files
authored
EasyCache: Support LTX2 (Comfy-Org#12231)
1 parent c2d7f07 commit a1c101f

File tree

1 file changed

+37
-13
lines changed

1 file changed

+37
-13
lines changed

comfy_extras/nodes_easycache.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
from uuid import UUID
1010

1111

12+
def _extract_tensor(data, output_channels):
13+
"""Extract tensor from data, handling both single tensors and lists."""
14+
if isinstance(data, list):
15+
# LTX2 AV tensors: [video, audio]
16+
return data[0][:, :output_channels], data[1][:, :output_channels]
17+
return data[:, :output_channels], None
18+
19+
1220
def easycache_forward_wrapper(executor, *args, **kwargs):
1321
# get values from args
1422
transformer_options: dict[str] = args[-1]
@@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
1725
if not transformer_options:
1826
transformer_options = args[-2]
1927
easycache: EasyCacheHolder = transformer_options["easycache"]
20-
x: torch.Tensor = args[0][:, :easycache.output_channels]
28+
x, ax = _extract_tensor(args[0], easycache.output_channels)
2129
sigmas = transformer_options["sigmas"]
2230
uuids = transformer_options["uuids"]
2331
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
@@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
3543
if easycache.skip_current_step and can_apply_cache_diff:
3644
if easycache.verbose:
3745
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
38-
return easycache.apply_cache_diff(x, uuids)
46+
result = easycache.apply_cache_diff(x, uuids)
47+
if ax is not None:
48+
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
49+
return [result, result_audio]
50+
return result
3951
if easycache.initial_step:
4052
easycache.first_cond_uuid = uuids[0]
4153
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
@@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
5163
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
5264
# other conds should also skip this step, and instead use their cached values
5365
easycache.skip_current_step = True
54-
return easycache.apply_cache_diff(x, uuids)
66+
result = easycache.apply_cache_diff(x, uuids)
67+
if ax is not None:
68+
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
69+
return [result, result_audio]
70+
return result
5571
else:
5672
if easycache.verbose:
5773
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
5874
easycache.cumulative_change_rate = 0.0
5975

60-
output: torch.Tensor = executor(*args, **kwargs)
76+
full_output: torch.Tensor = executor(*args, **kwargs)
77+
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
6178
if has_first_cond_uuid and easycache.has_output_prev_norm():
6279
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
6380
if easycache.verbose:
@@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
7491
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
7592
# TODO: allow cache_diff to be offloaded
7693
easycache.update_cache_diff(output, next_x_prev, uuids)
94+
if audio_output is not None:
95+
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
7796
if has_first_cond_uuid:
7897
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
7998
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
8099
easycache.output_prev_norm = output.flatten().abs().mean()
81100
if easycache.verbose:
82101
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
83-
return output
102+
return full_output
84103

85104
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
86105
# get values from args
@@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
89108
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
90109
if easycache.is_past_end_timestep(timestep):
91110
return executor(*args, **kwargs)
111+
x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels)
92112
# prepare next x_prev
93-
x: torch.Tensor = args[0][:, :easycache.output_channels]
94113
next_x_prev = x
95114
input_change = None
96115
do_easycache = easycache.should_do_easycache(timestep)
@@ -197,6 +216,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
197216
self.output_prev_subsampled: torch.Tensor = None
198217
self.output_prev_norm: torch.Tensor = None
199218
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
219+
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
200220
self.output_change_rates = []
201221
self.approx_output_change_rates = []
202222
self.total_steps_skipped = 0
@@ -245,20 +265,21 @@ def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> t
245265
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
246266
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
247267

248-
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
249-
if self.first_cond_uuid in uuids:
268+
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
269+
if self.first_cond_uuid in uuids and not is_audio:
250270
self.total_steps_skipped += 1
271+
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
251272
batch_offset = x.shape[0] // len(uuids)
252273
for i, uuid in enumerate(uuids):
253274
# slice out only what is relevant to this cond
254275
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
255276
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
256-
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
277+
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
257278
if not self.allow_mismatch:
258279
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
259280
slicing = []
260281
skip_this_dim = True
261-
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
282+
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
262283
if skip_this_dim:
263284
skip_this_dim = False
264285
continue
@@ -270,10 +291,11 @@ def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
270291
else:
271292
slicing.append(slice(None))
272293
batch_slice = batch_slice + slicing
273-
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
294+
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
274295
return x
275296

276-
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
297+
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
298+
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
277299
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
278300
if output.shape[1:] != x.shape[1:]:
279301
if not self.allow_mismatch:
@@ -293,7 +315,7 @@ def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[U
293315
diff = output - x
294316
batch_offset = diff.shape[0] // len(uuids)
295317
for i, uuid in enumerate(uuids):
296-
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
318+
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
297319

298320
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
299321
return self.first_cond_uuid in uuids
@@ -324,6 +346,8 @@ def reset(self):
324346
self.output_prev_norm = None
325347
del self.uuid_cache_diffs
326348
self.uuid_cache_diffs = {}
349+
del self.uuid_cache_diffs_audio
350+
self.uuid_cache_diffs_audio = {}
327351
self.total_steps_skipped = 0
328352
self.state_metadata = None
329353
return self

0 commit comments

Comments
 (0)