Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Version
  • Loading branch information
danielhanchen committed Nov 22, 2025
commit c11af7a0ecefa306ee1e7e31155e94a66c57497e
35 changes: 24 additions & 11 deletions unsloth/models/loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,11 @@

# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from packaging.version import Version
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TorchAoConfig,
__version__ as transformers_version,
)
from transformers import __version__ as transformers_version
from unsloth.models._utils import TorchAOConfig
from unsloth_zoo.utils import Version
import torch
import gc

transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
Expand Down Expand Up @@ -205,19 +201,36 @@ def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:
)

if not os.path.isdir(new_model_name):
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoTokenizer,
AutoProcessor,
TorchAoConfig,
AutoConfig,
)
qconfig = _get_torchao_fp8_config(fp8_mode)
qconfig = TorchAoConfig(qconfig)
# TODO: generalize this to beyond text models?
# Right now using AutoModel removes the `lm_head` layer,
# which is expected later when loading the vllm state dict
model = AutoModelForCausalLM.from_pretrained(
config = AutoConfig.from_pretrained(model_name)
is_vlm = any(
x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
for x in config.architectures
)
is_vlm = is_vlm or hasattr(config, "vision_config")
auto_model = AutoModelForImageTextToText if is_vlm else AutoModelForCausalLM
auto_processor = AutoProcessor if is_vlm else AutoTokenizer
model = auto_model.from_pretrained(
model_name,
torch_dtype = "auto",
device_map = "auto",
quantization_config = qconfig,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = auto_processor.from_pretrained(model_name)
model.save_pretrained(new_model_name, safe_serialization = False)
del model
for _ in range(2):
torch.cuda.empty_cache()
gc.collect()
tokenizer.save_pretrained(new_model_name)
return new_model_name

Expand Down