Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
afa4a23
feat: implement apply_lora_scale to remove boilerplate.
sayakpaul Jan 19, 2026
835a087
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Jan 20, 2026
3cdce4d
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Jan 27, 2026
9afafe5
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Jan 28, 2026
d6fcd78
apply to the rest.
sayakpaul Jan 28, 2026
290f749
up
sayakpaul Jan 28, 2026
458ac94
remove more.
sayakpaul Jan 28, 2026
8c402d3
remove.
sayakpaul Jan 28, 2026
e5ebacb
fix
sayakpaul Jan 28, 2026
9b3947c
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Jan 29, 2026
efd6d69
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Jan 29, 2026
261b061
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Feb 2, 2026
fc26162
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Feb 2, 2026
90eb7d7
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Feb 4, 2026
8f6c6dd
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Feb 8, 2026
c83bb06
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Feb 10, 2026
9a79d8a
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Feb 11, 2026
c5bced9
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Feb 12, 2026
a0085cd
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Feb 13, 2026
a5b9e8c
Merge branch 'main' into apply-lora-scale-decorator
sayakpaul Feb 13, 2026
07f6d08
apply feedback.
sayakpaul Feb 13, 2026
5d57508
resolve conflicts.
sayakpaul Feb 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/models/controlnets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ...utils import BaseOutput, apply_lora_scale, logging
from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -598,6 +598,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)

@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,
Expand Down
25 changes: 6 additions & 19 deletions src/diffusers/models/controlnets/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
BaseOutput,
apply_lora_scale,
logging,
)
from ..attention import AttentionMixin
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
Expand Down Expand Up @@ -150,6 +154,7 @@ def from_transformer(

return controlnet

@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -197,20 +202,6 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)

if self.input_hint_block is not None:
Expand Down Expand Up @@ -323,10 +314,6 @@ def forward(
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)

Expand Down
26 changes: 7 additions & 19 deletions src/diffusers/models/controlnets/controlnet_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
BaseOutput,
apply_lora_scale,
deprecate,
logging,
)
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
Expand Down Expand Up @@ -123,6 +128,7 @@ def from_transformer(

return controlnet

@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -181,20 +187,6 @@ def forward(
standard_warn=False,
)

if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.img_in(hidden_states)

# add
Expand Down Expand Up @@ -256,10 +248,6 @@ def forward(
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return controlnet_block_samples

Expand Down
22 changes: 2 additions & 20 deletions src/diffusers/models/controlnets/controlnet_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import BaseOutput, apply_lora_scale, logging
from ..attention import AttentionMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(

self.gradient_checkpointing = False

@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -129,21 +130,6 @@ def forward(
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)

# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
Expand Down Expand Up @@ -218,10 +204,6 @@ def forward(
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]

if not return_dict:
Expand Down
22 changes: 2 additions & 20 deletions src/diffusers/models/controlnets/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, JointTransformerBlock
from ..attention_processor import Attention, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
Expand Down Expand Up @@ -269,6 +269,7 @@ def from_transformer(

return controlnet

@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -308,21 +309,6 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

if self.pos_embed is not None and hidden_states.ndim != 4:
raise ValueError("hidden_states must be 4D when pos_embed is used")

Expand Down Expand Up @@ -382,10 +368,6 @@ def forward(
# 6. scaling
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (controlnet_block_res_samples,)

Expand Down
22 changes: 2 additions & 20 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin
from ..attention_processor import (
Expand Down Expand Up @@ -397,6 +397,7 @@ def unfuse_qkv_projections(self):
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.FloatTensor,
Expand All @@ -405,21 +406,6 @@ def forward(
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)

height, width = hidden_states.shape[-2:]

# Apply patch embedding, timestep embedding, and project the caption embeddings.
Expand Down Expand Up @@ -486,10 +472,6 @@ def forward(
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (output,)

Expand Down
22 changes: 2 additions & 20 deletions src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
Expand Down Expand Up @@ -363,6 +363,7 @@ def unfuse_qkv_projections(self):
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -374,21 +375,6 @@ def forward(
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)

batch_size, num_frames, channels, height, width = hidden_states.shape

# 1. Time embedding
Expand Down Expand Up @@ -454,10 +440,6 @@ def forward(
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
22 changes: 2 additions & 20 deletions src/diffusers/models/transformers/consisid_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import CogVideoXAttnProcessor2_0
Expand Down Expand Up @@ -620,6 +620,7 @@ def _init_face_inputs(self):
]
)

@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -632,21 +633,6 @@ def forward(
id_vit_hidden: torch.Tensor | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)

# fuse clip and insightface
valid_face_emb = None
if self.is_train_face:
Expand Down Expand Up @@ -720,10 +706,6 @@ def forward(
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
Loading