Skip to content

Commit e28b7c2

Browse files
committed
Update loader_utils.py
1 parent c142283 commit e28b7c2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

‎unsloth/models/loader_utils.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
2222
from packaging.version import Version
2323
from transformers import (
24-
AutoModelForCausalLM,
25-
AutoTokenizer,
24+
AutoModel,
25+
AutoProcessor,
2626
TorchAoConfig,
2727
__version__ as transformers_version,
2828
)
@@ -194,13 +194,13 @@ def _offline_quantize_to_fp8(model_name: str) -> str:
194194
if not os.path.isdir(new_model_name):
195195
qconfig = _get_torchao_fp8_config()
196196
qconfig = TorchAoConfig(qconfig)
197-
model = AutoModelForCausalLM.from_pretrained(
197+
model = AutoModel.from_pretrained(
198198
model_name,
199199
torch_dtype = "auto",
200200
device_map = "auto",
201201
quantization_config = qconfig,
202202
)
203-
tokenizer = AutoTokenizer.from_pretrained(model_name)
203+
tokenizer = AutoProcessor.from_pretrained(model_name)
204204
model.save_pretrained(new_model_name, safe_serialization = False)
205205
tokenizer.save_pretrained(new_model_name)
206206
return new_model_name

0 commit comments

Comments
 (0)