11from abc import ABC , abstractmethod
22from enum import Enum
33import functools
4+ import logging
45import math
56from typing import Dict , Optional , Tuple
67
1415
1516from .symmetric_patchifier import SymmetricPatchifier , latent_to_pixel_coords
1617
18+ logger = logging .getLogger (__name__ )
19+
1720def _log_base (x , base ):
1821 return np .log (x ) / np .log (base )
1922
@@ -415,12 +418,12 @@ def __init__(
415418
416419 self .scale_shift_table = nn .Parameter (torch .empty (6 , dim , device = device , dtype = dtype ))
417420
418- def forward (self , x , context = None , attention_mask = None , timestep = None , pe = None , transformer_options = {}):
421+ def forward (self , x , context = None , attention_mask = None , timestep = None , pe = None , transformer_options = {}, self_attention_mask = None ):
419422 shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = (self .scale_shift_table [None , None ].to (device = x .device , dtype = x .dtype ) + timestep .reshape (x .shape [0 ], timestep .shape [1 ], self .scale_shift_table .shape [0 ], - 1 )).unbind (dim = 2 )
420423
421424 attn1_input = comfy .ldm .common_dit .rms_norm (x )
422425 attn1_input = torch .addcmul (attn1_input , attn1_input , scale_msa ).add_ (shift_msa )
423- attn1_input = self .attn1 (attn1_input , pe = pe , transformer_options = transformer_options )
426+ attn1_input = self .attn1 (attn1_input , pe = pe , mask = self_attention_mask , transformer_options = transformer_options )
424427 x .addcmul_ (attn1_input , gate_msa )
425428 del attn1_input
426429
@@ -638,8 +641,16 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
638641 """Process input data. Must be implemented by subclasses."""
639642 pass
640643
644+ def _build_guide_self_attention_mask (self , x , transformer_options , merged_args ):
645+ """Build self-attention mask for per-guide attention attenuation.
646+
647+ Base implementation returns None (no attenuation). Subclasses that
648+ support guide-based attention control should override this.
649+ """
650+ return None
651+
641652 @abstractmethod
642- def _process_transformer_blocks (self , x , context , attention_mask , timestep , pe , ** kwargs ):
653+ def _process_transformer_blocks (self , x , context , attention_mask , timestep , pe , self_attention_mask = None , ** kwargs ):
643654 """Process transformer blocks. Must be implemented by subclasses."""
644655 pass
645656
@@ -788,9 +799,17 @@ def _forward(
788799 attention_mask = self ._prepare_attention_mask (attention_mask , input_dtype )
789800 pe = self ._prepare_positional_embeddings (pixel_coords , frame_rate , input_dtype )
790801
802+ # Build self-attention mask for per-guide attenuation
803+ self_attention_mask = self ._build_guide_self_attention_mask (
804+ x , transformer_options , merged_args
805+ )
806+
791807 # Process transformer blocks
792808 x = self ._process_transformer_blocks (
793- x , context , attention_mask , timestep , pe , transformer_options = transformer_options , ** merged_args
809+ x , context , attention_mask , timestep , pe ,
810+ transformer_options = transformer_options ,
811+ self_attention_mask = self_attention_mask ,
812+ ** merged_args ,
794813 )
795814
796815 # Process output
@@ -890,13 +909,243 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
890909 pixel_coords = pixel_coords [:, :, grid_mask , ...]
891910
892911 kf_grid_mask = grid_mask [- keyframe_idxs .shape [2 ]:]
912+
913+ # Compute per-guide surviving token counts from guide_attention_entries.
914+ # Each entry tracks one guide reference; they are appended in order and
915+ # their pre_filter_counts partition the kf_grid_mask.
916+ guide_entries = kwargs .get ("guide_attention_entries" , None )
917+ if guide_entries :
918+ total_pfc = sum (e ["pre_filter_count" ] for e in guide_entries )
919+ if total_pfc != len (kf_grid_mask ):
920+ raise ValueError (
921+ f"guide pre_filter_counts ({ total_pfc } ) != "
922+ f"keyframe grid mask length ({ len (kf_grid_mask )} )"
923+ )
924+ resolved_entries = []
925+ offset = 0
926+ for entry in guide_entries :
927+ pfc = entry ["pre_filter_count" ]
928+ entry_mask = kf_grid_mask [offset :offset + pfc ]
929+ surviving = int (entry_mask .sum ().item ())
930+ resolved_entries .append ({
931+ ** entry ,
932+ "surviving_count" : surviving ,
933+ })
934+ offset += pfc
935+ additional_args ["resolved_guide_entries" ] = resolved_entries
936+
893937 keyframe_idxs = keyframe_idxs [..., kf_grid_mask , :]
894938 pixel_coords [:, :, - keyframe_idxs .shape [2 ]:, :] = keyframe_idxs
895939
940+ # Total surviving guide tokens (all guides)
941+ additional_args ["num_guide_tokens" ] = keyframe_idxs .shape [2 ]
942+
896943 x = self .patchify_proj (x )
897944 return x , pixel_coords , additional_args
898945
899- def _process_transformer_blocks (self , x , context , attention_mask , timestep , pe , transformer_options = {}, ** kwargs ):
946+ def _build_guide_self_attention_mask (self , x , transformer_options , merged_args ):
947+ """Build self-attention mask for per-guide attention attenuation.
948+
949+ Reads resolved_guide_entries from merged_args (computed in _process_input)
950+ to build a log-space additive bias mask that attenuates noisy ↔ guide
951+ attention for each guide reference independently.
952+
953+ Returns None if no attenuation is needed (all strengths == 1.0 and no
954+ spatial masks, or no guide tokens).
955+ """
956+ if isinstance (x , list ):
957+ # AV model: x = [vx, ax]; use vx for token count and device
958+ total_tokens = x [0 ].shape [1 ]
959+ device = x [0 ].device
960+ dtype = x [0 ].dtype
961+ else :
962+ total_tokens = x .shape [1 ]
963+ device = x .device
964+ dtype = x .dtype
965+
966+ num_guide_tokens = merged_args .get ("num_guide_tokens" , 0 )
967+ if num_guide_tokens == 0 :
968+ return None
969+
970+ resolved_entries = merged_args .get ("resolved_guide_entries" , None )
971+ if not resolved_entries :
972+ return None
973+
974+ # Check if any attenuation is actually needed
975+ needs_attenuation = any (
976+ e ["strength" ] < 1.0 or e .get ("pixel_mask" ) is not None
977+ for e in resolved_entries
978+ )
979+ if not needs_attenuation :
980+ return None
981+
982+ # Build per-guide-token weights for all tracked guide tokens.
983+ # Guides are appended in order at the end of the sequence.
984+ guide_start = total_tokens - num_guide_tokens
985+ all_weights = []
986+ total_tracked = 0
987+
988+ for entry in resolved_entries :
989+ surviving = entry ["surviving_count" ]
990+ if surviving == 0 :
991+ continue
992+
993+ strength = entry ["strength" ]
994+ pixel_mask = entry .get ("pixel_mask" )
995+ latent_shape = entry .get ("latent_shape" )
996+
997+ if pixel_mask is not None and latent_shape is not None :
998+ f_lat , h_lat , w_lat = latent_shape
999+ per_token = self ._downsample_mask_to_latent (
1000+ pixel_mask .to (device = device , dtype = dtype ),
1001+ f_lat , h_lat , w_lat ,
1002+ )
1003+ # per_token shape: (B, f_lat*h_lat*w_lat).
1004+ # Collapse batch dim — the mask is assumed identical across the
1005+ # batch; validate and take the first element to get (1, tokens).
1006+ if per_token .shape [0 ] > 1 :
1007+ ref = per_token [0 ]
1008+ for bi in range (1 , per_token .shape [0 ]):
1009+ if not torch .equal (ref , per_token [bi ]):
1010+ logger .warning (
1011+ "pixel_mask differs across batch elements; "
1012+ "using first element only."
1013+ )
1014+ break
1015+ per_token = per_token [:1 ]
1016+ # `surviving` is the post-grid_mask token count.
1017+ # Clamp to surviving to handle any mismatch safely.
1018+ n_weights = min (per_token .shape [1 ], surviving )
1019+ weights = per_token [:, :n_weights ] * strength # (1, n_weights)
1020+ else :
1021+ weights = torch .full (
1022+ (1 , surviving ), strength , device = device , dtype = dtype
1023+ )
1024+
1025+ all_weights .append (weights )
1026+ total_tracked += weights .shape [1 ]
1027+
1028+ if not all_weights :
1029+ return None
1030+
1031+ # Concatenate per-token weights for all tracked guides
1032+ tracked_weights = torch .cat (all_weights , dim = 1 ) # (1, total_tracked)
1033+
1034+ # Check if any weight is actually < 1.0 (otherwise no attenuation needed)
1035+ if (tracked_weights >= 1.0 ).all ():
1036+ return None
1037+
1038+ # Build the mask: guide tokens are at the end of the sequence.
1039+ # Tracked guides come first (in order), untracked follow.
1040+ return self ._build_self_attention_mask (
1041+ total_tokens , num_guide_tokens , total_tracked ,
1042+ tracked_weights , guide_start , device , dtype ,
1043+ )
1044+
1045+ @staticmethod
1046+ def _downsample_mask_to_latent (mask , f_lat , h_lat , w_lat ):
1047+ """Downsample a pixel-space mask to per-token latent weights.
1048+
1049+ Args:
1050+ mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
1051+ f_lat: Number of latent frames (pre-dilation original count).
1052+ h_lat: Latent height (pre-dilation original height).
1053+ w_lat: Latent width (pre-dilation original width).
1054+
1055+ Returns:
1056+ (B, F_lat * H_lat * W_lat) flattened per-token weights.
1057+ """
1058+ b = mask .shape [0 ]
1059+ f_pix = mask .shape [2 ]
1060+
1061+ # Spatial downsampling: area interpolation per frame
1062+ spatial_down = torch .nn .functional .interpolate (
1063+ rearrange (mask , "b 1 f h w -> (b f) 1 h w" ),
1064+ size = (h_lat , w_lat ),
1065+ mode = "area" ,
1066+ )
1067+ spatial_down = rearrange (spatial_down , "(b f) 1 h w -> b 1 f h w" , b = b )
1068+
1069+ # Temporal downsampling: first pixel frame maps to first latent frame,
1070+ # remaining pixel frames are averaged in groups for causal temporal structure.
1071+ first_frame = spatial_down [:, :, :1 , :, :]
1072+ if f_pix > 1 and f_lat > 1 :
1073+ remaining_pix = f_pix - 1
1074+ remaining_lat = f_lat - 1
1075+ t = remaining_pix // remaining_lat
1076+ if t < 1 :
1077+ # Fewer pixel frames than latent frames — upsample by repeating
1078+ # the available pixel frames via nearest interpolation.
1079+ rest_flat = rearrange (
1080+ spatial_down [:, :, 1 :, :, :],
1081+ "b 1 f h w -> (b h w) 1 f" ,
1082+ )
1083+ rest_up = torch .nn .functional .interpolate (
1084+ rest_flat , size = remaining_lat , mode = "nearest" ,
1085+ )
1086+ rest = rearrange (
1087+ rest_up , "(b h w) 1 f -> b 1 f h w" ,
1088+ b = b , h = h_lat , w = w_lat ,
1089+ )
1090+ else :
1091+ # Trim trailing pixel frames that don't fill a complete group
1092+ usable = remaining_lat * t
1093+ rest = rearrange (
1094+ spatial_down [:, :, 1 :1 + usable , :, :],
1095+ "b 1 (f t) h w -> b 1 f t h w" ,
1096+ t = t ,
1097+ )
1098+ rest = rest .mean (dim = 3 )
1099+ latent_mask = torch .cat ([first_frame , rest ], dim = 2 )
1100+ elif f_lat > 1 :
1101+ # Single pixel frame but multiple latent frames — repeat the
1102+ # single frame across all latent frames.
1103+ latent_mask = first_frame .expand (- 1 , - 1 , f_lat , - 1 , - 1 )
1104+ else :
1105+ latent_mask = first_frame
1106+
1107+ return rearrange (latent_mask , "b 1 f h w -> b (f h w)" )
1108+
1109+ @staticmethod
1110+ def _build_self_attention_mask (total_tokens , num_guide_tokens , tracked_count ,
1111+ tracked_weights , guide_start , device , dtype ):
1112+ """Build a log-space additive self-attention bias mask.
1113+
1114+ Attenuates attention between noisy tokens and tracked guide tokens.
1115+ Untracked guide tokens (at the end of the guide portion) keep full attention.
1116+
1117+ Args:
1118+ total_tokens: Total sequence length.
1119+ num_guide_tokens: Total guide tokens (all guides) at end of sequence.
1120+ tracked_count: Number of tracked guide tokens (first in the guide portion).
1121+ tracked_weights: (1, tracked_count) tensor, values in [0, 1].
1122+ guide_start: Index where guide tokens begin in the sequence.
1123+ device: Target device.
1124+ dtype: Target dtype.
1125+
1126+ Returns:
1127+ (1, 1, total_tokens, total_tokens) additive bias mask.
1128+ 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
1129+ """
1130+ finfo = torch .finfo (dtype )
1131+ mask = torch .zeros ((1 , 1 , total_tokens , total_tokens ), device = device , dtype = dtype )
1132+ tracked_end = guide_start + tracked_count
1133+
1134+ # Convert weights to log-space bias
1135+ w = tracked_weights .to (device = device , dtype = dtype ) # (1, tracked_count)
1136+ log_w = torch .full_like (w , finfo .min )
1137+ positive_mask = w > 0
1138+ if positive_mask .any ():
1139+ log_w [positive_mask ] = torch .log (w [positive_mask ].clamp (min = finfo .tiny ))
1140+
1141+ # noisy → tracked guides: each noisy row gets the same per-guide weight
1142+ mask [:, :, :guide_start , guide_start :tracked_end ] = log_w .view (1 , 1 , 1 , - 1 )
1143+ # tracked guides → noisy: each guide row broadcasts its weight across noisy cols
1144+ mask [:, :, guide_start :tracked_end , :guide_start ] = log_w .view (1 , 1 , - 1 , 1 )
1145+
1146+ return mask
1147+
1148+ def _process_transformer_blocks (self , x , context , attention_mask , timestep , pe , transformer_options = {}, self_attention_mask = None , ** kwargs ):
9001149 """Process transformer blocks for LTXV."""
9011150 patches_replace = transformer_options .get ("patches_replace" , {})
9021151 blocks_replace = patches_replace .get ("dit" , {})
@@ -906,10 +1155,10 @@ def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe,
9061155
9071156 def block_wrap (args ):
9081157 out = {}
909- out ["img" ] = block (args ["img" ], context = args ["txt" ], attention_mask = args ["attention_mask" ], timestep = args ["vec" ], pe = args ["pe" ], transformer_options = args ["transformer_options" ])
1158+ out ["img" ] = block (args ["img" ], context = args ["txt" ], attention_mask = args ["attention_mask" ], timestep = args ["vec" ], pe = args ["pe" ], transformer_options = args ["transformer_options" ], self_attention_mask = args . get ( "self_attention_mask" ) )
9101159 return out
9111160
912- out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : context , "attention_mask" : attention_mask , "vec" : timestep , "pe" : pe , "transformer_options" : transformer_options }, {"original_block" : block_wrap })
1161+ out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : context , "attention_mask" : attention_mask , "vec" : timestep , "pe" : pe , "transformer_options" : transformer_options , "self_attention_mask" : self_attention_mask }, {"original_block" : block_wrap })
9131162 x = out ["img" ]
9141163 else :
9151164 x = block (
@@ -919,6 +1168,7 @@ def block_wrap(args):
9191168 timestep = timestep ,
9201169 pe = pe ,
9211170 transformer_options = transformer_options ,
1171+ self_attention_mask = self_attention_mask ,
9221172 )
9231173
9241174 return x
0 commit comments