|
83 | 83 | import re |
84 | 84 | from dataclasses import dataclass, field |
85 | 85 | import functools |
86 | | -import warnings, subprocess, re, inspect, psutil, os, math |
| 86 | +import textwrap |
| 87 | +import warnings, subprocess, inspect, psutil, os, math |
87 | 88 | from unsloth_zoo.utils import Version, get_quant_type |
88 | 89 | from importlib.metadata import version as importlib_version |
89 | 90 | from ..device_type import ( |
@@ -1688,60 +1689,103 @@ def patch_gradient_accumulation_fix(Trainer): |
1688 | 1689 | ) |
1689 | 1690 |
|
1690 | 1691 | # Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps |
1691 | | - if Trainer.training_step.__name__ == "_unsloth_training_step": |
1692 | | - return |
1693 | | - if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters: |
1694 | | - return |
| 1692 | + if not ( |
| 1693 | + Trainer.training_step.__name__ == "_unsloth_training_step" |
| 1694 | + or "num_items_in_batch" |
| 1695 | + not in inspect.signature(Trainer.training_step).parameters |
| 1696 | + ): |
| 1697 | + function = inspect.getsource(Trainer.training_step) |
| 1698 | + where = function.find("def") |
| 1699 | + function = function.split("\n") |
| 1700 | + function = "\n".join(x[where:] for x in function) |
| 1701 | + |
| 1702 | + # Import all variables that need importing |
| 1703 | + import transformers.trainer |
| 1704 | + |
| 1705 | + items_in_trainer = dir(transformers.trainer) |
| 1706 | + good_items = [] |
| 1707 | + for item in items_in_trainer: |
| 1708 | + if item in function: |
| 1709 | + good_items.append(item) |
| 1710 | + exec( |
| 1711 | + "from transformers.trainer import (" |
| 1712 | + + ", ".join(x for x in good_items) |
| 1713 | + + ")", |
| 1714 | + globals(), |
| 1715 | + ) |
1695 | 1716 |
|
1696 | | - function = inspect.getsource(Trainer.training_step) |
1697 | | - where = function.find("def") |
1698 | | - function = function.split("\n") |
1699 | | - function = "\n".join(x[where:] for x in function) |
| 1717 | + # Accelerate does / self.args.gradient_accumulation_steps internally, so if we already |
| 1718 | + # summed it up and did the division before hand, we have to negate it. |
| 1719 | + function = function.replace( |
| 1720 | + "loss *= self.args.gradient_accumulation_steps", |
| 1721 | + "if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps", |
| 1722 | + ) |
| 1723 | + function = function.replace( |
| 1724 | + "def training_step", "def _unsloth_training_step", 1 |
| 1725 | + ) |
1700 | 1726 |
|
1701 | | - # Import all variables that need importing |
1702 | | - import transformers.trainer |
1703 | | - |
1704 | | - items_in_trainer = dir(transformers.trainer) |
1705 | | - good_items = [] |
1706 | | - for item in items_in_trainer: |
1707 | | - if item in function: |
1708 | | - good_items.append(item) |
1709 | | - exec( |
1710 | | - "from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", |
1711 | | - globals(), |
1712 | | - ) |
| 1727 | + # Fix 4.47.0 issue where num_items_in_batch was removed |
| 1728 | + # See https://github.com/huggingface/transformers/pull/35121 |
| 1729 | + function = function.replace( |
| 1730 | + "if self.model_accepts_loss_kwargs:", |
| 1731 | + "if False:", |
| 1732 | + ) |
1713 | 1733 |
|
1714 | | - # Accelerate does / self.args.gradient_accumulation_steps internally, so if we already |
1715 | | - # summed it up and did the division before hand, we have to negate it. |
1716 | | - function = function.replace( |
1717 | | - "loss *= self.args.gradient_accumulation_steps", |
1718 | | - "if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps", |
1719 | | - ) |
1720 | | - function = function.replace("def training_step", "def _unsloth_training_step", 1) |
| 1734 | + # Fix when num_items_in_batch is nothing |
| 1735 | + # https://github.com/huggingface/transformers/pull/35207 |
| 1736 | + function = re.sub( |
| 1737 | + r"else:\n" |
| 1738 | + r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n" |
| 1739 | + r"(.+?)if num_items_in_batch is None\:\n" |
| 1740 | + r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps", |
| 1741 | + "else:\n" |
| 1742 | + "\2if num_items_in_batch is None:\n" |
| 1743 | + "\3loss = loss / self.args.gradient_accumulation_steps\n" |
| 1744 | + "\1self.accelerator.backward(loss, **kwargs)", |
| 1745 | + function, |
| 1746 | + ) |
1721 | 1747 |
|
1722 | | - # Fix 4.47.0 issue where num_items_in_batch was removed |
1723 | | - # See https://github.com/huggingface/transformers/pull/35121 |
1724 | | - function = function.replace( |
1725 | | - "if self.model_accepts_loss_kwargs:", |
1726 | | - "if False:", |
1727 | | - ) |
| 1748 | + exec(function, globals()) |
| 1749 | + Trainer.training_step = _unsloth_training_step |
1728 | 1750 |
|
1729 | | - # Fix when num_items_in_batch is nothing |
1730 | | - # https://github.com/huggingface/transformers/pull/35207 |
1731 | | - function = re.sub( |
1732 | | - r"else:\n" |
1733 | | - r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n" |
1734 | | - r"(.+?)if num_items_in_batch is None\:\n" |
1735 | | - r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps", |
1736 | | - "else:\n" |
1737 | | - "\2if num_items_in_batch is None:\n" |
1738 | | - "\3loss = loss / self.args.gradient_accumulation_steps\n" |
1739 | | - "\1self.accelerator.backward(loss, **kwargs)", |
1740 | | - function, |
1741 | | - ) |
| 1751 | + # Prevent double scaling gradient accumulation |
| 1752 | + # https://github.com/huggingface/transformers/pull/37208 |
| 1753 | + # Patch model_accepts_loss_kwargs detection in Trainer.__init__ |
| 1754 | + if Trainer.__init__.__name__ != "_unsloth___init__": |
| 1755 | + try: |
| 1756 | + init_function = inspect.getsource(Trainer.__init__) |
| 1757 | + except Exception: |
| 1758 | + init_function = "" |
| 1759 | + if init_function is not None: |
| 1760 | + init_function = textwrap.dedent(init_function) |
| 1761 | + |
| 1762 | + # Import all variables that need importing |
| 1763 | + import transformers.trainer |
| 1764 | + |
| 1765 | + items_in_trainer = dir(transformers.trainer) |
| 1766 | + good_items = [] |
| 1767 | + for item in items_in_trainer: |
| 1768 | + if item in init_function: |
| 1769 | + good_items.append(item) |
| 1770 | + exec( |
| 1771 | + "from transformers.trainer import (" |
| 1772 | + + ", ".join(x for x in good_items) |
| 1773 | + + ")", |
| 1774 | + globals(), |
| 1775 | + ) |
1742 | 1776 |
|
1743 | | - exec(function, globals()) |
1744 | | - Trainer.training_step = _unsloth_training_step |
| 1777 | + init_function = init_function.replace( |
| 1778 | + "def __init__", "def _unsloth___init__", 1 |
| 1779 | + ) |
| 1780 | + |
| 1781 | + # Force else branch |
| 1782 | + init_function = re.sub( |
| 1783 | + r'if[\s]+hasattr\(\s*unwrapped_model\s*,\s*"accepts_loss_kwargs"\s*\)\s*:', |
| 1784 | + 'if hasattr(unwrapped_model, "accepts_loss_kwargs") and False:', |
| 1785 | + init_function, |
| 1786 | + ) |
| 1787 | + exec(init_function, globals()) |
| 1788 | + Trainer.__init__ = _unsloth___init__ |
1745 | 1789 |
|
1746 | 1790 |
|
1747 | 1791 | def patch_tokenizer(model, tokenizer): |
|
0 commit comments