99 from uuid import UUID
1010
1111
12+ def _extract_tensor (data , output_channels ):
13+ """Extract tensor from data, handling both single tensors and lists."""
14+ if isinstance (data , list ):
15+ # LTX2 AV tensors: [video, audio]
16+ return data [0 ][:, :output_channels ], data [1 ][:, :output_channels ]
17+ return data [:, :output_channels ], None
18+
19+
1220def easycache_forward_wrapper (executor , * args , ** kwargs ):
1321 # get values from args
1422 transformer_options : dict [str ] = args [- 1 ]
@@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
1725 if not transformer_options :
1826 transformer_options = args [- 2 ]
1927 easycache : EasyCacheHolder = transformer_options ["easycache" ]
20- x : torch . Tensor = args [0 ][:, : easycache .output_channels ]
28+ x , ax = _extract_tensor ( args [0 ], easycache .output_channels )
2129 sigmas = transformer_options ["sigmas" ]
2230 uuids = transformer_options ["uuids" ]
2331 if sigmas is not None and easycache .is_past_end_timestep (sigmas ):
@@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
3543 if easycache .skip_current_step and can_apply_cache_diff :
3644 if easycache .verbose :
3745 logging .info (f"EasyCache [verbose] - was marked to skip this step by { easycache .first_cond_uuid } . Present uuids: { uuids } " )
38- return easycache .apply_cache_diff (x , uuids )
46+ result = easycache .apply_cache_diff (x , uuids )
47+ if ax is not None :
48+ result_audio = easycache .apply_cache_diff (ax , uuids , is_audio = True )
49+ return [result , result_audio ]
50+ return result
3951 if easycache .initial_step :
4052 easycache .first_cond_uuid = uuids [0 ]
4153 has_first_cond_uuid = easycache .has_first_cond_uuid (uuids )
@@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
5163 logging .info (f"EasyCache [verbose] - skipping step; cumulative_change_rate: { easycache .cumulative_change_rate } , reuse_threshold: { easycache .reuse_threshold } " )
5264 # other conds should also skip this step, and instead use their cached values
5365 easycache .skip_current_step = True
54- return easycache .apply_cache_diff (x , uuids )
66+ result = easycache .apply_cache_diff (x , uuids )
67+ if ax is not None :
68+ result_audio = easycache .apply_cache_diff (ax , uuids , is_audio = True )
69+ return [result , result_audio ]
70+ return result
5571 else :
5672 if easycache .verbose :
5773 logging .info (f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: { easycache .cumulative_change_rate } , reuse_threshold: { easycache .reuse_threshold } " )
5874 easycache .cumulative_change_rate = 0.0
5975
60- output : torch .Tensor = executor (* args , ** kwargs )
76+ full_output : torch .Tensor = executor (* args , ** kwargs )
77+ output , audio_output = _extract_tensor (full_output , easycache .output_channels )
6178 if has_first_cond_uuid and easycache .has_output_prev_norm ():
6279 output_change = (easycache .subsample (output , uuids , clone = False ) - easycache .output_prev_subsampled ).flatten ().abs ().mean ()
6380 if easycache .verbose :
@@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
7491 logging .info (f"EasyCache [verbose] - output_change_rate: { output_change_rate } " )
7592 # TODO: allow cache_diff to be offloaded
7693 easycache .update_cache_diff (output , next_x_prev , uuids )
94+ if audio_output is not None :
95+ easycache .update_cache_diff (audio_output , ax , uuids , is_audio = True )
7796 if has_first_cond_uuid :
7897 easycache .x_prev_subsampled = easycache .subsample (next_x_prev , uuids )
7998 easycache .output_prev_subsampled = easycache .subsample (output , uuids )
8099 easycache .output_prev_norm = output .flatten ().abs ().mean ()
81100 if easycache .verbose :
82101 logging .info (f"EasyCache [verbose] - x_prev_subsampled: { easycache .x_prev_subsampled .shape } " )
83- return output
102+ return full_output
84103
85104def lazycache_predict_noise_wrapper (executor , * args , ** kwargs ):
86105 # get values from args
@@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
89108 easycache : LazyCacheHolder = model_options ["transformer_options" ]["easycache" ]
90109 if easycache .is_past_end_timestep (timestep ):
91110 return executor (* args , ** kwargs )
111+ x : torch .Tensor = _extract_tensor (args [0 ], easycache .output_channels )
92112 # prepare next x_prev
93- x : torch .Tensor = args [0 ][:, :easycache .output_channels ]
94113 next_x_prev = x
95114 input_change = None
96115 do_easycache = easycache .should_do_easycache (timestep )
@@ -197,6 +216,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
197216 self .output_prev_subsampled : torch .Tensor = None
198217 self .output_prev_norm : torch .Tensor = None
199218 self .uuid_cache_diffs : dict [UUID , torch .Tensor ] = {}
219+ self .uuid_cache_diffs_audio : dict [UUID , torch .Tensor ] = {}
200220 self .output_change_rates = []
201221 self .approx_output_change_rates = []
202222 self .total_steps_skipped = 0
@@ -245,20 +265,21 @@ def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> t
245265 def can_apply_cache_diff (self , uuids : list [UUID ]) -> bool :
246266 return all (uuid in self .uuid_cache_diffs for uuid in uuids )
247267
248- def apply_cache_diff (self , x : torch .Tensor , uuids : list [UUID ]):
249- if self .first_cond_uuid in uuids :
268+ def apply_cache_diff (self , x : torch .Tensor , uuids : list [UUID ], is_audio : bool = False ):
269+ if self .first_cond_uuid in uuids and not is_audio :
250270 self .total_steps_skipped += 1
271+ cache_diffs = self .uuid_cache_diffs_audio if is_audio else self .uuid_cache_diffs
251272 batch_offset = x .shape [0 ] // len (uuids )
252273 for i , uuid in enumerate (uuids ):
253274 # slice out only what is relevant to this cond
254275 batch_slice = [slice (i * batch_offset ,(i + 1 )* batch_offset )]
255276 # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
256- if x .shape [1 :] != self . uuid_cache_diffs [uuid ].shape [1 :]:
277+ if x .shape [1 :] != cache_diffs [uuid ].shape [1 :]:
257278 if not self .allow_mismatch :
258279 raise ValueError (f"Cached dims { self .uuid_cache_diffs [uuid ].shape } don't match x dims { x .shape } - this is no good" )
259280 slicing = []
260281 skip_this_dim = True
261- for dim_u , dim_x in zip (self . uuid_cache_diffs [uuid ].shape , x .shape ):
282+ for dim_u , dim_x in zip (cache_diffs [uuid ].shape , x .shape ):
262283 if skip_this_dim :
263284 skip_this_dim = False
264285 continue
@@ -270,10 +291,11 @@ def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
270291 else :
271292 slicing .append (slice (None ))
272293 batch_slice = batch_slice + slicing
273- x [tuple (batch_slice )] += self . uuid_cache_diffs [uuid ].to (x .device )
294+ x [tuple (batch_slice )] += cache_diffs [uuid ].to (x .device )
274295 return x
275296
276- def update_cache_diff (self , output : torch .Tensor , x : torch .Tensor , uuids : list [UUID ]):
297+ def update_cache_diff (self , output : torch .Tensor , x : torch .Tensor , uuids : list [UUID ], is_audio : bool = False ):
298+ cache_diffs = self .uuid_cache_diffs_audio if is_audio else self .uuid_cache_diffs
277299 # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
278300 if output .shape [1 :] != x .shape [1 :]:
279301 if not self .allow_mismatch :
@@ -293,7 +315,7 @@ def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[U
293315 diff = output - x
294316 batch_offset = diff .shape [0 ] // len (uuids )
295317 for i , uuid in enumerate (uuids ):
296- self . uuid_cache_diffs [uuid ] = diff [i * batch_offset :(i + 1 )* batch_offset , ...]
318+ cache_diffs [uuid ] = diff [i * batch_offset :(i + 1 )* batch_offset , ...]
297319
298320 def has_first_cond_uuid (self , uuids : list [UUID ]) -> bool :
299321 return self .first_cond_uuid in uuids
@@ -324,6 +346,8 @@ def reset(self):
324346 self .output_prev_norm = None
325347 del self .uuid_cache_diffs
326348 self .uuid_cache_diffs = {}
349+ del self .uuid_cache_diffs_audio
350+ self .uuid_cache_diffs_audio = {}
327351 self .total_steps_skipped = 0
328352 self .state_metadata = None
329353 return self
0 commit comments