@@ -234,7 +234,7 @@ def generate_freq_grid(self, spacing, dtype, device):
234234
235235 return indices
236236
237- def precompute_freqs_cis (self , indices_grid , spacing = "exp" ):
237+ def precompute_freqs_cis (self , indices_grid , spacing = "exp" , out_dtype = None ):
238238 dim = self .inner_dim
239239 n_elem = 2 # 2 because of cos and sin
240240 freqs = self .precompute_freqs (indices_grid , spacing )
@@ -247,7 +247,7 @@ def precompute_freqs_cis(self, indices_grid, spacing="exp"):
247247 )
248248 else :
249249 cos_freq , sin_freq = interleaved_freqs_cis (freqs , dim % n_elem )
250- return cos_freq .to (self . dtype ), sin_freq .to (self . dtype ), self .split_rope
250+ return cos_freq .to (dtype = out_dtype ), sin_freq .to (dtype = out_dtype ), self .split_rope
251251
252252 def forward (
253253 self ,
@@ -288,7 +288,7 @@ def forward(
288288 hidden_states .shape [1 ], dtype = torch .float32 , device = hidden_states .device
289289 )
290290 indices_grid = indices_grid [None , None , :]
291- freqs_cis = self .precompute_freqs_cis (indices_grid )
291+ freqs_cis = self .precompute_freqs_cis (indices_grid , out_dtype = hidden_states . dtype )
292292
293293 # 2. Blocks
294294 for block_idx , block in enumerate (self .transformer_1d_blocks ):
0 commit comments