Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3210,9 +3210,15 @@ def patch_peft_model(
)
):
# https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
mlp_module.forward = types.MethodType(
_apply_lora_mlp, mlp_module
)
if hasattr(mlp_module, "_unsloth_forward"):
# then we've patched the mlp to use TiledMLP
mlp_module._unsloth_forward = types.MethodType(
_apply_lora_mlp, mlp_module
)
else:
mlp_module.forward = types.MethodType(
_apply_lora_mlp, mlp_module
)
n_mlp += 1
else:
logger.warning_once(
Expand Down
14 changes: 14 additions & 0 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from unsloth_zoo.utils import Version, _get_dtype
from unsloth_zoo.hf_utils import dtype_from_config
from unsloth_zoo.tiled_mlp import patch_tiled_mlp

transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
Expand Down Expand Up @@ -566,6 +567,13 @@ def from_pretrained(
)
# Patch it as well!
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)

# Patch Tiled MLP
# to turn on set UNSLOTH_TILED_MLP to "arctic", "target", or "target:{GB}""
patch_tiled_mlp_choice = os.environ.get("UNSLOTH_TILED_MLP", "0")
if patch_tiled_mlp_choice != "0":
patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
Comment on lines +571 to +575
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for patching Tiled MLP is duplicated in FastModel.from_pretrained (lines 1149-1153). To improve maintainability and avoid code duplication, consider extracting this logic into a private helper function within this module and calling it from both from_pretrained methods.


return model, tokenizer


Expand Down Expand Up @@ -1138,6 +1146,12 @@ def from_pretrained(
print("Unsloth: Applying QAT to mitigate quantization degradation")
model = _prepare_model_for_qat(model, qat_scheme)

# Patch Tiled MLP
# to turn on set UNSLOTH_TILED_MLP to "arctic", "target", or "target:{GB}""
patch_tiled_mlp_choice = os.environ.get("UNSLOTH_TILED_MLP", "0")
if patch_tiled_mlp_choice != "0":
patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
Comment on lines +1149 to +1153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a duplicate of the Tiled MLP patching logic found in FastLanguageModel.from_pretrained (lines 571-575). This should be refactored into a shared helper function to avoid code duplication.


return model, tokenizer


Expand Down