Skip to content

[Bug] Assertion error in RoPE Embedding Kernel when training Qwen3-8B with long context #3216

@arnavgarg1

Description

@arnavgarg1
  1. Did you update? pip install --upgrade unsloth unsloth_zoo: Yes
  2. Colab or Kaggle or local / cloud: local on H200 GPU
  3. Number GPUs used, use nvidia-smi: 1
  4. Which notebook? Please link!: N/A
  5. Which Unsloth version, TRL version, transformers version, PyTorch version?: Master
  6. Which trainer? SFTTrainer, GRPOTrainer etc: SFTTrainer

Running into this assertion error when trying to train Qwen3 8B (but the DeepSeek R1 variant) with longer context lengths > 32K: https://huggingface.co/deepseek-ai/DeepSeek-R1-0528-Qwen3-8B. I figured it'll work given Qwen3 works and this is just a distill of it. Looks like an issue with the RoPE Embedding kernel?

│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/unsloth/models/llama.py", line 995, in LlamaModel_fast_forward                                                              │
│ ray-head     layer_outputs = decoder_layer(                                                                                                                                                │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/transformers/modeling_layers.py", line 92, in __call__                                                                      │
│ ray-head     return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)                                                                                          │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner                                                                                       │
│ ray-head     return disable_fn(*args, **kwargs)                                                                                                                                            │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn                                                                              │
│ ray-head     return fn(*args, **kwargs)                                                                                                                                                    │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint                                                                         │
│ ray-head     return CheckpointFunction.apply(function, preserve, *args)                                                                                                                    │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply                                                                             │
│ ray-head     return super().apply(*args, **kwargs)  # type: ignore[misc]                                                                                                                   │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/unsloth_zoo/gradient_checkpointing.py", line 475, in forward                                                                │
│ ray-head     outputs = run_function(*args)                                                                                                                                                 │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl                                                               │
│ ray-head     return self._call_impl(*args, **kwargs)                                                                                                                                       │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                                                                       │
│ ray-head     return forward_call(*args, **kwargs)                                                                                                                                          │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/unsloth/models/llama.py", line 667, in LlamaDecoderLayer_fast_forward                                                       │
│ ray-head     hidden_states, self_attn_weights, present_key_value = self.self_attn(                                                                                                         │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl                                                               │
│ ray-head     return self._call_impl(*args, **kwargs)                                                                                                                                       │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl                                                                       │
│ ray-head     return forward_call(*args, **kwargs)                                                                                                                                          │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/unsloth/models/qwen3.py", line 121, in Qwen3Attention_fast_forward                                                          │
│ ray-head     Q, K = fast_rope_embedding(Q, K, cos, sin)                                                                                                                                    │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn                                                                              │
│ ray-head     return fn(*args, **kwargs)                                                                                                                                                    │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/unsloth/kernels/rope_embedding.py", line 156, in fast_rope_embedding                                                        │
│ ray-head     Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)                                                                                                    │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply                                                                             │
│ ray-head     return super().apply(*args, **kwargs)  # type: ignore[misc]                                                                                                                   │
│ ray-head   File "/opt/poetry-venv/lib/python3.10/site-packages/unsloth/kernels/rope_embedding.py", line 91, in forward                                                                     │
│ ray-head     assert(seq_len <= cos.shape[0])                                                                                                                                               │
│ ray-head AssertionError                                                                                                                                                                    

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions