Skip to content

Commit d7fb886

Browse files
Bug fixes (#331)
* Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update __init__.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update compiler.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Dannightly (#304) * gpt oss inference fix * gpt oss fix bf16 * gpt oss fix bf16 * gpt oss fix bf16 * gpt oss fix bf16 * gpt oss fix bf16 * gpt oss fix bf16 --------- Co-authored-by: DoubleMathew <mmathew23@gmail.com> * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Fix Flex Attention autotuning * Update patching_utils.py * Update patching_utils.py * Update patching_utils.py * Update mxfp4.py * Update mxfp4.py * Update gpt_oss.py * Update attention_sink.py * Update patching_utils.py * Update attention_sink.py * Update gpt_oss.py * prefer_nd_tiling * Update patching_utils.py * flex_attention_with_sink * Compile Flex Attention * Update mxfp4.py * Update mxfp4.py * Update mxfp4.py * Update mxfp4.py * Update gpt_oss.py * bitsandbytes patch * Update bitsandbytes.py * Update gpt_oss.py * Inplace ops * Update gpt_oss.py * has_static_cache * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update attention_sink.py * Update attention_sink.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * torch compile * Update attention_sink.py * Update common.py * Update common.py * Patches * Compiled mask creation * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * Revert * Update gpt_oss.py * Update gpt_oss.py * Fix up * Update attention_sink.py * Update attention_sink.py * Update utils.py * Update attention_sink.py * Update attention_sink.py * Retry * Update gpt_oss.py * Update gpt_oss.py * Fix Flex * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Bug fixes * Update patching_utils.py * Update patching_utils.py * Update patching_utils.py * Update rl_replacements.py * Update patching_utils.py * Update patching_utils.py * Update patching_utils.py * flash attn * Update gpt_oss.py * Update __init__.py * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * dropout_p * Update gpt_oss.py * Update gpt_oss.py * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * fix * Update attention_sink.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update loss_utils.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update loss_utils.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Update gpt_oss.py * Versioning * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Update saving_utils.py * Fix Gemma 3 * Update misc.py * Update rl_environments.py * Update pyproject.toml * Update rl_environments.py * Update __init__.py * Update empty_model.py * Update empty_model.py * Update empty_model.py * Update empty_model.py * Device type * Update vllm_utils.py * Update compiler.py * Update empty_model.py * Update vllm_utils.py * Update empty_model.py * Fixes * Update empty_model.py * Update empty_model.py * Update __init__.py --------- Co-authored-by: DoubleMathew <mmathew23@gmail.com>
1 parent 677086d commit d7fb886

File tree

7 files changed

+104
-26
lines changed

7 files changed

+104
-26
lines changed

‎unsloth_zoo/__init__.py‎

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
# You should have received a copy of the GNU Lesser General Public License
1515
# along with this program. If not, see <https://www.gnu.org/licenses/>.
1616

17-
__version__ = "2025.10.7"
17+
__version__ = "2025.10.8"
1818

1919
import os
20+
import warnings
2021
# Hugging Face Hub faster downloads
2122
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
2223
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
@@ -101,3 +102,14 @@
101102
execute_with_time_limit,
102103
Benchmarker,
103104
)
105+
106+
# Top some pydantic warnings
107+
try:
108+
# pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True
109+
# was provided to the `Field()` function, which has no effect in the context it was used.
110+
# 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment.
111+
# This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
112+
from pydantic.warnings import UnsupportedFieldAttributeWarning
113+
warnings.filterwarnings(action = "ignore", category = UnsupportedFieldAttributeWarning)
114+
except:
115+
pass

‎unsloth_zoo/compiler.py‎

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2035,7 +2035,14 @@ def unsloth_compile_transformers(
20352035
except ModuleNotFoundError:
20362036
return
20372037
modeling_file = eval(model_location)
2038-
if hasattr(modeling_file, "__UNSLOTH_PATCHED__"): return
2038+
if hasattr(modeling_file, "__UNSLOTH_PATCHED__"):
2039+
# Get __UNSLOTH_SUPPORTS_SDPA__
2040+
if hasattr(modeling_file, "__UNSLOTH_SUPPORTS_SDPA__"):
2041+
if supports_sdpa is not None:
2042+
assert(type(supports_sdpa) is list and len(supports_sdpa) == 1)
2043+
supports_sdpa[0] = modeling_file.__UNSLOTH_SUPPORTS_SDPA__
2044+
return
2045+
pass
20392046

20402047
# Use transformers model_type logger to suppress message: Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`
20412048
exec("model_logger.addFilter(HideLoggingMessage('`use_cache`'))", globals(), locals())
@@ -2189,6 +2196,7 @@ def replaced_tqdm(*args, **kwargs):
21892196
torch_modules = [x for x in torch_modules if x not in removal]
21902197

21912198
# Check SDPA to load as eager or SDPA (Pixtral / Mistral 3 for eg doesn't have SDPA)
2199+
final_supports_sdpa = True
21922200
if supports_sdpa is not None:
21932201
assert(type(supports_sdpa) is list and len(supports_sdpa) == 1)
21942202
if ("_supports_sdpa = True" in full_source) and ("_supports_sdpa = False" not in full_source):
@@ -2197,7 +2205,10 @@ def replaced_tqdm(*args, **kwargs):
21972205
if supports_sdpa[0] != False: supports_sdpa[0] = True
21982206
else:
21992207
supports_sdpa[0] = False
2208+
final_supports_sdpa = False
22002209
pass
2210+
# Save supports_sdpa to solve secondary imports
2211+
modeling_file.__UNSLOTH_SUPPORTS_SDPA__ = final_supports_sdpa
22012212

22022213
# Get functions which are called
22032214
called_functions = []

‎unsloth_zoo/empty_model.py‎

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -223,33 +223,45 @@ def copy_attributes(original_model, new_model):
223223
if dict_skipped_count > 0:
224224
print(f"📋 Skipped {dict_skipped_count} non-config dictionaries")
225225
if skipped_count > 0:
226-
print(f"⏭️ Skipped {skipped_count} total attributes (tensors, modules, non-config dicts, etc.)")
226+
print(f"⏭️ Skipped {skipped_count} total attributes (tensors, modules, non-config dicts, etc.)")
227227
if skipped_count <= 10:
228-
print(f" Skipped: {skipped_attrs}")
228+
print(f" Skipped: {skipped_attrs}")
229229
else:
230-
print(f" Sample: {skipped_attrs[:5]}... and {skipped_count-5} more")
230+
print(f" Sample: {skipped_attrs[:5]}... and {skipped_count-5} more")
231+
pass
231232

232233

233234
@torch.inference_mode()
234235
def create_empty_causal_lm(config, dtype = torch.float16):
235236
# All Unsloth Zoo code licensed under LGPLv3
236237
from transformers import AutoModelForCausalLM
237-
try:
238-
from accelerate import init_empty_weights
239-
# Suppress warning on uninited weights
240-
old_warn = os.environ.get("UNSLOTH_WARN_UNINITIALIZED", "1")
241-
os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0"
242-
with init_empty_weights():
243-
model_name = getattr(config, 'model_name')
244-
kwargs = {"torch_dtype" if HAS_TORCH_DTYPE else "dtype" : dtype_from_config(config)}
245-
if model_name is not None:
246-
# This would persist quantization information.
238+
from accelerate import init_empty_weights
239+
# Suppress warning on uninited weights
240+
old_warn = os.environ.get("UNSLOTH_WARN_UNINITIALIZED", "1")
241+
os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0"
242+
model_name = getattr(config, 'model_name')
243+
kwargs = {"torch_dtype" if HAS_TORCH_DTYPE else "dtype" : dtype_from_config(config)}
244+
original_meta_model = None
245+
error = None
246+
with init_empty_weights(include_buffers = True):
247+
if model_name is not None:
248+
try:
249+
# This would persist quantization information for FP8 weights
247250
original_meta_model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
248-
else:
251+
except Exception as e:
252+
error = str(e)
253+
original_meta_model = None
254+
if original_meta_model is None:
255+
try:
256+
# We must do this for 4.57.0 and above
249257
original_meta_model = AutoModelForCausalLM.from_config(config)
250-
# Suppress warning on uninited weights
251-
os.environ["UNSLOTH_WARN_UNINITIALIZED"] = old_warn
252-
except Exception as e:
258+
except Exception as e:
259+
error = str(e)
260+
original_meta_model = None
261+
pass
262+
# Suppress warning on uninited weights
263+
os.environ["UNSLOTH_WARN_UNINITIALIZED"] = old_warn
264+
if error is not None and original_meta_model is None:
253265
print(f"Failed to create original_meta_model for AutoModelForCausalLM. Error {e}")
254266
original_meta_model = None
255267

@@ -302,7 +314,7 @@ def _init_weights(self, module):
302314
try:
303315
# Use accelerate's init_empty_weights, not transformers.modeling_utils
304316
from accelerate import init_empty_weights
305-
with init_empty_weights():
317+
with init_empty_weights(include_buffers = True):
306318
original_meta_model = model_cls(config)
307319
except Exception as e:
308320
print(f"Failed to create original_meta_model for {model_cls.__name__}. Error {e}")

‎unsloth_zoo/fused_losses/cross_entropy_loss.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import functools
2727
import math
2828
from ..temporary_patches.common import UNSLOTH_ENABLE_LOGGING, torch_compile_options, logger
29-
from unsloth import DEVICE_TYPE
29+
from ..device_type import DEVICE_TYPE
3030

3131
@functools.cache
3232
def _get_mapping(autograd):

‎unsloth_zoo/rl_replacements.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import os
2424
import numpy as np
2525
from typing import Union, Callable, Optional, List, Dict
26-
from unsloth import DEVICE_TYPE
26+
from .device_type import DEVICE_TYPE
2727
from .temporary_patches.common import torch_compile_options
2828
RL_REPLACEMENTS = dict()
2929

‎unsloth_zoo/temporary_patches/gpt_oss.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def forward(self, hidden_states):
535535

536536

537537
# Combo kernels uses too much VRAM for low memory GPUs
538-
from unsloth import DEVICE_TYPE
538+
from ..device_type import DEVICE_TYPE
539539
if DEVICE_TYPE == "xpu":
540540
device_memory = torch.xpu.memory.mem_get_info(0)[-1]
541541
else:

‎unsloth_zoo/vllm_utils.py‎

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import math
4040
import gc
4141
import os
42+
import ast
43+
import sys
4244
import torch
4345
import json
4446
import psutil
@@ -59,8 +61,7 @@
5961
UNSLOTH_ENABLE_LOGGING,
6062
)
6163
from .log import logger
62-
from unsloth import DEVICE_TYPE
63-
from unsloth.models.vision import VLLM_SUPPORTED_VLM
64+
from .device_type import DEVICE_TYPE
6465
global LORA_REQUEST_ID
6566

6667
# Ignore logging messages
@@ -2360,9 +2361,50 @@ def _test_is_same_vlm(model, new_model, processor, test_backward=False):
23602361
mismatches.append(layer_name)
23612362
print(f"Backward gradient statistics match for {len(matches)} layers: {matches}")
23622363
print(f"Backward gradient statistics mismatch for {len(mismatches)} layers: {mismatches}")
2364+
pass
23632365

23642366

2365-
pass
2367+
def _read_unsloth_vision_source() -> str:
2368+
_VISION_TAIL = ("unsloth", "models", "vision.py")
2369+
from importlib.metadata import files, PackageNotFoundError, PackagePath
2370+
from pathlib import Path
2371+
# 1) Via installed distribution metadata (no import of the package)
2372+
try:
2373+
for entry in files("unsloth") or ():
2374+
if isinstance(entry, PackagePath):
2375+
parts = entry.parts
2376+
if len(parts) >= 3 and tuple(parts[-3:]) == _VISION_TAIL:
2377+
return entry.read_text(encoding = "utf-8")
2378+
except PackageNotFoundError:
2379+
pass
2380+
2381+
# 2) Fallback: scan sys.path for a plain file
2382+
for base in map(Path, sys.path):
2383+
candidate = base.joinpath(*_VISION_TAIL)
2384+
if candidate.is_file():
2385+
return candidate.read_text(encoding = "utf-8")
2386+
raise FileNotFoundError("Could not locate unsloth/models/vision.py without importing it")
2387+
pass
2388+
2389+
2390+
def get_vllm_supported_vlm(_VAR_NAME = "VLLM_SUPPORTED_VLM"):
2391+
"""
2392+
Parse VLLM_SUPPORTED_VLM from unsloth/models/vision.py as a literal.
2393+
"""
2394+
src = _read_unsloth_vision_source()
2395+
tree = ast.parse(src)
2396+
2397+
# Support: `VLLM_SUPPORTED_VLM = [...]` and `VLLM_SUPPORTED_VLM: list[str] = [...]`
2398+
for node in tree.body:
2399+
if isinstance(node, ast.Assign):
2400+
if any(getattr(t, "id", None) == _VAR_NAME for t in node.targets):
2401+
return ast.literal_eval(node.value)
2402+
elif isinstance(node, ast.AnnAssign):
2403+
if getattr(node.target, "id", None) == _VAR_NAME:
2404+
return ast.literal_eval(node.value)
2405+
raise ValueError(f"{_VAR_NAME} not found as a literal in unsloth/models/vision.py")
2406+
pass
2407+
23662408

23672409
@torch.inference_mode
23682410
def _test_get_vllm_state_dict(
@@ -2419,6 +2461,7 @@ def _test_get_vllm_state_dict(
24192461
if not is_vision_model:
24202462
model_class = AutoModelForCausalLM
24212463
else:
2464+
VLLM_SUPPORTED_VLM = get_vllm_supported_vlm()
24222465
if model_type in VLLM_SUPPORTED_VLM:
24232466
import transformers
24242467
model_class = getattr(transformers, config.architectures[0])

0 commit comments

Comments
 (0)