Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
755 commits
Select commit Hold shift + click to select a range
15a9f88
Update gpt_oss.py
danielhanchen Sep 24, 2025
1ce4f9b
Update gpt_oss.py
danielhanchen Sep 24, 2025
976fcb4
Update gpt_oss.py
danielhanchen Sep 24, 2025
5b05f0d
Update gpt_oss.py
danielhanchen Sep 24, 2025
4bc3d22
Fix Flex Attention autotuning
danielhanchen Sep 24, 2025
0aaa499
Update patching_utils.py
danielhanchen Sep 24, 2025
5b27681
Update patching_utils.py
danielhanchen Sep 24, 2025
ebeb9f8
Update patching_utils.py
danielhanchen Sep 24, 2025
c2c473b
Update mxfp4.py
danielhanchen Sep 24, 2025
ed65c01
Update mxfp4.py
danielhanchen Sep 24, 2025
272a5ea
Update gpt_oss.py
danielhanchen Sep 24, 2025
95772ec
Update attention_sink.py
danielhanchen Sep 24, 2025
bacfc4d
Update patching_utils.py
danielhanchen Sep 24, 2025
8a89348
Update attention_sink.py
danielhanchen Sep 24, 2025
6d9b66b
Update gpt_oss.py
danielhanchen Sep 24, 2025
226c866
prefer_nd_tiling
danielhanchen Sep 24, 2025
2cebcc9
Update patching_utils.py
danielhanchen Sep 24, 2025
f3c5e1f
flex_attention_with_sink
danielhanchen Sep 24, 2025
e485dfc
Compile Flex Attention
danielhanchen Sep 24, 2025
6bd3f70
Update mxfp4.py
danielhanchen Sep 24, 2025
2393dfb
Update mxfp4.py
danielhanchen Sep 24, 2025
44356be
Update mxfp4.py
danielhanchen Sep 24, 2025
d5eebbc
Update mxfp4.py
danielhanchen Sep 24, 2025
93b5b88
Update gpt_oss.py
danielhanchen Sep 24, 2025
5e43a2a
bitsandbytes patch
danielhanchen Sep 24, 2025
2f0acb1
Update bitsandbytes.py
danielhanchen Sep 24, 2025
ebaf9b3
Update gpt_oss.py
danielhanchen Sep 24, 2025
2db0323
Inplace ops
danielhanchen Sep 24, 2025
031d21a
Update gpt_oss.py
danielhanchen Sep 24, 2025
61bb5aa
has_static_cache
danielhanchen Sep 24, 2025
267ab06
Update gpt_oss.py
danielhanchen Sep 24, 2025
bf825b1
Update gpt_oss.py
danielhanchen Sep 24, 2025
bf65ea5
Update gpt_oss.py
danielhanchen Sep 24, 2025
5c22d92
Update gpt_oss.py
danielhanchen Sep 24, 2025
274f7be
Update attention_sink.py
danielhanchen Sep 24, 2025
b2ec9f6
Update gpt_oss.py
danielhanchen Sep 24, 2025
ed572b1
Update gpt_oss.py
danielhanchen Sep 24, 2025
0bdaf45
Update gpt_oss.py
danielhanchen Sep 24, 2025
9fdf256
Update gpt_oss.py
danielhanchen Sep 24, 2025
56f7a73
Update gpt_oss.py
danielhanchen Sep 24, 2025
0c5437e
Update attention_sink.py
danielhanchen Sep 24, 2025
619c462
Update attention_sink.py
danielhanchen Sep 24, 2025
5d87949
Update rl_replacements.py
danielhanchen Sep 24, 2025
7ba642a
Update rl_replacements.py
danielhanchen Sep 24, 2025
040c6f2
Update rl_replacements.py
danielhanchen Sep 24, 2025
96798d8
Update gpt_oss.py
danielhanchen Sep 24, 2025
138f9f7
Update gpt_oss.py
danielhanchen Sep 24, 2025
eb19db9
Update gpt_oss.py
danielhanchen Sep 24, 2025
1f4f0c7
torch compile
danielhanchen Sep 25, 2025
b4afc0a
Update attention_sink.py
danielhanchen Sep 25, 2025
3d2083b
Update common.py
danielhanchen Sep 25, 2025
475a1fa
Update common.py
danielhanchen Sep 25, 2025
a1577f3
Patches
danielhanchen Sep 25, 2025
dc8308b
Compiled mask creation
danielhanchen Sep 25, 2025
15ae568
Update attention_sink.py
danielhanchen Sep 25, 2025
c849066
Update gpt_oss.py
danielhanchen Sep 25, 2025
eb68b54
Update gpt_oss.py
danielhanchen Sep 25, 2025
b4433b0
Revert
danielhanchen Sep 25, 2025
5f0fa7e
Update gpt_oss.py
danielhanchen Sep 25, 2025
274c830
Update gpt_oss.py
danielhanchen Sep 25, 2025
0c52d58
Fix up
danielhanchen Sep 25, 2025
3d9f498
Update attention_sink.py
danielhanchen Sep 25, 2025
dfe12c5
Update attention_sink.py
danielhanchen Sep 25, 2025
02ec222
Update utils.py
danielhanchen Sep 25, 2025
4e57162
Update attention_sink.py
danielhanchen Sep 25, 2025
17a6427
Update attention_sink.py
danielhanchen Sep 25, 2025
2002c9c
Retry
danielhanchen Sep 25, 2025
1ee8d5e
Update gpt_oss.py
danielhanchen Sep 25, 2025
3994e3c
Update gpt_oss.py
danielhanchen Sep 25, 2025
ef81921
Fix Flex
danielhanchen Sep 25, 2025
8cc0e77
Update gpt_oss.py
danielhanchen Sep 25, 2025
31f1624
Update gpt_oss.py
danielhanchen Sep 25, 2025
27fc0a9
Update gpt_oss.py
danielhanchen Sep 25, 2025
e86c541
Update gpt_oss.py
danielhanchen Sep 25, 2025
b4596cc
Update gpt_oss.py
danielhanchen Sep 25, 2025
858b962
Update gpt_oss.py
danielhanchen Sep 25, 2025
b676650
Update gpt_oss.py
danielhanchen Sep 25, 2025
dc1bd58
Update gpt_oss.py
danielhanchen Sep 25, 2025
1fe5a69
Update gpt_oss.py
danielhanchen Sep 25, 2025
524ac7f
Update gpt_oss.py
danielhanchen Sep 25, 2025
bd34939
Update gpt_oss.py
danielhanchen Sep 25, 2025
935ea71
Update gpt_oss.py
danielhanchen Sep 25, 2025
3ea5482
Update gpt_oss.py
danielhanchen Sep 25, 2025
1885f31
Update gpt_oss.py
danielhanchen Sep 25, 2025
ecd9b53
Update gpt_oss.py
danielhanchen Sep 25, 2025
d3b65af
Update gpt_oss.py
danielhanchen Sep 25, 2025
3b75bc9
Update gpt_oss.py
danielhanchen Sep 25, 2025
b43c1b5
Update gpt_oss.py
danielhanchen Sep 25, 2025
db12a8a
Update gpt_oss.py
danielhanchen Sep 25, 2025
889b4fb
Update gpt_oss.py
danielhanchen Sep 25, 2025
f481e2f
Update gpt_oss.py
danielhanchen Sep 25, 2025
c3e3a90
Update gpt_oss.py
danielhanchen Sep 25, 2025
b721c77
Update gpt_oss.py
danielhanchen Sep 25, 2025
7d81867
Update gpt_oss.py
danielhanchen Sep 25, 2025
577a2a0
Update gpt_oss.py
danielhanchen Sep 25, 2025
c0e421b
Update gpt_oss.py
danielhanchen Sep 25, 2025
2605ecb
Update gpt_oss.py
danielhanchen Sep 25, 2025
e850c7d
Update gpt_oss.py
danielhanchen Sep 25, 2025
9af4313
Update gpt_oss.py
danielhanchen Sep 25, 2025
d8a4e50
Update gpt_oss.py
danielhanchen Sep 25, 2025
1b732ba
Update gpt_oss.py
danielhanchen Sep 25, 2025
666f121
Update gpt_oss.py
danielhanchen Sep 25, 2025
b8cfebf
Update gpt_oss.py
danielhanchen Sep 25, 2025
5e88a87
Update gpt_oss.py
danielhanchen Sep 25, 2025
70dfc00
Update gpt_oss.py
danielhanchen Sep 25, 2025
9128339
Update gpt_oss.py
danielhanchen Sep 25, 2025
082cfb7
Update gpt_oss.py
danielhanchen Sep 25, 2025
0f47e5e
Update gpt_oss.py
danielhanchen Sep 25, 2025
d92e62d
Update gpt_oss.py
danielhanchen Sep 25, 2025
5646157
Update gpt_oss.py
danielhanchen Sep 25, 2025
272689b
Update gpt_oss.py
danielhanchen Sep 25, 2025
d10fc7a
Bug fixes
danielhanchen Sep 26, 2025
4396a93
Update patching_utils.py
danielhanchen Sep 26, 2025
ee50724
Update patching_utils.py
danielhanchen Sep 26, 2025
abe89f0
Update patching_utils.py
danielhanchen Sep 26, 2025
edc85ca
Update rl_replacements.py
danielhanchen Sep 26, 2025
efb18b5
Update patching_utils.py
danielhanchen Sep 26, 2025
f16a5a8
Update patching_utils.py
danielhanchen Sep 26, 2025
0dae9dd
Update patching_utils.py
danielhanchen Sep 26, 2025
435de2d
flash attn
danielhanchen Sep 26, 2025
9cd630c
Update gpt_oss.py
danielhanchen Sep 26, 2025
c510029
Update __init__.py
danielhanchen Sep 26, 2025
98080fc
Update attention_sink.py
danielhanchen Sep 26, 2025
5625cfb
Update gpt_oss.py
danielhanchen Sep 26, 2025
62756a8
Update gpt_oss.py
danielhanchen Sep 26, 2025
3f9a9a9
Update gpt_oss.py
danielhanchen Sep 26, 2025
c32eb2e
Update gpt_oss.py
danielhanchen Sep 26, 2025
63a771c
Update gpt_oss.py
danielhanchen Sep 26, 2025
194ff92
Update gpt_oss.py
danielhanchen Sep 26, 2025
be54940
Update gpt_oss.py
danielhanchen Sep 26, 2025
9ebf49f
Update gpt_oss.py
danielhanchen Sep 26, 2025
2b45d36
dropout_p
danielhanchen Sep 26, 2025
7a6941a
Update gpt_oss.py
danielhanchen Sep 26, 2025
588c4f0
Update gpt_oss.py
danielhanchen Sep 26, 2025
aded049
Update attention_sink.py
danielhanchen Sep 26, 2025
33ba6b3
Update gpt_oss.py
danielhanchen Sep 26, 2025
b08753b
Update gpt_oss.py
danielhanchen Sep 26, 2025
9fe8ec0
fix
danielhanchen Sep 26, 2025
5be9e57
Update attention_sink.py
danielhanchen Sep 26, 2025
a218bfc
Update gpt_oss.py
danielhanchen Sep 26, 2025
9fc2694
Update gpt_oss.py
danielhanchen Sep 26, 2025
769301d
Update gpt_oss.py
danielhanchen Sep 26, 2025
d59f62b
Update gpt_oss.py
danielhanchen Sep 26, 2025
92d16d4
Update gpt_oss.py
danielhanchen Sep 26, 2025
0608531
Update gpt_oss.py
danielhanchen Sep 26, 2025
24bb593
Update gpt_oss.py
danielhanchen Sep 26, 2025
c481eb8
Update gpt_oss.py
danielhanchen Sep 26, 2025
68fed93
Update gpt_oss.py
danielhanchen Sep 26, 2025
9ff936f
Update gpt_oss.py
danielhanchen Sep 26, 2025
77343fa
Update gpt_oss.py
danielhanchen Sep 26, 2025
f3e7f8c
Update gpt_oss.py
danielhanchen Sep 26, 2025
5e7e7d3
Update gpt_oss.py
danielhanchen Sep 26, 2025
a508006
Update loss_utils.py
danielhanchen Sep 26, 2025
44e1de7
Update gpt_oss.py
danielhanchen Sep 26, 2025
1079a21
Update gpt_oss.py
danielhanchen Sep 26, 2025
58e5f24
Update gpt_oss.py
danielhanchen Sep 26, 2025
3c61724
Update gpt_oss.py
danielhanchen Sep 26, 2025
bd50ca4
Update gpt_oss.py
danielhanchen Sep 26, 2025
5f8b77c
Update gpt_oss.py
danielhanchen Sep 26, 2025
f2fe3db
Update gpt_oss.py
danielhanchen Sep 26, 2025
04bbc07
Update loss_utils.py
danielhanchen Sep 26, 2025
cb16066
Update gpt_oss.py
danielhanchen Sep 26, 2025
75d7829
Update gpt_oss.py
danielhanchen Sep 26, 2025
679e882
Update gpt_oss.py
danielhanchen Sep 26, 2025
4b61795
Merge branch 'main' into nightly
danielhanchen Sep 26, 2025
c37dff1
Merge branch 'main' into nightly
danielhanchen Sep 26, 2025
b61346a
Merge branch 'main' into nightly
danielhanchen Sep 26, 2025
a8d6aa8
Merge branch 'main' into nightly
danielhanchen Sep 28, 2025
5225692
Update gpt_oss.py
danielhanchen Sep 28, 2025
02326ab
Update gpt_oss.py
danielhanchen Sep 28, 2025
2210555
Update gpt_oss.py
danielhanchen Sep 30, 2025
f7406a4
Update gpt_oss.py
danielhanchen Sep 30, 2025
7020561
Update gpt_oss.py
danielhanchen Sep 30, 2025
e316226
Update gpt_oss.py
danielhanchen Sep 30, 2025
55a0f94
Update gpt_oss.py
danielhanchen Sep 30, 2025
d241d8d
Versioning
danielhanchen Sep 30, 2025
8d752f6
Merge branch 'main' into nightly
danielhanchen Oct 1, 2025
7c40a85
Update saving_utils.py
danielhanchen Oct 5, 2025
114feed
Update saving_utils.py
danielhanchen Oct 5, 2025
5bdbffe
Update saving_utils.py
danielhanchen Oct 5, 2025
79115db
Update saving_utils.py
danielhanchen Oct 5, 2025
51e3889
Update saving_utils.py
danielhanchen Oct 5, 2025
3284083
Update saving_utils.py
danielhanchen Oct 5, 2025
289abf2
Update saving_utils.py
danielhanchen Oct 5, 2025
efe6d76
Update saving_utils.py
danielhanchen Oct 5, 2025
2f5e342
Fix Gemma 3
danielhanchen Oct 5, 2025
3237c4b
Update misc.py
danielhanchen Oct 5, 2025
dc3e28e
Merge branch 'main' into nightly
danielhanchen Oct 5, 2025
22b3cb6
Merge branch 'main' into nightly
danielhanchen Oct 14, 2025
5beb515
Merge branch 'main' into nightly
danielhanchen Oct 16, 2025
bd43a5b
Update rl_environments.py
danielhanchen Oct 17, 2025
9571b67
Update pyproject.toml
danielhanchen Oct 17, 2025
f789e3b
Update rl_environments.py
danielhanchen Oct 17, 2025
c146ca2
Update __init__.py
danielhanchen Oct 17, 2025
5012df2
Merge branch 'main' into nightly
danielhanchen Oct 17, 2025
80f4b15
Merge branch 'main' into nightly
danielhanchen Oct 17, 2025
6857125
Update empty_model.py
danielhanchen Oct 17, 2025
49f3cd0
Update empty_model.py
danielhanchen Oct 17, 2025
7642fbc
Update empty_model.py
danielhanchen Oct 17, 2025
a6a9a53
Merge branch 'main' into nightly
danielhanchen Oct 17, 2025
565d37f
Merge branch 'main' into nightly
danielhanchen Oct 17, 2025
068142c
Merge branch 'main' into nightly
danielhanchen Oct 19, 2025
9b06516
Merge branch 'main' into nightly
danielhanchen Oct 19, 2025
33a55fc
Merge branch 'main' into nightly
danielhanchen Oct 20, 2025
9f9fad5
Update empty_model.py
danielhanchen Oct 20, 2025
c62f0db
Device type
danielhanchen Oct 20, 2025
44539dc
Update vllm_utils.py
danielhanchen Oct 20, 2025
c7f1a85
Update compiler.py
danielhanchen Oct 20, 2025
d98b8dd
Update empty_model.py
danielhanchen Oct 20, 2025
7dccb4f
Update vllm_utils.py
danielhanchen Oct 20, 2025
96b12f6
Update empty_model.py
danielhanchen Oct 20, 2025
b900605
Fixes
danielhanchen Oct 20, 2025
be24a86
Update empty_model.py
danielhanchen Oct 20, 2025
09a56e1
Update empty_model.py
danielhanchen Oct 20, 2025
dd3f5a9
Update __init__.py
danielhanchen Oct 20, 2025
5e914a5
Update vllm_utils.py
danielhanchen Oct 20, 2025
d45333a
Update vllm_utils.py
danielhanchen Oct 20, 2025
aef0696
Update rl_environments.py
danielhanchen Oct 20, 2025
4bbede7
Update cross_entropy_loss.py
danielhanchen Oct 20, 2025
03adb63
Update vllm_utils.py
danielhanchen Oct 20, 2025
4e0786b
Update vllm_utils.py
danielhanchen Oct 20, 2025
21a4404
Update rl_environments.py
danielhanchen Oct 20, 2025
e63cd7b
Update vllm_utils.py
danielhanchen Oct 20, 2025
855d572
Merge branch 'main' into nightly
danielhanchen Oct 20, 2025
60b28fa
Merge branch 'main' into nightly
danielhanchen Oct 22, 2025
f34d525
Merge branch 'main' into nightly
danielhanchen Oct 23, 2025
26fe13e
Merge branch 'main' into nightly
danielhanchen Oct 27, 2025
ac90015
Merge branch 'main' into nightly
danielhanchen Oct 27, 2025
113c8d3
Merge branch 'main' into nightly
danielhanchen Oct 30, 2025
bb81b69
Qwen3 VL vLLM (#324)
Datta0 Oct 31, 2025
0632308
Update __init__.py
danielhanchen Oct 31, 2025
fe09bfd
Update __init__.py
danielhanchen Oct 31, 2025
a5102af
Update __init__.py
danielhanchen Oct 31, 2025
d2fcf41
Update __init__.py
danielhanchen Oct 31, 2025
8b07dcf
Update __init__.py
danielhanchen Oct 31, 2025
6d43f0d
Update __init__.py
danielhanchen Oct 31, 2025
ad18827
Update __init__.py
danielhanchen Oct 31, 2025
c00681e
Merge branch 'main' into nightly
danielhanchen Nov 2, 2025
9321399
Update vllm_utils.py
danielhanchen Nov 2, 2025
32ca2c0
Update vllm_utils.py
danielhanchen Nov 2, 2025
45a2f69
Update pyproject.toml
danielhanchen Nov 2, 2025
6c6c4e8
Update vllm_utils.py
danielhanchen Nov 2, 2025
3a1a097
Update vllm_utils.py
danielhanchen Nov 2, 2025
ed24866
Update vllm_utils.py
danielhanchen Nov 2, 2025
c9b3186
Update vllm_utils.py
danielhanchen Nov 3, 2025
64395ac
Update vllm_utils.py
danielhanchen Nov 3, 2025
60de923
Update vllm_utils.py
danielhanchen Nov 3, 2025
0b339f4
Update __init__.py
danielhanchen Nov 3, 2025
5ae18ab
Update compiler.py
danielhanchen Nov 3, 2025
dac460f
Update __init__.py
danielhanchen Nov 3, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies = [
"wheel>=0.42.0",
"numpy",
"accelerate>=0.34.1",
"trl>=0.18.2,!=0.19.0,<=0.23.0",
"trl>=0.18.2,!=0.19.0,<=0.24.0",
"peft>=0.7.1,!=0.11.0",
"protobuf",
"huggingface_hub>=0.34.0",
Expand Down Expand Up @@ -70,7 +70,7 @@ huggingface = [
"wheel>=0.42.0",
"numpy",
"accelerate>=0.34.1",
"trl>=0.18.2,!=0.19.0,<=0.23.0",
"trl>=0.18.2,!=0.19.0,<=0.24.0",
"peft>=0.7.1,!=0.11.0",
"protobuf",
"huggingface_hub>=0.34.0",
Expand Down
37 changes: 18 additions & 19 deletions unsloth_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__version__ = "2025.10.13"
__version__ = "2025.11.1"

import os
import warnings
Expand Down Expand Up @@ -98,7 +98,7 @@
for key in ("PYTORCH_CUDA_ALLOC_CONF", "PYTORCH_HIP_ALLOC_CONF", "PYTORCH_ALLOC_CONF",):
if "expandable_segments:True" in os.environ.get(key, ""):
warnings.warn(
"Unsloth: `UNSLOTH_VLLM_STANDBY` is on, but requires `expandable_segments` to be off.\n"\
"Unsloth: `UNSLOTH_VLLM_STANDBY` is on, but requires `expandable_segments` to be off. "\
"We will remove `expandable_segments`.",
stacklevel = 2,
)
Expand All @@ -123,6 +123,20 @@ def delete_key(key):
delete_key("PYTORCH_ALLOC_CONF")
pass

# Suppress WARNING:torchao:Skipping import of cpp extensions due to incompatible torch version 2.7.0+cu126 for torchao version 0.14.1
# Please see https://github.com/pytorch/ao/issues/2919 for more info
import logging
torchao_logger = logging.getLogger("torchao")
# Ignore logging messages
class HideLoggingMessage(logging.Filter):
__slots__ = "text",
def __init__(self, text): self.text = text
def filter(self, x): return not (self.text in x.getMessage())
pass
torchao_logger.addFilter(HideLoggingMessage("Skipping import"))
del logging, torchao_logger, HideLoggingMessage

# Get device types and other variables
from .device_type import (
is_hip,
get_device_type,
Expand Down Expand Up @@ -173,6 +187,7 @@ def delete_key(key):
os.environ["UNSLOTH_ZOO_IS_PRESENT"] = "1"
del os


from .temporary_patches import (
encode_conversations_with_harmony,
)
Expand All @@ -193,22 +208,6 @@ def delete_key(key):
# This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
from pydantic.warnings import UnsupportedFieldAttributeWarning
warnings.filterwarnings(action = "ignore", category = UnsupportedFieldAttributeWarning)
del warnings, UnsupportedFieldAttributeWarning
except:
pass

import logging
# Ignore logging messages
class HideLoggingMessage(logging.Filter):
__slots__ = "text",
def __init__(self, text): self.text = text
def filter(self, x): return not (self.text in x.getMessage())
pass

# Skipping import of cpp extensions due to incompatible torch version
try:
from torchao import logger as torchao_logger
torchao_logger.addFilter(HideLoggingMessage("Skipping import"))
del torchao_logger
except:
pass
del HideLoggingMessage
18 changes: 14 additions & 4 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,10 +1310,16 @@ def apply_fused_lm_head(forward, module = None):
spaces = finder[0][3]
replacement = cross_entropy_replacement.strip().split("\n")
replacement = "\n".join((len(spaces)-4)*" " + x for x in replacement)
replacement = \
"logits = self.lm_head(hidden_states[:, slice_indices, :]) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS\n" + \
(len(spaces)-4)*" " + "loss = None\n" + \
replacement + "\n"
if "slice_indices" in forward:
replacement = \
"logits = self.lm_head(hidden_states[:, slice_indices, :]) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS\n" + \
(len(spaces)-4)*" " + "loss = None\n" + \
replacement + "\n"
else:
replacement = \
"logits = self.lm_head(hidden_states) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS\n" + \
(len(spaces)-4)*" " + "loss = None\n" + \
replacement + "\n"
try:
forward = regex.sub(
cross_entropy_find,
Expand All @@ -1329,6 +1335,10 @@ def apply_fused_lm_head(forward, module = None):
"logits = self.lm_head(hidden_states[:, slice_indices, :]) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS",
"logits = outputs.logits",
)
forward = forward.replace(
"logits = self.lm_head(hidden_states) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS",
"logits = outputs.logits",
)
# Fix vocab_size = (vocab_size=
forward = regex.sub(
r"vocab_size[ ]{0,}=[ ]{0,}\(vocab_size[ ]{0,}=",
Expand Down
117 changes: 87 additions & 30 deletions unsloth_zoo/empty_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def copy_attributes(original_model, new_model):
pass


@torch.inference_mode()
@torch.inference_mode
def create_empty_causal_lm(config, dtype = torch.float16):
# All Unsloth Zoo code licensed under LGPLv3
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -293,10 +293,10 @@ def _set_config_attrs(config_obj, attrs_to_set):
pass


@torch.inference_mode()
@torch.inference_mode
def create_empty_vision_model(config, dtype = torch.float16):
# All Unsloth Zoo code licensed under LGPLv3
model_type = config.model_type
model_type = get_model_type(config)

from transformers.models.siglip.modeling_siglip import SiglipVisionModel

Expand Down Expand Up @@ -360,6 +360,8 @@ def _init_weights(self, module):
# Set minimal sizes for different model types
if model_type == "qwen2_5_vl":
new_config.vision_config.out_hidden_size = 1
elif model_type == "qwen3_vl":
new_config.vision_config.out_hidden_size = 1


num_layers = max(text_layers, vision_layers)
Expand All @@ -368,7 +370,7 @@ def _init_weights(self, module):
return new_model, original_meta_model, num_layers


@torch.inference_mode()
@torch.inference_mode
def create_empty_model(config, dtype = torch.float16, is_vision_model = False):
# All Unsloth Zoo code licensed under LGPLv3

Expand All @@ -384,7 +386,7 @@ def create_empty_model(config, dtype = torch.float16, is_vision_model = False):
return new_model, original_meta_model, num_layers, layer_names


@torch.inference_mode()
@torch.inference_mode
def set_additional_modules(new_model, quant_state_dict, config):
if hasattr(new_model, "language_model"):
language_model = new_model.language_model
Expand All @@ -404,16 +406,23 @@ def set_additional_modules(new_model, quant_state_dict, config):
# padding_idx = pad_token_id,
# )
# we cannot use the normal embedding init because gemma3 uses Gemma3TextScaledWordEmbedding which wraps around nn.Embedding and has a scaling factor. This new init ensures that we respect the forward from original model.
num_embeddings, embedding_dim = quant_state_dict[embed_tokens_key].shape
embeddings = quant_state_dict[embed_tokens_key]
if isinstance(embeddings, torch.Tensor):
# in the newer vLLM versions, this seems to return a tensor which can't be assigned to embedding weight
# we need to convert that to nn.Paramter and then pass it on
embeddings = torch.nn.Parameter(embeddings, requires_grad = False)
language_model.embed_tokens.weight = embeddings
language_model.embed_tokens.padding_idx = pad_token_id
language_model.embed_tokens.num_embeddings = num_embeddings
language_model.embed_tokens.embedding_dim = embedding_dim
def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False):
num_embeddings, embedding_dim = quant_state_dict[embed_tokens_key].shape
embeddings = quant_state_dict[embed_tokens_key]
if isinstance(embeddings, torch.Tensor):
# in the newer vLLM versions, this seems to return a tensor which can't be assigned to embedding weight
# we need to convert that to nn.Paramter and then pass it on
embeddings = torch.nn.Parameter(embeddings, requires_grad = requires_grad)
module.weight = embeddings
module.padding_idx = pad_token_id
module.num_embeddings = num_embeddings
module.embedding_dim = embedding_dim

set_embedding(language_model.embed_tokens, embed_tokens_key, pad_token_id) # This sets the embedding that we generally find in language (sub)model

if 'model.visual.pos_embed.weight' in quant_state_dict:
# This is to handle visual embeddings in Qwen 3 VL
set_embedding(new_model.model.visual.pos_embed, 'model.visual.pos_embed.weight', None, requires_grad=False)

# Norm
norm_key = f"{language_model_prefix}.norm.weight"
Expand Down Expand Up @@ -461,22 +470,25 @@ def set_additional_modules(new_model, quant_state_dict, config):
language_model.tie_weights()

# Process additional keys
# For eg, `merger` in qwen2.5-vl or probably any other projection modules
# For any layers that are potentially in non layered components.
# Preferably norms, embeddings and convolution type layers.
additional_keys = set(
x for x in quant_state_dict.keys()
if not any(substr in x for substr in ("layers", "blocks", embed_tokens_key, norm_key, "lm_head"))
if not any(substr in x for substr in ("layers", "blocks", embed_tokens_key, norm_key, "lm_head", "mlp", "linear", "list"))
)
print(f'Performing substitution for {additional_keys=}')

for key in additional_keys:
replaced_key = re.sub(r"\.(\d+)\.", r"[\1].", key)
# sometimes it can be in new_model.model. instead of new_model.
for prefix in ['new_', 'new_model.']:
for suffix in ['', '.data']:
try:
exec(f"{prefix}{replaced_key}{suffix} = quant_state_dict[key]")
break
except:
continue
try:
val = quant_state_dict[key]
if isinstance(val, torch.Tensor):
val = torch.nn.Parameter(val,requires_grad=False)
exec(f"{prefix}{key} = val")
break
except:
continue

pass
pass
Expand Down Expand Up @@ -531,6 +543,9 @@ def get_model_layer_config(return_non_layered=True):
# Mistral3 vision norms
"model.vision_tower.transformer.layers.{kk}.attention_norm",
"model.vision_tower.transformer.layers.{kk}.ffn_norm",

# qwen3 vl
"model.visual.deepstack_merger_list.{kk}.norm",
},
'vision_layers': {

Expand Down Expand Up @@ -595,8 +610,15 @@ def get_model_layer_config(return_non_layered=True):
"model.vision_tower.transformer.layers.{kk}.feed_forward.up_proj",
"model.vision_tower.transformer.layers.{kk}.feed_forward.down_proj",

# qwen 3 vl
"model.visual.blocks.{kk}.mlp.linear_fc1",
"model.visual.blocks.{kk}.mlp.linear_fc2",

},
'additional_layers': {
# Primarily for layers that are neither language decoder layers or vision transformer layers/blocks.
# Basically anything that is a merger, convertor or bridge in between. Preferably iterable layers

"model.visual.merger.mlp.{kk}",
"model.visual.merger.mlp.{kk}",
'model.language_model.model.layers.{kk}.cross_attn_mlp_gate',
Expand All @@ -605,8 +627,14 @@ def get_model_layer_config(return_non_layered=True):

# Mistral3
"model.multi_modal_projector.patch_merger.merging_layer",
"model.multi_modal_projector.linear_1",
"model.multi_modal_projector.linear_2",
"model.multi_modal_projector.linear_{kk}",
# "model.multi_modal_projector.linear_2",

# qwen 3 vl
"model.visual.deepstack_merger_list.{kk}.linear_fc1",
"model.visual.deepstack_merger_list.{kk}.linear_fc2",
"model.visual.merger.linear_fc{kk}",

},
"non_layered_components":{
# we do not handle quantization for these layers yet
Expand Down Expand Up @@ -636,12 +664,23 @@ def get_model_layer_config(return_non_layered=True):
"model.vision_tower.patch_positional_embedding",
"model.vision_tower.patch_conv",
"model.vision_tower.ln_pre",

# qwen 3 vl
"model.visual.pos_embed",
"model.visual.merger.norm",
}
}

# Convert sets to sorted lists for deterministic order
return {key: sorted(list(value)) for key, value in layer_templates.items() if key!='non_layered_components' or return_non_layered}

def get_model_type(config):
model_type = getattr(config, "model_type", "causal_lm")
if hasattr(config, "vision_config"):
# vllm curretly seems to be having qwen 2.5 vl model type as qwen2_5_vl_text for some reason
# aka vllm_config.model_type is qwen2_5_vl_text but config.vision_config.model_type is qwen2_5_vl
model_type = getattr(config.vision_config, "model_type", model_type)
return model_type

def get_model_layer_counts(config):
"""
Expand All @@ -653,7 +692,7 @@ def get_model_layer_counts(config):
Returns:
int or dict: Number of layers (int for causal_lm, dict for VL models)
"""
model_type = getattr(config, "model_type", "causal_lm")
model_type = get_model_type(config)

if model_type == "mllama":
return {
Expand All @@ -664,7 +703,13 @@ def get_model_layer_counts(config):
elif model_type == "qwen2_5_vl":
return {
"text_layers": getattr(config, "num_hidden_layers", 32),
"vision_layers": getattr(config.vision_config, "num_hidden_layers", 32),
"vision_layers": getattr(config.vision_config, "depth", 32),
}
elif model_type == "qwen3_vl":
return {
"text_layers": getattr(config, "num_hidden_layers", 36),
"vision_layers": getattr(config.vision_config, "depth", 27),
"deepstack_layers": getattr(config.vision_config, "deepstack_depth", 3),
}
elif model_type == "gemma3":
return {
Expand Down Expand Up @@ -699,7 +744,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat
a model-specific configuration. This approach is more robust and avoids
failures by correctly identifying layer paths and parameters.
"""
model_type = vllm_internals.config.model_type
model_type = get_model_type(vllm_internals.config)
layer_config = get_model_layer_config()

all_layered_templates = (
Expand All @@ -724,7 +769,9 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat

if layer_module is not None:
if "qkv" in layer_path:
if model_type == "qwen2_5_vl":
if model_type in ("qwen2_5_vl", "qwen3_vl"):
# If the HF model too prefers having merged qkv, we do this
# This is evident in qwen-2.5-vl and qwen-3-vl so far.
get_state_dict(layer_path, 0, state_dict, layer_module, slice_weights=False)
else:
get_state_dict(f"{layer_path.replace('qkv_proj', 'q_proj')}", 0, state_dict, layer_module)
Expand Down Expand Up @@ -780,3 +827,13 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat
weight = component._linear.weight
state_dict[f'{path}.weight'] = weight.reshape(weight.shape[0], 3, 14, 14)
quant_state_dict[f'{path}.weight'] = state_dict[f'{path}.weight']

# for qwen3 vl, only needed in specific vllm which had this PR which uses Linear instead of Conv3d
# https://github.com/vllm-project/vllm/pull/27418
path = "model.visual.patch_embed.proj"
vision_config = vllm_internals.config.vision_config
component = _get_nested_attr(vllm_internals, path)
if component is not None:
weight = component.weight
state_dict[f'{path}.weight'] = weight.reshape(vision_config.hidden_size, vision_config.in_channels, vision_config.temporal_patch_size, vision_config.patch_size, vision_config.patch_size)
quant_state_dict[f'{path}.weight'] = state_dict[f'{path}.weight']
Loading