Skip to content

Commit 347c46a

Browse files
authored
Detach hidden states to avoid gradient carry (#345)
Detach old and reference hidden states to prevent gradient flow across micro-batches.
1 parent 44cf87f commit 347c46a

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

‎unsloth_zoo/rl_replacements.py‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,14 @@ def grpo_accumulated_loss(
626626
image_sizes = image_sizes,
627627
logits_to_keep = logits_to_keep + 1,
628628
).logits
629+
630+
# Ensure old/ref states never carry grads across micro-batches
631+
if old_hidden_states is not None and torch.is_tensor(old_hidden_states):
632+
old_hidden_states = old_hidden_states.detach()
633+
634+
if ref_hidden_states is not None and torch.is_tensor(ref_hidden_states):
635+
ref_hidden_states = ref_hidden_states.detach()
636+
629637
loss, completion_length, mean_kl, delta, flat_is_ratio = UnslothEfficientGRPO.apply(
630638
new_hidden_states,
631639
old_hidden_states,

0 commit comments

Comments
 (0)