Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,7 @@ def forward(self, x):
return self.conv(x)

def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")

class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
Expand Down
25 changes: 25 additions & 0 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten

return padded_tensor

def calculate_shape(patches, weight, key, original_weights=None):
current_shape = weight.shape

for p in patches:
v = p[1]
offset = p[3]

# Offsets restore the old shape; lists force a diff without metadata
if offset is not None or isinstance(v, list):
continue

if isinstance(v, weight_adapter.WeightAdapterBase):
adapter_shape = v.calculate_shape(key)
if adapter_shape is not None:
current_shape = adapter_shape
continue

# Standard diff logic with padding
if len(v) == 2:
patch_type, patch_data = v[0], v[1]
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
current_shape = patch_data[0].shape

return current_shape

def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]
Expand Down
35 changes: 27 additions & 8 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,8 +1514,10 @@ def setup_param(self, m, n, param_key):

weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return 0
return (False, 0)
if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
Expand All @@ -1529,21 +1531,33 @@ def setup_param(self, m, n, param_key):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
return (False, comfy.memory_management.vram_aligned_size(geometry))

def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key)
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)

if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n
set_dirty(m, dirty)

v_weight_size = 0
v_weight_size += setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "bias")
force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias

if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size

else:
for param in params:
Expand Down Expand Up @@ -1606,6 +1620,11 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
for m in self.model.modules():
move_weight_functions(m, device_to)

keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)

def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):
Expand Down
6 changes: 6 additions & 0 deletions comfy/weight_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
"""
raise NotImplementedError

def calculate_shape(
self,
key
):
return None

def calculate_weight(
self,
weight,
Expand Down
7 changes: 7 additions & 0 deletions comfy/weight_adapter/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ def load(
else:
return None

def calculate_shape(
self,
key
):
reshape = self.weights[5]
return tuple(reshape) if reshape is not None else None

def calculate_weight(
self,
weight,
Expand Down
44 changes: 41 additions & 3 deletions comfy_api_nodes/apis/bria.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
)


class BriaRemoveBackgroundRequest(BaseModel):
image: str = Field(...)
sync: bool = Field(False)
visual_input_content_moderation: bool = Field(
False, description="If true, returns 422 on input image moderation failure."
)
visual_output_content_moderation: bool = Field(
False, description="If true, returns 422 on visual output moderation failure."
)
seed: int = Field(...)


class BriaStatusResponse(BaseModel):
request_id: str = Field(...)
status_url: str = Field(...)
warning: str | None = Field(None)


class BriaResult(BaseModel):
class BriaRemoveBackgroundResult(BaseModel):
image_url: str = Field(...)


class BriaRemoveBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveBackgroundResult | None = Field(None)


class BriaImageEditResult(BaseModel):
structured_prompt: str = Field(...)
image_url: str = Field(...)


class BriaResponse(BaseModel):
class BriaImageEditResponse(BaseModel):
status: str = Field(...)
result: BriaImageEditResult | None = Field(None)


class BriaRemoveVideoBackgroundRequest(BaseModel):
video: str = Field(...)
background_color: str = Field(default="transparent", description="Background color for the output video.")
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)


class BriaRemoveVideoBackgroundResult(BaseModel):
video_url: str = Field(...)


class BriaRemoveVideoBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaResult | None = Field(None)
result: BriaRemoveVideoBackgroundResult | None = Field(None)
Loading
Loading