Skip to content

Commit 07ca685

Browse files
Fix dtype issue in embeddings connector. (Comfy-Org#12570)
1 parent f266b8d commit 07ca685

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

comfy/ldm/lightricks/embeddings_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def generate_freq_grid(self, spacing, dtype, device):
234234

235235
return indices
236236

237-
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
237+
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
238238
dim = self.inner_dim
239239
n_elem = 2 # 2 because of cos and sin
240240
freqs = self.precompute_freqs(indices_grid, spacing)
@@ -247,7 +247,7 @@ def precompute_freqs_cis(self, indices_grid, spacing="exp"):
247247
)
248248
else:
249249
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
250-
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
250+
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
251251

252252
def forward(
253253
self,
@@ -288,7 +288,7 @@ def forward(
288288
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
289289
)
290290
indices_grid = indices_grid[None, None, :]
291-
freqs_cis = self.precompute_freqs_cis(indices_grid)
291+
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
292292

293293
# 2. Blocks
294294
for block_idx, block in enumerate(self.transformer_1d_blocks):

0 commit comments

Comments
 (0)