Skip to content

Commit c3c3e93

Browse files
Use rope functions from comfy kitchen. (Comfy-Org#11674)
1 parent 6ffc159 commit c3c3e93

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

comfy/ldm/flux/math.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from comfy.ldm.modules.attention import optimized_attention
66
import comfy.model_management
7+
import logging
78

89

910
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
@@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
1314
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
1415
return x
1516

16-
1717
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
1818
assert dim % 2 == 0
1919
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
@@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
2828
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
2929
return out.to(dtype=torch.float32, device=pos.device)
3030

31-
def apply_rope1(x: Tensor, freqs_cis: Tensor):
32-
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
3331

34-
x_out = freqs_cis[..., 0] * x_[..., 0]
35-
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
32+
try:
33+
import comfy.quant_ops
34+
apply_rope = comfy.quant_ops.ck.apply_rope
35+
apply_rope1 = comfy.quant_ops.ck.apply_rope1
36+
except:
37+
logging.warning("No comfy kitchen, using old apply_rope functions.")
38+
def apply_rope1(x: Tensor, freqs_cis: Tensor):
39+
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
40+
41+
x_out = freqs_cis[..., 0] * x_[..., 0]
42+
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
3643

37-
return x_out.reshape(*x.shape).type_as(x)
44+
return x_out.reshape(*x.shape).type_as(x)
3845

39-
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
40-
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
46+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
47+
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ psutil
2121
alembic
2222
SQLAlchemy
2323
av>=14.2.0
24-
comfy-kitchen>=0.2.1
24+
comfy-kitchen>=0.2.2
2525

2626
#non essential dependencies:
2727
kornia>=0.7.1

0 commit comments

Comments
 (0)