Skip to content

Commit 6c73f7a

Browse files
jarrycyxjarrycyx
andauthored
Fix: prevent rope_embedding AssertionError by checking kv_seq_len before reuse (#3578)
* fix: add kv_seq_len boundary check before reusing RoPE embeddings Prevented AssertionError in rope_embedding.forward when kv_seq_len exceeds the cached rope size. Added condition to verify kv_seq_len <= position_embeddings[0].shape[0] before reuse, ensuring dynamic extension triggers correctly. Fixes #3036 #3216 * fix falcon h1 --------- Co-authored-by: jarrycyx <dzdzzd@126.com>
1 parent d9a27fb commit 6c73f7a

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

‎unsloth/models/falcon_h1.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def FalconH1Attention_fast_forward(
116116
if past_key_value is not None:
117117
kv_seq_len += past_key_value[0].shape[-2]
118118

119-
if position_embeddings:
119+
if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:
120120
cos, sin = position_embeddings
121121
else:
122122
# Extend RoPE dynamically to fit in VRA

‎unsloth/models/llama.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def LlamaAttention_fast_forward(
566566
if past_key_value is not None:
567567
kv_seq_len += past_key_value[0].shape[-2]
568568

569-
if position_embeddings:
569+
if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:
570570
cos, sin = position_embeddings
571571
else:
572572
# Extend RoPE dynamically to fit in VRA

‎unsloth/models/qwen3.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def Qwen3Attention_fast_forward(
110110
if past_key_value is not None:
111111
kv_seq_len += past_key_value[0].shape[-2]
112112

113-
if position_embeddings:
113+
if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:
114114
cos, sin = position_embeddings
115115
else:
116116
# Extend RoPE dynamically to fit in VRA

0 commit comments

Comments
 (0)