Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Detach hidden states to avoid gradient carry
Detach old and reference hidden states to prevent gradient flow across micro-batches.
  • Loading branch information
pluesclues authored Nov 4, 2025
commit d5ff9be234ef20bf982a57d9c41fb434122d0ee9
8 changes: 8 additions & 0 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,14 @@ def grpo_accumulated_loss(
image_sizes = image_sizes,
logits_to_keep = logits_to_keep + 1,
).logits

# Ensure old/ref states never carry grads across micro-batches
if old_hidden_states is not None and torch.is_tensor(old_hidden_states):
old_hidden_states = old_hidden_states.detach()

if ref_hidden_states is not None and torch.is_tensor(ref_hidden_states):
ref_hidden_states = ref_hidden_states.detach()

loss, completion_length, mean_kl, delta, flat_is_ratio = UnslothEfficientGRPO.apply(
new_hidden_states,
old_hidden_states,
Expand Down