Skip to content

Commit 38bdbed

Browse files
mmathew23pre-commit-ci[bot]danielhanchengemini-code-assist[bot]
authored
fix qwen3 vl gradient accumulation (#3598)
* fix qwen3 vl gradient accumulation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update unsloth/models/_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ac7a478 commit 38bdbed

File tree

2 files changed

+93
-50
lines changed

2 files changed

+93
-50
lines changed

‎unsloth/models/_utils.py‎

Lines changed: 93 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@
8383
import re
8484
from dataclasses import dataclass, field
8585
import functools
86-
import warnings, subprocess, re, inspect, psutil, os, math
86+
import textwrap
87+
import warnings, subprocess, inspect, psutil, os, math
8788
from unsloth_zoo.utils import Version, get_quant_type
8889
from importlib.metadata import version as importlib_version
8990
from ..device_type import (
@@ -1688,60 +1689,103 @@ def patch_gradient_accumulation_fix(Trainer):
16881689
)
16891690

16901691
# Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps
1691-
if Trainer.training_step.__name__ == "_unsloth_training_step":
1692-
return
1693-
if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters:
1694-
return
1692+
if not (
1693+
Trainer.training_step.__name__ == "_unsloth_training_step"
1694+
or "num_items_in_batch"
1695+
not in inspect.signature(Trainer.training_step).parameters
1696+
):
1697+
function = inspect.getsource(Trainer.training_step)
1698+
where = function.find("def")
1699+
function = function.split("\n")
1700+
function = "\n".join(x[where:] for x in function)
1701+
1702+
# Import all variables that need importing
1703+
import transformers.trainer
1704+
1705+
items_in_trainer = dir(transformers.trainer)
1706+
good_items = []
1707+
for item in items_in_trainer:
1708+
if item in function:
1709+
good_items.append(item)
1710+
exec(
1711+
"from transformers.trainer import ("
1712+
+ ", ".join(x for x in good_items)
1713+
+ ")",
1714+
globals(),
1715+
)
16951716

1696-
function = inspect.getsource(Trainer.training_step)
1697-
where = function.find("def")
1698-
function = function.split("\n")
1699-
function = "\n".join(x[where:] for x in function)
1717+
# Accelerate does / self.args.gradient_accumulation_steps internally, so if we already
1718+
# summed it up and did the division before hand, we have to negate it.
1719+
function = function.replace(
1720+
"loss *= self.args.gradient_accumulation_steps",
1721+
"if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps",
1722+
)
1723+
function = function.replace(
1724+
"def training_step", "def _unsloth_training_step", 1
1725+
)
17001726

1701-
# Import all variables that need importing
1702-
import transformers.trainer
1703-
1704-
items_in_trainer = dir(transformers.trainer)
1705-
good_items = []
1706-
for item in items_in_trainer:
1707-
if item in function:
1708-
good_items.append(item)
1709-
exec(
1710-
"from transformers.trainer import (" + ", ".join(x for x in good_items) + ")",
1711-
globals(),
1712-
)
1727+
# Fix 4.47.0 issue where num_items_in_batch was removed
1728+
# See https://github.com/huggingface/transformers/pull/35121
1729+
function = function.replace(
1730+
"if self.model_accepts_loss_kwargs:",
1731+
"if False:",
1732+
)
17131733

1714-
# Accelerate does / self.args.gradient_accumulation_steps internally, so if we already
1715-
# summed it up and did the division before hand, we have to negate it.
1716-
function = function.replace(
1717-
"loss *= self.args.gradient_accumulation_steps",
1718-
"if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps",
1719-
)
1720-
function = function.replace("def training_step", "def _unsloth_training_step", 1)
1734+
# Fix when num_items_in_batch is nothing
1735+
# https://github.com/huggingface/transformers/pull/35207
1736+
function = re.sub(
1737+
r"else:\n"
1738+
r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n"
1739+
r"(.+?)if num_items_in_batch is None\:\n"
1740+
r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps",
1741+
"else:\n"
1742+
"\2if num_items_in_batch is None:\n"
1743+
"\3loss = loss / self.args.gradient_accumulation_steps\n"
1744+
"\1self.accelerator.backward(loss, **kwargs)",
1745+
function,
1746+
)
17211747

1722-
# Fix 4.47.0 issue where num_items_in_batch was removed
1723-
# See https://github.com/huggingface/transformers/pull/35121
1724-
function = function.replace(
1725-
"if self.model_accepts_loss_kwargs:",
1726-
"if False:",
1727-
)
1748+
exec(function, globals())
1749+
Trainer.training_step = _unsloth_training_step
17281750

1729-
# Fix when num_items_in_batch is nothing
1730-
# https://github.com/huggingface/transformers/pull/35207
1731-
function = re.sub(
1732-
r"else:\n"
1733-
r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n"
1734-
r"(.+?)if num_items_in_batch is None\:\n"
1735-
r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps",
1736-
"else:\n"
1737-
"\2if num_items_in_batch is None:\n"
1738-
"\3loss = loss / self.args.gradient_accumulation_steps\n"
1739-
"\1self.accelerator.backward(loss, **kwargs)",
1740-
function,
1741-
)
1751+
# Prevent double scaling gradient accumulation
1752+
# https://github.com/huggingface/transformers/pull/37208
1753+
# Patch model_accepts_loss_kwargs detection in Trainer.__init__
1754+
if Trainer.__init__.__name__ != "_unsloth___init__":
1755+
try:
1756+
init_function = inspect.getsource(Trainer.__init__)
1757+
except Exception:
1758+
init_function = ""
1759+
if init_function is not None:
1760+
init_function = textwrap.dedent(init_function)
1761+
1762+
# Import all variables that need importing
1763+
import transformers.trainer
1764+
1765+
items_in_trainer = dir(transformers.trainer)
1766+
good_items = []
1767+
for item in items_in_trainer:
1768+
if item in init_function:
1769+
good_items.append(item)
1770+
exec(
1771+
"from transformers.trainer import ("
1772+
+ ", ".join(x for x in good_items)
1773+
+ ")",
1774+
globals(),
1775+
)
17421776

1743-
exec(function, globals())
1744-
Trainer.training_step = _unsloth_training_step
1777+
init_function = init_function.replace(
1778+
"def __init__", "def _unsloth___init__", 1
1779+
)
1780+
1781+
# Force else branch
1782+
init_function = re.sub(
1783+
r'if[\s]+hasattr\(\s*unwrapped_model\s*,\s*"accepts_loss_kwargs"\s*\)\s*:',
1784+
'if hasattr(unwrapped_model, "accepts_loss_kwargs") and False:',
1785+
init_function,
1786+
)
1787+
exec(init_function, globals())
1788+
Trainer.__init__ = _unsloth___init__
17451789

17461790

17471791
def patch_tokenizer(model, tokenizer):

‎unsloth/models/llama.py‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3014,7 +3014,6 @@ def get_peft_model(
30143014
print("Unsloth: Applying QAT to mitigate quantization degradation")
30153015
model = FastLlamaModel._prepare_for_qat(model, qat_scheme)
30163016

3017-
30183017
model._saved_temp_tokenizer = _saved_temp_tokenizer
30193018

30203019
model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)

0 commit comments

Comments
 (0)