Skip to content

[Bug] RuntimeError: backward through graph a second time - Whisper training with gradient checkpointing #3592

@VincentValueCare

Description

@VincentValueCare
  1. Did you update? pip install --upgrade unsloth unsloth-zoo
    Yes, already updated to latest version. Issue persists.

  2. Colab or kaggle or local / cloud

    • Google Colab (Tesla T4)
    • Local (RTX 3080 Mobile) using official Unsloth Docker image
  3. Number GPUs used, use nvidia-smi
    1 GPU - Tesla T4 (15360MiB) on Colab

  4. Which notebook? Please link!
    https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Whisper.ipynb

  5. Which unsloth version, TRL version, transformers version, PyTorch version?

    • Unsloth: 2025.11.2
    • Unsloth Zoo: 2025.11.3
    • TRL: 0.22.2
    • Transformers: 4.56.2
    • PyTorch: 2.8.0+cu126
    • Python: 3.12.12
  6. Which trainer? SFTTrainer, GRPOTrainer etc
    Seq2SeqTrainer

Error Description

Training fails immediately during the first training step with a gradient checkpointing error. This issue occurs consistently on:

  • Google Colab with Tesla T4
  • Local RTX 3080 Mobile using the official Unsloth Docker image

The error suggests that the gradient computation graph is being traversed twice during backward pass, which is not allowed without retain_graph=True.

This is a Regression

This code worked previously with an earlier version of Unsloth. The training completed successfully with the following output:

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,123 | Num Epochs = 1 | Total steps = 60
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 31,457,280 of 1,574,947,840 (2.00% trained)
Unsloth: Not an error, but WhisperForConditionalGeneration does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
Unsloth: Will smartly offload gradients to save VRAM!
[60/60 06:44, Epoch 0/1]
Step    Training Loss    Validation Loss    Wer
5       1.916300        1.587457           19.969040
10      1.441900        1.481879           19.504644
15      1.221700        1.397027           19.969040
20      2.508200        1.317600           20.046440
25      1.962400        1.243164           19.349845
30      1.558600        1.176595           20.046440
35      1.126300        1.122559           20.510836
40      1.012300        1.080295           21.439628
45      1.360200        1.044217           22.368421
50      1.133800        1.012587           23.297214
55      0.789300        0.989529           23.839009
60      1.143100        0.978841           23.761610

After updating to the latest version, the exact same code now crashes immediately at the first training step.

Minimal code to reproduce

from unsloth import FastModel
from transformers import WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset, Audio
from unsloth import is_bf16_supported
import torch

# Model laden
model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/whisper-large-v3",
    dtype = None,
    load_in_4bit = False,
    auto_model = WhisperForConditionalGeneration,
    whisper_language = "English",
    whisper_task = "transcribe",
)

# LoRA configuratie
model = FastModel.get_peft_model(
    model,
    r = 64,
    target_modules = ["q_proj", "v_proj"],
    lora_alpha = 64,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
    task_type = None,
)

# Dataset voorbereiden
model.generation_config.language = "<|en|>"
model.generation_config.task = "transcribe"
model.config.suppress_tokens = []
model.generation_config.forced_decoder_ids = None

def formatting_prompts_func(example):
    audio_arrays = example['audio']['array']
    sampling_rate = example["audio"]["sampling_rate"]
    features = tokenizer.feature_extractor(audio_arrays, sampling_rate=sampling_rate)
    tokenized_text = tokenizer.tokenizer(example["text"])
    return {
        "input_features": features.input_features[0],
        "labels": tokenized_text.input_ids,
    }

dataset = load_dataset("MrDragonFox/Elise", split="train")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.train_test_split(test_size=0.06)
train_dataset = [formatting_prompts_func(example) for example in dataset['train']]

# DataCollator
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        
        batch["labels"] = labels
        return batch

# Trainer configuratie
trainer = Seq2SeqTrainer(
    model = model,
    train_dataset = train_dataset,
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=tokenizer),
    tokenizer = tokenizer.feature_extractor,
    args = Seq2SeqTrainingArguments(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 1e-4,
        logging_steps = 1,
        optim = "adamw_8bit",
        fp16 = not is_bf16_supported(),
        bf16 = is_bf16_supported(),
        weight_decay = 0.001,
        remove_unused_columns = False,
        lr_scheduler_type = "linear",
        label_names = ['labels'],
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",
    ),
)

# Training start - FAILS IMMEDIATELY
trainer_stats = trainer.train()

Full Error Traceback

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,123 | Num Epochs = 1 | Total steps = 60
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 31,457,280 of 1,574,947,840 (2.00% trained)
Unsloth: Not an error, but WhisperForConditionalGeneration does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

/tmp/ipython-input-773422404.py in <cell line: 0>()
----> 1 trainer_stats = trainer.train()

12 frames

/usr/local/lib/python3.12/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2326                 hf_hub_utils.enable_progress_bars()
   2327         else:
-> 2328             return inner_training_loop(
   2329                 args=args,
   2330                 resume_from_checkpoint=resume_from_checkpoint,

/usr/local/lib/python3.12/dist-packages/unsloth_zoo/compiler.py in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

/usr/local/lib/python3.12/dist-packages/unsloth/models/_utils.py in _unsloth_training_step(***failed resolving arguments***)

/usr/local/lib/python3.12/dist-packages/accelerate/accelerator.py in backward(self, loss, **kwargs)
   2734             return
   2735         elif self.scaler is not None:
-> 2736             self.scaler.scale(loss).backward(**kwargs)
   2737         elif learning_rate is not None and self.has_lomo_optimizer:
   2738             self.lomo_backward(loss, learning_rate)

/usr/local/lib/python3.12/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    645                 inputs=inputs,
    646             )
--> 647         torch.autograd.backward(
    648             self, gradient, retain_graph, create_graph, inputs=inputs
    649         )

/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    352     # some Python versions print out the first line of a multi-line function
    353     # calls in the traceback and some print out the last line
--> 354     _engine_run_backward(
    355         tensors,
    356         grad_tensors_,

/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py in _engine_run_backward(t_outputs, *args, **kwargs)
    827         unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    828     try:
--> 829         return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    830             t_outputs, *args, **kwargs
    831         )  # Calls into the C++ engine to run the backward pass

/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py in apply(self, *args)
    309             )
    310         user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 311         return user_fn(self, *args)
    312 
    313     def apply_jvp(self, *args):

/usr/local/lib/python3.12/dist-packages/unsloth_zoo/gradient_checkpointing.py in backward(ctx, *args)
    596             # )
    597         else:
--> 598             torch.autograd.backward(outputs_with_grad, args_with_grad)
    599         pass
    600 

/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    352     # some Python versions print out the first line of a multi-line function
    353     # calls in the traceback and some print out the last line
--> 354     _engine_run_backward(
    355         tensors,
    356         grad_tensors_,

/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py in _engine_run_backward(t_outputs, *args, **kwargs)
    827         unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    828     try:
--> 829         return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    830             t_outputs, *args, **kwargs
    831         )  # Calls into the C++ engine to run the backward pass

/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py in apply(self, *args)
    309             )
    310         user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 311         return user_fn(self, *args)
    312 
    313     def apply_jvp(self, *args):

/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py in backward(ctx, *flat_args)
   2196             def backward(ctx, *flat_args):
   2197                 all_args = _backward_prologue_functional(
-> 2198                     ctx.saved_tensors,
   2199                     ctx.symints,
   2200                     CompiledFunction.metadata,

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Additional Context

In the previous working version, the output included the message "Unsloth: Will smartly offload gradients to save VRAM!", which is missing in the current failing version. This might indicate that the gradient offloading mechanism has been changed or broken in a recent update.

The error occurs immediately on the first training step and originates from the Unsloth gradient checkpointing implementation in unsloth_zoo/gradient_checkpointing.py.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions