Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,7 @@ def _maybe_fake_quantize_activations(
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
# TODO: After adding XPU BNB support, check this function
if isinstance(W, Float8Tensor):
# TorchAO Float8Tensor
# In the backward pass, rowwise scaled becomes colwise scaled after we
# transpose the weight tensor. Use this case to detect backward
assert W.ndim == 2
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
return W.dequantize()
return W.dequantize()
if quant_state is None:
return W
if W.dtype == torch.float8_e4m3fn:
Expand Down Expand Up @@ -465,12 +460,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if isinstance(W, Float8Tensor):
# TorchAO Float8Tensor
# In the backward pass, rowwise scaled becomes colwise scaled after we
# transpose the weight tensor. Use this case to detect backward
assert W.ndim == 2
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
return W.dequantize()
return W.dequantize()
if quant_state is None:
return W
if W.dtype == torch.float8_e4m3fn:
Expand Down Expand Up @@ -582,12 +572,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if isinstance(W, Float8Tensor):
# TorchAO Float8Tensor
# In the backward pass, rowwise scaled becomes colwise scaled after we
# transpose the weight tensor. Use this case to detect backward
assert W.ndim == 2
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
return W.dequantize()
return W.dequantize()
if quant_state is None:
return W
if W.dtype == torch.float8_e4m3fn:
Expand Down Expand Up @@ -1021,7 +1006,17 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
else:
reshape = False

if W.dtype == torch.float8_e4m3fn:
if isinstance(W, Float8Tensor):
assert W.ndim == 2
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
# In the backward pass, rowwise scaled becomes colwise scaled after we
# transpose the weight tensor. Use this case to detect backward.
# TODO: would be simpler if we simply don't call `matmul_lora` in backward
W = W.dequantize()
else:
W = W.contiguous()
out = torch_matmul(X, W.t(), out = out)
elif W.dtype == torch.float8_e4m3fn:
out = fp8_linear(X, W, W_quant)
else:
W = fast_dequantize(W, W_quant, use_global_buffer = True)
Expand Down
20 changes: 13 additions & 7 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
from .loader_utils import (
_check_load_in_fp8_settings,
_get_fp8_mode_and_check_settings,
_offline_quantize_to_fp8,
_tag_model_with_fp8_torchao_config,
get_model_name,
Expand Down Expand Up @@ -220,19 +220,22 @@ def from_pretrained(
load_in_4bit = False

if load_in_fp8:
_check_load_in_fp8_settings(
fp8_mode = _get_fp8_mode_and_check_settings(
load_in_fp8,
fast_inference,
full_finetuning,
load_in_4bit,
load_in_8bit,
load_in_16bit,
use_exact_model_name,
)
else:
fp8_mode = None

old_model_name = model_name
if not use_exact_model_name:
if load_in_fp8:
model_name = _offline_quantize_to_fp8(model_name)
model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
else:
model_name = get_model_name(model_name, load_in_4bit)

Expand Down Expand Up @@ -578,7 +581,7 @@ def from_pretrained(
model.config.update({"quantization_config": quantization_config})

if load_in_fp8:
_tag_model_with_fp8_torchao_config(model)
_tag_model_with_fp8_torchao_config(model, fp8_mode)

if is_peft:
# From https://github.com/huggingface/peft/issues/184
Expand Down Expand Up @@ -722,19 +725,22 @@ def from_pretrained(
load_in_4bit = False

if load_in_fp8:
_check_load_in_fp8_settings(
fp8_mode = _get_fp8_mode_and_check_settings(
load_in_fp8,
fast_inference,
full_finetuning,
load_in_4bit,
load_in_8bit,
load_in_16bit,
use_exact_model_name,
)
else:
fp8_mode = None

old_model_name = model_name
if not use_exact_model_name:
if load_in_fp8:
model_name = _offline_quantize_to_fp8(model_name)
model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
else:
model_name = get_model_name(model_name, load_in_4bit)

Expand Down Expand Up @@ -1172,7 +1178,7 @@ def from_pretrained(
model.config.update({"quantization_config": quantization_config})

if load_in_fp8:
_tag_model_with_fp8_torchao_config(model)
_tag_model_with_fp8_torchao_config(model, fp8_mode)

if is_peft:
# From https://github.com/huggingface/peft/issues/184
Expand Down
75 changes: 55 additions & 20 deletions unsloth/models/loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
import os
import re
import tempfile
from typing import Union
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit

# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from packaging.version import Version
from transformers import (
AutoModel,
AutoProcessor,
AutoModelForCausalLM,
AutoTokenizer,
TorchAoConfig,
__version__ as transformers_version,
)
Expand Down Expand Up @@ -158,20 +159,31 @@ def get_model_name(model_name, load_in_4bit = True):
return new_model_name if new_model_name is not None else model_name


def _get_torchao_fp8_config():
def _get_torchao_fp8_config(fp8_mode: str):
"""
Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig`
to be used for `load_in_fp8=True`.
"""
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
PerBlock,
PerRow,
)

if fp8_mode == "row":
granularity = PerRow()
Comment on lines +173 to +174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This validation is redundant, as fp8_mode is already validated in _get_fp8_mode_and_check_settings before being passed to this function. For internal functions, it's better to rely on assertions for contract checking rather than raising user-facing ValueErrors. This avoids duplicated validation logic and makes the code cleaner.

Consider removing this else block. If you want to keep a check for robustness, an assert would be more appropriate, for example:

assert fp8_mode in ['row', 'block']

However, given the call chain, even an assert is likely unnecessary.

elif fp8_mode == "block":
granularity = (PerBlock([1, 128]), PerBlock([128, 128]))
else:
raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'")

return Float8DynamicActivationFloat8WeightConfig(
granularity = PerRow(),
granularity = granularity,
activation_value_lb = 1e-12,
)


def _offline_quantize_to_fp8(model_name: str) -> str:
def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:
"""
Quantizes the model to fp8 using torchao and saving the quantized model to a
temporary location. Return the path to the quantized model.
Expand All @@ -186,53 +198,72 @@ def _offline_quantize_to_fp8(model_name: str) -> str:
)
"""
temp_dir = tempfile.gettempdir()
new_model_name = model_name.split("/")[-1] + "-fp8"
new_model_name = model_name.split("/")[-1] + "-fp8-" + fp8_mode
new_model_name = os.path.join(temp_dir, new_model_name)
print(
f"Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead"
f"Unsloth: Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead"
)

if not os.path.isdir(new_model_name):
qconfig = _get_torchao_fp8_config()
qconfig = _get_torchao_fp8_config(fp8_mode)
qconfig = TorchAoConfig(qconfig)
model = AutoModel.from_pretrained(
# TODO: generalize this to beyond text models?
# Right now using AutoModel removes the `lm_head` layer,
# which is expected later when loading the vllm state dict
model = AutoModelForCausalLM.from_pretrained(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielhanchen I had to change this back for this to work. When I tried AutoModel it removed the lm_head from the state dict, which later caused an out of bounds exception on this line: https://github.com/unslothai/unsloth-zoo/blob/54dce973426ee61670e15b720619c8539bf05104/unsloth_zoo/vllm_utils.py#L1110. Maybe we can generalize this to multimodal models later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok oh wait I can make this work for vision models without using AutoModel

model_name,
torch_dtype = "auto",
device_map = "auto",
quantization_config = qconfig,
)
tokenizer = AutoProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.save_pretrained(new_model_name, safe_serialization = False)
tokenizer.save_pretrained(new_model_name)
return new_model_name


def _tag_model_with_fp8_torchao_config(model: torch.nn.Module):
def _tag_model_with_fp8_torchao_config(model: torch.nn.Module, fp8_mode: str):
"""
Tag a model with a `TorchAOConfig` so downstream callers will know what to do with it.
"""
base_config = _get_torchao_fp8_config()
base_config = _get_torchao_fp8_config(fp8_mode)
model.torchao_config = TorchAOConfig(
qat_scheme = None,
base_config_and_filter_fns = [(base_config, None)],
)


def _check_load_in_fp8_settings(
def _get_fp8_mode_and_check_settings(
load_in_fp8: Union[bool, str],
fast_inference: bool,
full_finetuning: bool,
load_in_4bit: bool,
load_in_8bit: bool,
load_in_16bit: bool,
use_exact_model_name: bool,
):
) -> str:
"""
Assuming `load_in_fp8=True`, raise appropriate errors on incompatible settings
Assuming `load_in_fp8` is enabled, raise appropriate errors on incompatible settings
and environment. Currently this feature requires:

1. H100 GPUs or after
2. torchao 0.15.0+ (or nightly)
3. torch 2.9.0+
4. If fbgemm_gpu_genai is installed, require 1.4.1+

Returns the fp8 mode, one of "row" or "block".
"""
assert load_in_fp8 is not False
if load_in_fp8 is True:
fp8_mode = "row" # default
else:
fp8_mode = load_in_fp8

# Check user settings
if fp8_mode not in ["row", "block"]:
raise ValueError(
f"Unsloth: `load_in_fp8` can only be 'row' or 'block', got '{fp8_mode}'"
)
if not fast_inference:
raise ValueError(
"Unsloth: `load_in_fp8` is only supported for `fast_inference` for now"
Expand Down Expand Up @@ -263,13 +294,16 @@ def _check_load_in_fp8_settings(
# Check if torchao has this PR: https://github.com/pytorch/ao/pull/3158,
# which will be released in 0.15.0.
if importlib.util.find_spec("torchao") is None:
raise ValueError("Unsloth: Please install torchao for on the fly float8 to work!")
raise ValueError(
"Unsloth: Please install torchao for on the fly float8 to work!"
)
import torchao

error_message = \
"Unsloth: `load_in_fp8` requires torchao 0.15.0+ (or nightly).\n"\
f"You have torchao version={torchao.__version__}\n"\
error_message = (
"Unsloth: `load_in_fp8` requires torchao 0.15.0+ (or nightly).\n"
f"You have torchao version={torchao.__version__}\n"
"Use `pip install --upgrade --force-reinstall torchao`"
)
if Version(torchao.__version__) < Version("0.15.0"):
raise ValueError(error_message)

Expand All @@ -284,3 +318,4 @@ def _check_load_in_fp8_settings(
raise ValueError(
"Unsloth: `load_in_fp8` is only compatible with fbgemm_gpu_genai 1.4.1+"
)
return fp8_mode