Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ Put your VAE in: models/vae

AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:

```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```

This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:

Expand Down
19 changes: 10 additions & 9 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,18 +679,19 @@ def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)

def _load_list(self, prio_comfy_cast_weights=False):
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
default = False
params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=True):
if name not in params:
skip = True # skip random weights in non leaf modules
default = True # default random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
if default and default_device is not None:
for param in params.values():
param.data = param.data.to(device=default_device)
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
Expand Down Expand Up @@ -1495,7 +1496,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)

for x in loading:
Expand Down Expand Up @@ -1579,7 +1580,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
return 0 if vbar is None else vbar.free_memory(memory_to_free)

def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True)
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
Expand Down
25 changes: 14 additions & 11 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,6 @@ def forward(self, x: torch.Tensor):



def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
Expand Down Expand Up @@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out.append((cos, sin))
sin_split = sin.shape[-1] // 2
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))

if len(out) == 1:
return out[0]

return out


def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
nsin = freqs_cis[2]

q_embed = (xq * cos)
q_split = q_embed.shape[-1] // 2
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)

k_embed = (xk * cos)
k_split = k_embed.shape[-1] // 2
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)

return q_embed.to(org_dtype), k_embed.to(org_dtype)


Expand Down
Loading