Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ $$
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
$$

To use this formulation, set `loss_type="bnpo"` in [`GRPOConfig`]. Note that we do not reproduce the DAPO formulation exactly: when using gradient accumulation, the loss is computed over the total number of tokens in each batch, not over the accumulated batches. `loss_type="bnpo"` is equivalent to the original DAPO formulation only when `gradient_accumulation_steps=1`.
To use this formulation, set `loss_type="dapo"` in [`GRPOConfig`].

Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation:

Expand Down
18 changes: 16 additions & 2 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def test_with_none_tensor(self):
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
self.assertIsNone(result[i]["y"])

def test_with_scalar(self):
x = torch.arange(12).reshape(6, 2)
tensor_dict = {"x": x, "y": torch.tensor(1)}

result = split_tensor_dict(tensor_dict, 2)

expected_x_chunks = torch.chunk(x, 2, dim=0)
self.assertEqual(len(result), 2)
for i in range(2):
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
self.assertTrue(torch.equal(result[i]["y"], torch.tensor(1)))


class ShuffleSequenceDictTester(TrlTestCase):
def test_shuffle_preserves_shape(self):
Expand Down Expand Up @@ -549,7 +561,7 @@ def test_training(self, config_name):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@parameterized.expand([("bnpo",), ("dr_grpo",)])
@parameterized.expand([("bnpo",), ("dr_grpo",), ("dapo",)])
def test_training_loss_types(self, loss_type):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

Expand All @@ -559,6 +571,7 @@ def test_training_loss_types(self, loss_type):
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
gradient_accumulation_steps=2, # set to 2 to test than DAPO can operate with accumulated batch
loss_type=loss_type,
report_to="none",
)
Expand Down Expand Up @@ -1785,7 +1798,8 @@ def test_training_vlm_and_liger(self):
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
use_liger_loss=True, # Enable Liger loss
use_liger_loss=True, # enable Liger loss
loss_type="bnpo", # default dapo is not supported yet
report_to="none",
)
trainer = GRPOTrainer(
Expand Down
46 changes: 26 additions & 20 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,22 @@ class GRPOConfig(TrainingArguments):
- `False` or `"none"`: no scaling is applied. The [Dr. GRPO
paper](https://huggingface.co/papers/2503.20783) recommends not scaling rewards, as scaling by the
standard deviation introduces a question-level difficulty bias.
loss_type (`str`, *optional*, defaults to `"bnpo"`):
loss_type (`str`, *optional*, defaults to `"dapo"`):
Specifies the loss formulation to use. Supported values are:

- `"grpo"`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to
length bias—this approach tends to prefer shorter completions with positive advantages and longer ones
with negative advantages.
- `"bnpo"`: Aggregates token-level losses by normalizing number of active token in the local batch.
Note that normalization is performed over the local batch only, so results may slightly vary depending
on the local batch size, despite a constant effective batch size. When using
`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss.
length bias—this approach tends to prefer shorter completions with positive advantages and longer ones
with negative advantages.
- `"dr_grpo"`: Aggregates token-level losses by normalizing with a global constant. This method was
introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) to eliminate length bias.
The value of the constant corresponds to `max_completion_length`.
introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) to eliminate length bias.
The value of the constant corresponds to `max_completion_length`.
- `"dapo"` (default): Aggregates token-level losses by normalizing with the number of active token in the
global accumulated batch. This method was introduced in the [DAPO
paper](https://huggingface.co/papers/2503.14476) to eliminate length bias.
- `"bnpo"`: Aggregates token-level losses by normalizing with the number of active token in the local
batch. Note that normalization is performed over the local batch only, so results may slightly vary
depending on the local batch size, despite a constant effective batch size. When using
`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss.
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
incorrectly penalized and introducing noise during training. According to the
Expand Down Expand Up @@ -514,19 +517,22 @@ class GRPOConfig(TrainingArguments):
},
)
loss_type: str = field(
default="bnpo",
metadata={
"help": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`. "
"`'grpo'`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to "
"length bias—this approach tends to prefer shorter completions with positive advantages and longer ones "
"with negative advantages. "
"`'bnpo'`: Aggregates token-level losses by normalizing number of active token in the local batch. "
default="dapo",
metadata={
"help": "Specifies the loss formulation to use. Supported values are 'grpo', 'dapo', 'bnpo', and "
"'dr_grpo'. "
"'grpo': Aggregates token-level losses by normalizing over sequence length. Not recommended due to length "
"bias—this approach tends to prefer shorter completions with positive advantages and longer ones with "
"negative advantages. "
"'dapo' (default): Aggregates token-level losses by normalizing with the number of active token in the "
"global accumulated batch. This method was introduced in the DAPO paper to eliminate length bias. "
"'dr_grpo': Aggregates token-level losses by normalizing with a global constant. This method was "
"introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to "
"`max_completion_length`. "
"'bnpo': Aggregates token-level losses by normalizing with the number of active token in the local batch. "
"Note that normalization is performed over the local batch only, so results may slightly vary depending "
"on the local batch size, despite a constant effective batch size. When using "
"`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. "
"`'dr_grpo'`: Aggregates token-level losses by normalizing with a global constant. This method was "
"introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to "
"`max_completion_length`."
"`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss."
},
)
mask_truncated_completions: bool = field(
Expand Down
41 changes: 31 additions & 10 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,18 @@ def split_tensor_dict(
"""
first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
chunk_size = first_tensor.shape[0] // num_chunks
return [
{
key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None
for key, tensor in tensor_dict.items()
}
for i in range(num_chunks)
]
chunks = []
for i in range(num_chunks):
chunk_dict = {}
for key, tensor in tensor_dict.items():
if tensor is not None and tensor.ndim > 0:
chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size]
elif tensor is not None and tensor.ndim == 0:
chunk_dict[key] = tensor
else:
chunk_dict[key] = None
chunks.append(chunk_dict)
return chunks


def shuffle_sequence_dict(seq_dict: dict[str, Optional[Sequence]]) -> dict[str, Optional[Sequence]]:
Expand All @@ -258,7 +263,9 @@ def shuffle_sequence_dict(seq_dict: dict[str, Optional[Sequence]]) -> dict[str,
def permute(v: Optional[Sequence]) -> Optional[Sequence]:
if v is None:
return None
if isinstance(v, torch.Tensor):
if isinstance(v, torch.Tensor) and v.ndim == 0:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for "num_items_in_batch" that is a single float that needs to be copied in every chunks

return v
if isinstance(v, torch.Tensor) and v.ndim >= 1:
return v[permutation]
return [v[i] for i in permutation]

Expand Down Expand Up @@ -700,6 +707,12 @@ def __init__(
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
# In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func`
# is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the
# global accumulated batch. To control scaling ourselves, we must disable Trainer’s built-in scaling. The
# simplest (though a bit hacky) way is to set `compute_loss_func` to any non-None value, which bypasses
# that behavior without rewriting `training_step`.
compute_loss_func="non-None value to disable scaling",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @SunMarc, idk if you'd have done it this way

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's a bit hacky but i guess that one of the simplest solution. Maybe you can also do it with modifying model_accepts_loss_kwargs and num_items_in_batches

)

# Reference model
Expand Down Expand Up @@ -1593,6 +1606,8 @@ def _generate_and_score_completions(

# Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
completion_lengths = completion_mask.sum(1)
agg_completion_lengths = self.accelerator.gather(completion_lengths)
num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss

# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
Expand Down Expand Up @@ -1712,7 +1727,6 @@ def _generate_and_score_completions(
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]

# Log completion lengths, mean, min, max
agg_completion_lengths = self.accelerator.gather(completion_lengths)
self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())
Expand Down Expand Up @@ -1754,6 +1768,7 @@ def _generate_and_score_completions(
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": advantages,
"num_items_in_batch": num_items_in_batch,
}
if old_per_token_logps is not None:
output["old_per_token_logps"] = old_per_token_logps
Expand Down Expand Up @@ -1809,7 +1824,7 @@ def compute_liger_loss(self, unwrapped_model, inputs):
if self.beta != 0.0:
self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item())
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item())
return loss
return loss / self.current_gradient_accumulation_steps

@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
Expand Down Expand Up @@ -1894,10 +1909,16 @@ def _compute_loss(self, model, inputs):

if self.loss_type == "grpo":
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
loss = loss / self.current_gradient_accumulation_steps
Copy link
Member Author

@qgallouedec qgallouedec Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we don't rely on trainer to scale the reward, we need to do it here.

elif self.loss_type == "bnpo":
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
loss = loss / self.current_gradient_accumulation_steps
elif self.loss_type == "dr_grpo":
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
loss = loss / self.current_gradient_accumulation_steps
elif self.loss_type == "dapo":
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * completion_mask).sum() / normalizer
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")

Expand Down
Loading