Skip to content

Conversation

@hutaiHang
Copy link
Contributor

@hutaiHang hutaiHang commented Jul 25, 2025

What does this PR do?

This PR addresses an issue where the loss scaling during gradient accumulation is incorrect for the final optimizer step of an epoch if the total number of batches is not perfectly divisible by gradient_accumulation_steps.

Currently, the loss for each micro-batch is always divided by the configured args.gradient_accumulation_steps. This leads to the accumulated loss for the final, incomplete cycle being scaled down too much, resulting in an improperly small gradient update for that step.

This fix resolves the issue by dynamically tracking the number of micro-batches processed in each accumulation cycle and using this actual count for loss scaling.

The changes are as follows:

  1. In the _inner_training_loop, a new instance variable self.cur_gradient_accumulation_steps is introduced. It is updated at the start of each optimizer step with the actual number of batches being processed (i.e., len(batch_samples)).
  2. In the training_step method, the loss scaling logic now uses this dynamic self.cur_gradient_accumulation_steps value instead of the fixed self.args.gradient_accumulation_steps.

This ensures that the loss is correctly averaged over the number of batches that actually contributed to the gradient accumulation, regardless of whether the cycle was complete or not. This change has no new dependencies.

Fixes #38837

Before submitting

Who can review?

@hutaiHang
Copy link
Contributor Author

Hi @SunMarc,

Thanks for your guidance in issue #38837. As you suggested, I've created this Pull Request to address the gradient accumulation scaling issue.

All CI checks have now passed. Could you please take a look when you have a moment?

Thank you!

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

hutaiHang and others added 3 commits July 25, 2025 23:37
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@hutaiHang
Copy link
Contributor Author

Hi @qgallouedec, thanks for the feedback! I've just pushed the requested changes. Could you please take another look when you have a chance?

@qgallouedec
Copy link
Member

Looks good, let's see if the CI is happy

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hutaiHang
Copy link
Contributor Author

@qgallouedec All checks have passed, thanks for the review. 😊

@qgallouedec
Copy link
Member

Note for myself: it will silently break the grad accumulation in GRPOtrainer because this trainer oversample from the dataloader. I'll have to find a way to solve that

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

@SunMarc SunMarc merged commit 075dbbc into huggingface:main Jul 29, 2025
25 checks passed
@kaln27
Copy link
Contributor

kaln27 commented Jul 31, 2025

Hi @hutaiHang

There is a question. What if the last batch only has 1 data ? This only one data will contribute to all of the gradient, which means that the model will update to a direction that is not general (Cause we wanna batch size to be big to make each update to be general). That will cause training not stable.

I thinks we should keep the loss divided by self.args.gradient_accumulation_steps before backward process. At the end return the loss (right here loss is just for log) that is sacled to the correct one.

# trainer.py
        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
            if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
                loss = loss / self.args.gradient_accumulation_steps

            # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
            # https://github.com/huggingface/transformers/pull/35808
            if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
                kwargs["scale_wrt_gas"] = False

            self.accelerator.backward(loss, **kwargs)

            rescale_loss = loss.detach() * self.args.gradient_accumulation_steps / self.current_gradient_accumulation_steps

            return rescale_loss
@hutaiHang
Copy link
Contributor Author

Hi @kaln27

Thank you for bringing up this excellent point about training dynamics! It's a very important consideration.

You are absolutely right to be concerned about the potential instability if an optimizer step is updated based on a very small final batch (e.g., with a single sample). The gradient from such a small batch can indeed be noisy.

However, I believe the proposed solution in the PR is the correct way to fix the underlying mathematical bug in the Trainer, and the stability concern should be addressed at the data loading level. Let me break down my reasoning:

1. Gradient Magnitude Correctness (The goal of this PR)

The purpose of gradient accumulation is to simulate a larger batch size. The final gradient update should have a magnitude that is the average of the gradients from the processed micro-batches.

  • If we process 4 batches, the loss is summed and then divided by 4 before backward().
  • If we process 1 batch, the loss should be divided by 1 before backward().

This PR ensures this mathematical correctness. If we were to always divide by self.args.gradient_accumulation_steps as you suggested, the gradient for the final, incomplete step would be artificially suppressed (e.g., 1/4 of its correct magnitude). This would mean the model barely learns from those last few samples, which is also undesirable.

2. Training Stability (The concern you raised)

The question of whether one wants to perform an update based on a noisy, small batch is a matter of training strategy.

The standard and recommended way to handle this in transformers is to use the dataloader_drop_last=True argument in TrainingArguments. This tells the DataLoader to simply discard the last, incomplete batch, ensuring that every single optimizer step sees a full batch of data.

Conclusion:

This PR is focused on making the Trainer's default behavior mathematically correct. Users who are concerned about the stability implications of small final batches already have a direct and explicit tool (dataloader_drop_last) to manage this. The Trainer itself should not silently suppress gradients as a "feature," as that hides the true contribution of the data.

I hope this clarifies my approach. Let me know what you think!

zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…on steps (huggingface#39659)

* Fix issue[huggingface#38837]: wrong loss scaled in last step of epoch

* chore: trigger CI

* Update src/transformers/trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/transformers/modeling_flash_attention_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: taihang <taihang@U-2RHYVWX7-2207.local>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…on steps (huggingface#39659)

* Fix issue[huggingface#38837]: wrong loss scaled in last step of epoch

* chore: trigger CI

* Update src/transformers/trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/transformers/modeling_flash_attention_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: taihang <taihang@U-2RHYVWX7-2207.local>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…on steps (huggingface#39659)

* Fix issue[huggingface#38837]: wrong loss scaled in last step of epoch

* chore: trigger CI

* Update src/transformers/trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/transformers/modeling_flash_attention_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: taihang <taihang@U-2RHYVWX7-2207.local>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…on steps (huggingface#39659)

* Fix issue[huggingface#38837]: wrong loss scaled in last step of epoch

* chore: trigger CI

* Update src/transformers/trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/transformers/modeling_flash_attention_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: taihang <taihang@U-2RHYVWX7-2207.local>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…on steps (huggingface#39659)

* Fix issue[huggingface#38837]: wrong loss scaled in last step of epoch

* chore: trigger CI

* Update src/transformers/trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/transformers/modeling_flash_attention_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: taihang <taihang@U-2RHYVWX7-2207.local>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…on steps (huggingface#39659)

* Fix issue[huggingface#38837]: wrong loss scaled in last step of epoch

* chore: trigger CI

* Update src/transformers/trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/transformers/modeling_flash_attention_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: taihang <taihang@U-2RHYVWX7-2207.local>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
…on steps (huggingface#39659)

* Fix issue[huggingface#38837]: wrong loss scaled in last step of epoch

* chore: trigger CI

* Update src/transformers/trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update src/transformers/modeling_flash_attention_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: taihang <taihang@U-2RHYVWX7-2207.local>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

5 participants