Skip to content

Commit a452201

Browse files
authored
feat: per-guide attention strength control in self-attention (Comfy-Org#12518)
Implements per-guide attention attenuation via log-space additive bias in self-attention. Each guide reference tracks its own strength and optional spatial mask in conditioning metadata (guide_attention_entries).
1 parent 907e5dc commit a452201

File tree

4 files changed

+352
-12
lines changed

4 files changed

+352
-12
lines changed

comfy/ldm/lightricks/av_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def get_av_ca_ada_values(
218218
def forward(
219219
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
220220
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
221-
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
221+
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
222222
) -> Tuple[torch.Tensor, torch.Tensor]:
223223
run_vx = transformer_options.get("run_vx", True)
224224
run_ax = transformer_options.get("run_ax", True)
@@ -234,7 +234,7 @@ def forward(
234234
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
235235
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
236236
del vshift_msa, vscale_msa
237-
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
237+
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
238238
del norm_vx
239239
# video cross-attention
240240
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
@@ -726,7 +726,7 @@ def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
726726
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
727727

728728
def _process_transformer_blocks(
729-
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
729+
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
730730
):
731731
vx = x[0]
732732
ax = x[1]
@@ -770,6 +770,7 @@ def block_wrap(args):
770770
v_cross_gate_timestep=args["v_cross_gate_timestep"],
771771
a_cross_gate_timestep=args["a_cross_gate_timestep"],
772772
transformer_options=args["transformer_options"],
773+
self_attention_mask=args.get("self_attention_mask"),
773774
)
774775
return out
775776

@@ -790,6 +791,7 @@ def block_wrap(args):
790791
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
791792
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
792793
"transformer_options": transformer_options,
794+
"self_attention_mask": self_attention_mask,
793795
},
794796
{"original_block": block_wrap},
795797
)
@@ -811,6 +813,7 @@ def block_wrap(args):
811813
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
812814
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
813815
transformer_options=transformer_options,
816+
self_attention_mask=self_attention_mask,
814817
)
815818

816819
return [vx, ax]

comfy/ldm/lightricks/model.py

Lines changed: 257 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22
from enum import Enum
33
import functools
4+
import logging
45
import math
56
from typing import Dict, Optional, Tuple
67

@@ -14,6 +15,8 @@
1415

1516
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
1617

18+
logger = logging.getLogger(__name__)
19+
1720
def _log_base(x, base):
1821
return np.log(x) / np.log(base)
1922

@@ -415,12 +418,12 @@ def __init__(
415418

416419
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
417420

418-
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
421+
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
419422
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
420423

421424
attn1_input = comfy.ldm.common_dit.rms_norm(x)
422425
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
423-
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
426+
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
424427
x.addcmul_(attn1_input, gate_msa)
425428
del attn1_input
426429

@@ -638,8 +641,16 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
638641
"""Process input data. Must be implemented by subclasses."""
639642
pass
640643

644+
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
645+
"""Build self-attention mask for per-guide attention attenuation.
646+
647+
Base implementation returns None (no attenuation). Subclasses that
648+
support guide-based attention control should override this.
649+
"""
650+
return None
651+
641652
@abstractmethod
642-
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
653+
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
643654
"""Process transformer blocks. Must be implemented by subclasses."""
644655
pass
645656

@@ -788,9 +799,17 @@ def _forward(
788799
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
789800
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
790801

802+
# Build self-attention mask for per-guide attenuation
803+
self_attention_mask = self._build_guide_self_attention_mask(
804+
x, transformer_options, merged_args
805+
)
806+
791807
# Process transformer blocks
792808
x = self._process_transformer_blocks(
793-
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
809+
x, context, attention_mask, timestep, pe,
810+
transformer_options=transformer_options,
811+
self_attention_mask=self_attention_mask,
812+
**merged_args,
794813
)
795814

796815
# Process output
@@ -890,13 +909,243 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
890909
pixel_coords = pixel_coords[:, :, grid_mask, ...]
891910

892911
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
912+
913+
# Compute per-guide surviving token counts from guide_attention_entries.
914+
# Each entry tracks one guide reference; they are appended in order and
915+
# their pre_filter_counts partition the kf_grid_mask.
916+
guide_entries = kwargs.get("guide_attention_entries", None)
917+
if guide_entries:
918+
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
919+
if total_pfc != len(kf_grid_mask):
920+
raise ValueError(
921+
f"guide pre_filter_counts ({total_pfc}) != "
922+
f"keyframe grid mask length ({len(kf_grid_mask)})"
923+
)
924+
resolved_entries = []
925+
offset = 0
926+
for entry in guide_entries:
927+
pfc = entry["pre_filter_count"]
928+
entry_mask = kf_grid_mask[offset:offset + pfc]
929+
surviving = int(entry_mask.sum().item())
930+
resolved_entries.append({
931+
**entry,
932+
"surviving_count": surviving,
933+
})
934+
offset += pfc
935+
additional_args["resolved_guide_entries"] = resolved_entries
936+
893937
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
894938
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
895939

940+
# Total surviving guide tokens (all guides)
941+
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
942+
896943
x = self.patchify_proj(x)
897944
return x, pixel_coords, additional_args
898945

899-
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
946+
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
947+
"""Build self-attention mask for per-guide attention attenuation.
948+
949+
Reads resolved_guide_entries from merged_args (computed in _process_input)
950+
to build a log-space additive bias mask that attenuates noisy ↔ guide
951+
attention for each guide reference independently.
952+
953+
Returns None if no attenuation is needed (all strengths == 1.0 and no
954+
spatial masks, or no guide tokens).
955+
"""
956+
if isinstance(x, list):
957+
# AV model: x = [vx, ax]; use vx for token count and device
958+
total_tokens = x[0].shape[1]
959+
device = x[0].device
960+
dtype = x[0].dtype
961+
else:
962+
total_tokens = x.shape[1]
963+
device = x.device
964+
dtype = x.dtype
965+
966+
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
967+
if num_guide_tokens == 0:
968+
return None
969+
970+
resolved_entries = merged_args.get("resolved_guide_entries", None)
971+
if not resolved_entries:
972+
return None
973+
974+
# Check if any attenuation is actually needed
975+
needs_attenuation = any(
976+
e["strength"] < 1.0 or e.get("pixel_mask") is not None
977+
for e in resolved_entries
978+
)
979+
if not needs_attenuation:
980+
return None
981+
982+
# Build per-guide-token weights for all tracked guide tokens.
983+
# Guides are appended in order at the end of the sequence.
984+
guide_start = total_tokens - num_guide_tokens
985+
all_weights = []
986+
total_tracked = 0
987+
988+
for entry in resolved_entries:
989+
surviving = entry["surviving_count"]
990+
if surviving == 0:
991+
continue
992+
993+
strength = entry["strength"]
994+
pixel_mask = entry.get("pixel_mask")
995+
latent_shape = entry.get("latent_shape")
996+
997+
if pixel_mask is not None and latent_shape is not None:
998+
f_lat, h_lat, w_lat = latent_shape
999+
per_token = self._downsample_mask_to_latent(
1000+
pixel_mask.to(device=device, dtype=dtype),
1001+
f_lat, h_lat, w_lat,
1002+
)
1003+
# per_token shape: (B, f_lat*h_lat*w_lat).
1004+
# Collapse batch dim — the mask is assumed identical across the
1005+
# batch; validate and take the first element to get (1, tokens).
1006+
if per_token.shape[0] > 1:
1007+
ref = per_token[0]
1008+
for bi in range(1, per_token.shape[0]):
1009+
if not torch.equal(ref, per_token[bi]):
1010+
logger.warning(
1011+
"pixel_mask differs across batch elements; "
1012+
"using first element only."
1013+
)
1014+
break
1015+
per_token = per_token[:1]
1016+
# `surviving` is the post-grid_mask token count.
1017+
# Clamp to surviving to handle any mismatch safely.
1018+
n_weights = min(per_token.shape[1], surviving)
1019+
weights = per_token[:, :n_weights] * strength # (1, n_weights)
1020+
else:
1021+
weights = torch.full(
1022+
(1, surviving), strength, device=device, dtype=dtype
1023+
)
1024+
1025+
all_weights.append(weights)
1026+
total_tracked += weights.shape[1]
1027+
1028+
if not all_weights:
1029+
return None
1030+
1031+
# Concatenate per-token weights for all tracked guides
1032+
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
1033+
1034+
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
1035+
if (tracked_weights >= 1.0).all():
1036+
return None
1037+
1038+
# Build the mask: guide tokens are at the end of the sequence.
1039+
# Tracked guides come first (in order), untracked follow.
1040+
return self._build_self_attention_mask(
1041+
total_tokens, num_guide_tokens, total_tracked,
1042+
tracked_weights, guide_start, device, dtype,
1043+
)
1044+
1045+
@staticmethod
1046+
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
1047+
"""Downsample a pixel-space mask to per-token latent weights.
1048+
1049+
Args:
1050+
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
1051+
f_lat: Number of latent frames (pre-dilation original count).
1052+
h_lat: Latent height (pre-dilation original height).
1053+
w_lat: Latent width (pre-dilation original width).
1054+
1055+
Returns:
1056+
(B, F_lat * H_lat * W_lat) flattened per-token weights.
1057+
"""
1058+
b = mask.shape[0]
1059+
f_pix = mask.shape[2]
1060+
1061+
# Spatial downsampling: area interpolation per frame
1062+
spatial_down = torch.nn.functional.interpolate(
1063+
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
1064+
size=(h_lat, w_lat),
1065+
mode="area",
1066+
)
1067+
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
1068+
1069+
# Temporal downsampling: first pixel frame maps to first latent frame,
1070+
# remaining pixel frames are averaged in groups for causal temporal structure.
1071+
first_frame = spatial_down[:, :, :1, :, :]
1072+
if f_pix > 1 and f_lat > 1:
1073+
remaining_pix = f_pix - 1
1074+
remaining_lat = f_lat - 1
1075+
t = remaining_pix // remaining_lat
1076+
if t < 1:
1077+
# Fewer pixel frames than latent frames — upsample by repeating
1078+
# the available pixel frames via nearest interpolation.
1079+
rest_flat = rearrange(
1080+
spatial_down[:, :, 1:, :, :],
1081+
"b 1 f h w -> (b h w) 1 f",
1082+
)
1083+
rest_up = torch.nn.functional.interpolate(
1084+
rest_flat, size=remaining_lat, mode="nearest",
1085+
)
1086+
rest = rearrange(
1087+
rest_up, "(b h w) 1 f -> b 1 f h w",
1088+
b=b, h=h_lat, w=w_lat,
1089+
)
1090+
else:
1091+
# Trim trailing pixel frames that don't fill a complete group
1092+
usable = remaining_lat * t
1093+
rest = rearrange(
1094+
spatial_down[:, :, 1:1 + usable, :, :],
1095+
"b 1 (f t) h w -> b 1 f t h w",
1096+
t=t,
1097+
)
1098+
rest = rest.mean(dim=3)
1099+
latent_mask = torch.cat([first_frame, rest], dim=2)
1100+
elif f_lat > 1:
1101+
# Single pixel frame but multiple latent frames — repeat the
1102+
# single frame across all latent frames.
1103+
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
1104+
else:
1105+
latent_mask = first_frame
1106+
1107+
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
1108+
1109+
@staticmethod
1110+
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
1111+
tracked_weights, guide_start, device, dtype):
1112+
"""Build a log-space additive self-attention bias mask.
1113+
1114+
Attenuates attention between noisy tokens and tracked guide tokens.
1115+
Untracked guide tokens (at the end of the guide portion) keep full attention.
1116+
1117+
Args:
1118+
total_tokens: Total sequence length.
1119+
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
1120+
tracked_count: Number of tracked guide tokens (first in the guide portion).
1121+
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
1122+
guide_start: Index where guide tokens begin in the sequence.
1123+
device: Target device.
1124+
dtype: Target dtype.
1125+
1126+
Returns:
1127+
(1, 1, total_tokens, total_tokens) additive bias mask.
1128+
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
1129+
"""
1130+
finfo = torch.finfo(dtype)
1131+
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
1132+
tracked_end = guide_start + tracked_count
1133+
1134+
# Convert weights to log-space bias
1135+
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
1136+
log_w = torch.full_like(w, finfo.min)
1137+
positive_mask = w > 0
1138+
if positive_mask.any():
1139+
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
1140+
1141+
# noisy → tracked guides: each noisy row gets the same per-guide weight
1142+
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
1143+
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
1144+
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
1145+
1146+
return mask
1147+
1148+
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
9001149
"""Process transformer blocks for LTXV."""
9011150
patches_replace = transformer_options.get("patches_replace", {})
9021151
blocks_replace = patches_replace.get("dit", {})
@@ -906,10 +1155,10 @@ def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe,
9061155

9071156
def block_wrap(args):
9081157
out = {}
909-
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
1158+
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
9101159
return out
9111160

912-
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
1161+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
9131162
x = out["img"]
9141163
else:
9151164
x = block(
@@ -919,6 +1168,7 @@ def block_wrap(args):
9191168
timestep=timestep,
9201169
pe=pe,
9211170
transformer_options=transformer_options,
1171+
self_attention_mask=self_attention_mask,
9221172
)
9231173

9241174
return x

0 commit comments

Comments
 (0)