Skip to content

Conversation

@jarrycyx
Copy link
Contributor

📝 Pull Request Description

Problem
When running with long sequences, the following assertion occasionally triggers:

AssertionError
File "unsloth/kernels/rope_embedding.py", line 91, in forward
assert(seq_len <= cos.shape[0])

This happens because the current position_embeddings buffer is reused even when kv_seq_len exceeds its allocated length. The dynamically extended RoPE embedding is never called due to a missing check.

Root Cause
Seems that current_rope_size defined here never got extended.

self.current_rope_size = min(4 * 8192, self.max_position_embeddings)

In unsloth/models/llama.py, the condition:

if position_embeddings:
cos, sin = position_embeddings
else:

does not verify whether kv_seq_len fits within the existing RoPE cache.
As a result, cos and sin may be too short, leading to the AssertionError.

Fix
Added a boundary check before reusing position_embeddings:

if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:
    cos, sin = position_embeddings
else:
    rotary_emb = self.rotary_emb
    rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)

This ensures RoPE is extended dynamically when the sequence length exceeds the cached size.

Related Issue
Fixes #3036 #3216.

Impact

  • Prevents AssertionError in rope_embedding.py during long-sequence inference or training.
  • No change in behavior for short sequences.
  • Only fix Qwen3 and llama for now, need further investigation on other models.
jarrycyx and others added 2 commits November 11, 2025 13:39
Prevented AssertionError in rope_embedding.forward when kv_seq_len exceeds
the cached rope size. Added condition to verify kv_seq_len <=
position_embeddings[0].shape[0] before reuse, ensuring dynamic extension
triggers correctly.

Fixes unslothai#3036 unslothai#3216
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @jarrycyx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves an AssertionError that could occur in rope_embedding.py when processing long sequences. The issue stemmed from an insufficient check before reusing cached positional embeddings, leading to kv_seq_len exceeding the allocated buffer. The solution introduces a boundary check to dynamically extend the RoPE embedding only when necessary, ensuring correct behavior for long sequences while maintaining existing functionality for shorter ones. This change specifically targets Llama and Qwen3 models.

Highlights

  • Preventing RoPE AssertionError: Addresses an AssertionError in rope_embedding.py that occurred when processing long sequences. This error was caused by reusing position_embeddings without verifying if kv_seq_len fit within the allocated buffer, leading to out-of-bounds access.
  • Dynamic RoPE Extension: Implements a crucial boundary check (kv_seq_len <= position_embeddings[0].shape[0]) before reusing position_embeddings. This ensures that the RoPE embedding is dynamically extended only when the sequence length exceeds the currently cached size, preventing the assertion failure.
  • Model Scope: The fix has been applied to Llama and Qwen3 models. Further investigation is noted as necessary for other models to ensure comprehensive coverage.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a critical AssertionError that occurs during long-sequence processing by adding a boundary check for kv_seq_len before reusing cached RoPE position embeddings. The fix is applied to LlamaAttention_fast_forward in llama.py and Qwen3Attention_fast_forward in qwen3.py, which is a solid approach.

As you noted, further investigation on other models is needed. My analysis confirms this, as similar vulnerabilities seem to exist in other models within the repository. Specifically:

  • cohere.py (CohereAttention_fast_forward)
  • granite.py (GraniteAttention_fast_forward)
  • falcon_h1.py (FalconH1Attention_fast_forward)

These models also appear to use position_embeddings without a sufficient length check against kv_seq_len. The fix might not be as straightforward for all of them since some attention layers may not have direct access to self.rotary_emb to extend the embeddings.

The changes in this PR are correct and well-implemented. Great job on identifying and fixing this issue!

@Datta0
Copy link
Collaborator

Datta0 commented Nov 12, 2025

Hey @jarrycyx , thanks a lot for identifying and fixing the issue. It'd be great if you can also adapt the same to other models as well so that the PR is complete...

@jarrycyx
Copy link
Contributor Author

Hey @jarrycyx , thanks a lot for identifying and fixing the issue. It'd be great if you can also adapt the same to other models as well so that the PR is complete...

Thank you for the reply. I have added this fix to falcon_h1.
However, it seems that #3586 is more suitable to fix the other models and is already complete. This PR can serve as a double check for kv_seq_len.

@Teeeto
Copy link

Teeeto commented Nov 12, 2025

Once extended, does it get scaled back for shorter sequences on the next training step?

@jarrycyx
Copy link
Contributor Author

Once extended, does it get scaled back for shorter sequences on the next training step?

Actually I don't see any functions for scaling back here.

class LlamaRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
# [TODO] Hack to pass in config - need to remove later
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = getattr(config, "head_dim", None)
if dim is None: dim = int((config.hidden_size // config.num_attention_heads))
device = DEVICE_TYPE_TORCH
max_position_embeddings = config.max_position_embeddings
pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
self.multi_gpu_cos_cached = [None]*DEVICE_COUNT
self.multi_gpu_sin_cached = [None]*DEVICE_COUNT
# Build here to make `torch.jit.trace` work.
for device_idx in range(DEVICE_COUNT):
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype())
# dummy so that patch_utils doesn't fail for now
self.cos_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
self.sin_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True)
sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
pass
def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len is not None and seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
device_index = x.device.index
return (
self.multi_gpu_cos_cached[device_index][:seq_len],
self.multi_gpu_sin_cached[device_index][:seq_len],
)
pass
def get_cached(self, seq_len = None, device_index = None):
if device_index is None:
device_index = get_current_device()
return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index]
pass
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
for device_idx in range(DEVICE_COUNT):
self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype)
pass
pass
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
t = t / self.scaling_factor
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True)
sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True)
self.multi_gpu_cos_cached[device.index] = cos
self.multi_gpu_sin_cached[device.index] = sin
return cos, sin
pass
pass

@Teeeto
Copy link

Teeeto commented Nov 12, 2025

Once extended, does it get scaled back for shorter sequences on the next training step?

Actually I don't see any functions for scaling back here.

I don't have enough knowledge to estimate the consequences, but may it cause unexpected drift in training results?

@jarrycyx
Copy link
Contributor Author

Once extended, does it get scaled back for shorter sequences on the next training step?

Actually I don't see any functions for scaling back here.

I don't have enough knowledge to estimate the consequences, but may it cause unexpected drift in training results?

inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()

I am also not an expert, but it is my understanding that the frequency calculated here is not based on the context length, so maybe not?

@danielhanchen
Copy link
Contributor

Thanks so much for the fix!

@danielhanchen danielhanchen merged commit 6c73f7a into unslothai:main Nov 14, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants