Skip to content

Commit 75c8675

Browse files
mmathew23gemini-code-assist[bot]pre-commit-ci[bot]
authored
Patch in tiled mlp (#3584)
* Patch in tiled mlp * Update unsloth/models/llama.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0779d69 commit 75c8675

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

‎unsloth/models/llama.py‎

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3214,9 +3214,15 @@ def patch_peft_model(
32143214
)
32153215
):
32163216
# https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
3217-
mlp_module.forward = types.MethodType(
3218-
_apply_lora_mlp, mlp_module
3219-
)
3217+
if hasattr(mlp_module, "_unsloth_forward"):
3218+
# then we've patched the mlp to use TiledMLP
3219+
mlp_module._unsloth_forward = types.MethodType(
3220+
_apply_lora_mlp, mlp_module
3221+
)
3222+
else:
3223+
mlp_module.forward = types.MethodType(
3224+
_apply_lora_mlp, mlp_module
3225+
)
32203226
n_mlp += 1
32213227
else:
32223228
logger.warning_once(

‎unsloth/models/loader.py‎

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
5858
from unsloth_zoo.utils import Version, _get_dtype
5959
from unsloth_zoo.hf_utils import dtype_from_config
60+
from unsloth_zoo.tiled_mlp import patch_tiled_mlp
6061

6162
transformers_version = Version(transformers_version)
6263
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
@@ -566,6 +567,13 @@ def from_pretrained(
566567
)
567568
# Patch it as well!
568569
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
570+
571+
# Patch Tiled MLP
572+
# to turn on set UNSLOTH_TILED_MLP to "arctic", "target", or "target:{GB}""
573+
patch_tiled_mlp_choice = os.environ.get("UNSLOTH_TILED_MLP", "0")
574+
if patch_tiled_mlp_choice != "0":
575+
patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
576+
569577
return model, tokenizer
570578

571579

@@ -1138,6 +1146,12 @@ def from_pretrained(
11381146
print("Unsloth: Applying QAT to mitigate quantization degradation")
11391147
model = _prepare_model_for_qat(model, qat_scheme)
11401148

1149+
# Patch Tiled MLP
1150+
# to turn on set UNSLOTH_TILED_MLP to "arctic", "target", or "target:{GB}""
1151+
patch_tiled_mlp_choice = os.environ.get("UNSLOTH_TILED_MLP", "0")
1152+
if patch_tiled_mlp_choice != "0":
1153+
patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
1154+
11411155
return model, tokenizer
11421156

11431157

0 commit comments

Comments
 (0)