44
55from comfy .ldm .modules .attention import optimized_attention
66import comfy .model_management
7+ import logging
78
89
910def attention (q : Tensor , k : Tensor , v : Tensor , pe : Tensor , mask = None , transformer_options = {}) -> Tensor :
@@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
1314 x = optimized_attention (q , k , v , heads , skip_reshape = True , mask = mask , transformer_options = transformer_options )
1415 return x
1516
16-
1717def rope (pos : Tensor , dim : int , theta : int ) -> Tensor :
1818 assert dim % 2 == 0
1919 if comfy .model_management .is_device_mps (pos .device ) or comfy .model_management .is_intel_xpu () or comfy .model_management .is_directml_enabled ():
@@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
2828 out = rearrange (out , "b n d (i j) -> b n d i j" , i = 2 , j = 2 )
2929 return out .to (dtype = torch .float32 , device = pos .device )
3030
31- def apply_rope1 (x : Tensor , freqs_cis : Tensor ):
32- x_ = x .to (dtype = freqs_cis .dtype ).reshape (* x .shape [:- 1 ], - 1 , 1 , 2 )
3331
34- x_out = freqs_cis [..., 0 ] * x_ [..., 0 ]
35- x_out .addcmul_ (freqs_cis [..., 1 ], x_ [..., 1 ])
32+ try :
33+ import comfy .quant_ops
34+ apply_rope = comfy .quant_ops .ck .apply_rope
35+ apply_rope1 = comfy .quant_ops .ck .apply_rope1
36+ except :
37+ logging .warning ("No comfy kitchen, using old apply_rope functions." )
38+ def apply_rope1 (x : Tensor , freqs_cis : Tensor ):
39+ x_ = x .to (dtype = freqs_cis .dtype ).reshape (* x .shape [:- 1 ], - 1 , 1 , 2 )
40+
41+ x_out = freqs_cis [..., 0 ] * x_ [..., 0 ]
42+ x_out .addcmul_ (freqs_cis [..., 1 ], x_ [..., 1 ])
3643
37- return x_out .reshape (* x .shape ).type_as (x )
44+ return x_out .reshape (* x .shape ).type_as (x )
3845
39- def apply_rope (xq : Tensor , xk : Tensor , freqs_cis : Tensor ):
40- return apply_rope1 (xq , freqs_cis ), apply_rope1 (xk , freqs_cis )
46+ def apply_rope (xq : Tensor , xk : Tensor , freqs_cis : Tensor ):
47+ return apply_rope1 (xq , freqs_cis ), apply_rope1 (xk , freqs_cis )
0 commit comments