-
Notifications
You must be signed in to change notification settings - Fork 2.3k
🏌️ DAPO loss type #3938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🏌️ DAPO loss type #3938
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -226,13 +226,18 @@ def split_tensor_dict( | |
| """ | ||
| first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) | ||
| chunk_size = first_tensor.shape[0] // num_chunks | ||
| return [ | ||
| { | ||
| key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None | ||
| for key, tensor in tensor_dict.items() | ||
| } | ||
| for i in range(num_chunks) | ||
| ] | ||
| chunks = [] | ||
| for i in range(num_chunks): | ||
| chunk_dict = {} | ||
| for key, tensor in tensor_dict.items(): | ||
| if tensor is not None and tensor.ndim > 0: | ||
| chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size] | ||
| elif tensor is not None and tensor.ndim == 0: | ||
| chunk_dict[key] = tensor | ||
| else: | ||
| chunk_dict[key] = None | ||
| chunks.append(chunk_dict) | ||
| return chunks | ||
|
|
||
|
|
||
| def shuffle_sequence_dict(seq_dict: dict[str, Optional[Sequence]]) -> dict[str, Optional[Sequence]]: | ||
|
|
@@ -258,7 +263,9 @@ def shuffle_sequence_dict(seq_dict: dict[str, Optional[Sequence]]) -> dict[str, | |
| def permute(v: Optional[Sequence]) -> Optional[Sequence]: | ||
| if v is None: | ||
| return None | ||
| if isinstance(v, torch.Tensor): | ||
| if isinstance(v, torch.Tensor) and v.ndim == 0: | ||
| return v | ||
| if isinstance(v, torch.Tensor) and v.ndim >= 1: | ||
| return v[permutation] | ||
| return [v[i] for i in permutation] | ||
|
|
||
|
|
@@ -700,6 +707,12 @@ def __init__( | |
| processing_class=processing_class, | ||
| callbacks=callbacks, | ||
| optimizers=optimizers, | ||
| # In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func` | ||
| # is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the | ||
| # global accumulated batch. To control scaling ourselves, we must disable Trainer’s built-in scaling. The | ||
| # simplest (though a bit hacky) way is to set `compute_loss_func` to any non-None value, which bypasses | ||
| # that behavior without rewriting `training_step`. | ||
| compute_loss_func="non-None value to disable scaling", | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @SunMarc, idk if you'd have done it this way
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah it's a bit hacky but i guess that one of the simplest solution. Maybe you can also do it with modifying |
||
| ) | ||
|
|
||
| # Reference model | ||
|
|
@@ -1593,6 +1606,8 @@ def _generate_and_score_completions( | |
|
|
||
| # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging | ||
| completion_lengths = completion_mask.sum(1) | ||
| agg_completion_lengths = self.accelerator.gather(completion_lengths) | ||
| num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss | ||
|
|
||
| # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask | ||
| if self.mask_truncated_completions: | ||
|
|
@@ -1712,7 +1727,6 @@ def _generate_and_score_completions( | |
| self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] | ||
|
|
||
| # Log completion lengths, mean, min, max | ||
| agg_completion_lengths = self.accelerator.gather(completion_lengths) | ||
| self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) | ||
| self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) | ||
| self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) | ||
|
|
@@ -1754,6 +1768,7 @@ def _generate_and_score_completions( | |
| "completion_ids": completion_ids, | ||
| "completion_mask": completion_mask, | ||
| "advantages": advantages, | ||
| "num_items_in_batch": num_items_in_batch, | ||
| } | ||
| if old_per_token_logps is not None: | ||
| output["old_per_token_logps"] = old_per_token_logps | ||
|
|
@@ -1809,7 +1824,7 @@ def compute_liger_loss(self, unwrapped_model, inputs): | |
| if self.beta != 0.0: | ||
| self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item()) | ||
| self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item()) | ||
| return loss | ||
| return loss / self.current_gradient_accumulation_steps | ||
|
|
||
| @profiling_decorator | ||
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | ||
|
|
@@ -1894,10 +1909,16 @@ def _compute_loss(self, model, inputs): | |
|
|
||
| if self.loss_type == "grpo": | ||
| loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() | ||
| loss = loss / self.current_gradient_accumulation_steps | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now that we don't rely on trainer to scale the reward, we need to do it here. |
||
| elif self.loss_type == "bnpo": | ||
| loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) | ||
| loss = loss / self.current_gradient_accumulation_steps | ||
| elif self.loss_type == "dr_grpo": | ||
| loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) | ||
| loss = loss / self.current_gradient_accumulation_steps | ||
| elif self.loss_type == "dapo": | ||
| normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes | ||
| loss = (per_token_loss * completion_mask).sum() / normalizer | ||
| else: | ||
| raise ValueError(f"Unknown loss type: {self.loss_type}") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for
"num_items_in_batch"that is a single float that needs to be copied in every chunks