Skip to content

Commit b9d9600

Browse files
Extend TorchAOConfig to support mobile usecases (unslothai#3587)
* up * up * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6c73f7a commit b9d9600

File tree

1 file changed

+144
-36
lines changed

1 file changed

+144
-36
lines changed

‎unsloth/models/_utils.py‎

Lines changed: 144 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
]
7575

7676
import torch
77-
from typing import Union, Optional, List, Any, Callable, Tuple
77+
from typing import Union, Optional, List, Any, Callable, Tuple, Iterator
7878
from platform import system as platform_system
7979

8080
platform_system = platform_system()
@@ -2013,18 +2013,110 @@ def error_out_no_vllm(*args, **kwargs):
20132013
@dataclass
20142014
class TorchAOConfig:
20152015
qat_scheme: str = "int4"
2016-
base_config: AOBaseConfig = field(
2017-
default_factory = lambda: Int4WeightOnlyConfig(group_size = 128)
2018-
)
2019-
group_size: int = 128
2020-
filter_fn: Optional[Callable] = None
20212016

2022-
def __post_init__(self):
2023-
if self.filter_fn is None:
2024-
self.filter_fn = (
2017+
# Each (config, filter_fn) pair defines a quantization rule
2018+
base_config_and_filter_fns: List[
2019+
Tuple["AOBaseConfig", Optional[Callable[[torch.nn.Module, str], bool]]]
2020+
] = field(
2021+
default_factory = lambda: [
2022+
(
2023+
Int4WeightOnlyConfig(group_size = 128),
20252024
lambda m, _: isinstance(m, torch.nn.Linear)
2026-
and m.in_features >= self.group_size
2027-
)
2025+
and getattr(m, "in_features", 0) >= 128,
2026+
),
2027+
]
2028+
)
2029+
2030+
# Optional transformation to apply before quantization setup
2031+
prequantization_transform: Optional[Callable[[torch.nn.Module], None]] = None
2032+
2033+
2034+
def _untie_input_output_embeddings(model: torch.nn.Module) -> None:
2035+
"""
2036+
Utility to untie input/output embeddings in a HuggingFace model.
2037+
This is useful if we want to quantize the input/ouput embeddings differently.
2038+
Model is modified in-place.
2039+
"""
2040+
2041+
# 1) Persist setting in config
2042+
if hasattr(model.config, "tie_word_embeddings"):
2043+
model.config.tie_word_embeddings = False
2044+
2045+
# 2) Find input and output embeddings
2046+
in_emb = model.get_input_embeddings()
2047+
out_proj = model.get_output_embeddings() or getattr(model, "lm_head", None)
2048+
if out_proj is None:
2049+
raise AttributeError("Couldn't locate output projection (lm_head).")
2050+
2051+
# (Optional) sanity: shapes should match [vocab, hidden]
2052+
assert (
2053+
out_proj.weight.shape == in_emb.weight.shape
2054+
), f"Shape mismatch: out_proj {out_proj.weight.shape} vs in_emb {in_emb.weight.shape}"
2055+
2056+
# 3) Only clone if they are actually tied (shared storage)
2057+
if out_proj.weight.data_ptr() == in_emb.weight.data_ptr():
2058+
with torch.no_grad():
2059+
W = in_emb.weight.detach().clone()
2060+
out_proj.weight = torch.nn.Parameter(W) # new storage, keeps dtype/device
2061+
2062+
# 4) Prevent future automatic re-tying
2063+
def _no_tie(self):
2064+
return
2065+
2066+
model.tie_weights = _no_tie.__get__(model, model.__class__)
2067+
2068+
# 5) Verify no shared storage
2069+
assert (
2070+
out_proj.weight.data_ptr() != in_emb.weight.data_ptr()
2071+
), "Embeddings still tied!"
2072+
2073+
2074+
def _filter_fn_to_fqns(
2075+
model: torch.nn.Module,
2076+
filter_fn: Callable[[torch.nn.Module, str], bool],
2077+
) -> Iterator[str]:
2078+
"""
2079+
Given a model and a filter function (m, fqn) -> bool,
2080+
yield fully qualified names (FQNs) of modules that match.
2081+
"""
2082+
for fqn, module in model.named_modules():
2083+
if filter_fn(module, fqn):
2084+
yield fqn
2085+
2086+
2087+
def _convert_torchao_model(model):
2088+
from transformers import TorchAoConfig
2089+
from torchao.quantization import quantize_, ModuleFqnToConfig
2090+
from torchao.quantization.qat import QATConfig
2091+
from torchao.utils import TorchAOBaseTensor
2092+
2093+
module_to_fqn_dict = {}
2094+
for base_config, filter_fn in model._torchao_config.base_config_and_filter_fns:
2095+
quantize_(model, QATConfig(base_config, step = "convert"), filter_fn = filter_fn)
2096+
2097+
# Default filter function used for quantize_
2098+
if filter_fn is None:
2099+
if "_default" in module_to_fqn_dict:
2100+
raise ValueError("Cannot use multiple default quantization configs")
2101+
module_to_fqn_dict["_default"] = base_config
2102+
else:
2103+
for fqn in _filter_fn_to_fqns(model, filter_fn):
2104+
if fqn in module_to_fqn_dict:
2105+
raise ValueError(f"Found multiple quantization configs for {fqn}")
2106+
module_to_fqn_dict[fqn] = base_config
2107+
2108+
in_emb = model.get_input_embeddings()
2109+
out_proj = model.get_output_embeddings() or getattr(model, "lm_head", None)
2110+
kwargs = {}
2111+
if isinstance(in_emb.weight, TorchAOBaseTensor) or (
2112+
out_proj is not None and isinstance(out_proj.weight, TorchAOBaseTensor)
2113+
):
2114+
kwargs["include_input_output_embeddings"] = True
2115+
kwargs["modules_to_not_convert"] = []
2116+
2117+
quant_config = ModuleFqnToConfig(module_to_fqn_dict)
2118+
quantization_config = TorchAoConfig(quant_type = quant_config, **kwargs)
2119+
model.config.quantization_config = quantization_config
20282120

20292121

20302122
def _prepare_model_for_qat(
@@ -2041,13 +2133,11 @@ def _prepare_model_for_qat(
20412133
For more details: https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700
20422134
"""
20432135
from torchao.quantization import PerRow, quantize_
2044-
from torchao.quantization.granularity import PerGroup
2136+
from torchao.quantization.granularity import PerGroup, PerAxis
20452137
from torchao.quantization.qat import QATConfig
20462138

20472139
if not isinstance(qat_scheme, TorchAOConfig):
2048-
filter_fn = None
2049-
group_size = None
2050-
base_config = None
2140+
torchao_config: Optional[TorchAOConfig] = None
20512141
if qat_scheme == "fp8-int4":
20522142
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
20532143

@@ -2057,22 +2147,42 @@ def _prepare_model_for_qat(
20572147
lambda m, _: isinstance(m, torch.nn.Linear)
20582148
and m.in_features >= group_size
20592149
)
2150+
torchao_config = TorchAOConfig(
2151+
qat_scheme = qat_scheme,
2152+
base_config_and_filter_fns = [(base_config, filter_fn)],
2153+
)
20602154
elif qat_scheme == "fp8-fp8":
20612155
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
20622156

20632157
base_config = Float8DynamicActivationFloat8WeightConfig(
20642158
granularity = PerRow()
20652159
)
2160+
torchao_config = TorchAOConfig(
2161+
qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)]
2162+
)
20662163
elif qat_scheme == "int8-int4":
2067-
from torchao.quantization import Int8DynamicActivationIntxWeightConfig
2068-
2069-
group_size = 32
2070-
base_config = Int8DynamicActivationIntxWeightConfig(
2071-
weight_dtype = torch.int4, weight_granularity = PerGroup(group_size)
2164+
from torchao.quantization import (
2165+
Int8DynamicActivationIntxWeightConfig,
2166+
IntxWeightOnlyConfig,
20722167
)
2073-
filter_fn = (
2074-
lambda m, _: isinstance(m, torch.nn.Linear)
2075-
and m.in_features >= group_size
2168+
2169+
torchao_config = TorchAOConfig(
2170+
qat_scheme = qat_scheme,
2171+
base_config_and_filter_fns = [
2172+
(
2173+
IntxWeightOnlyConfig(
2174+
weight_dtype = torch.int8, granularity = PerAxis(0)
2175+
),
2176+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
2177+
),
2178+
(
2179+
Int8DynamicActivationIntxWeightConfig(
2180+
weight_dtype = torch.int4, weight_granularity = PerGroup(32)
2181+
),
2182+
None,
2183+
),
2184+
],
2185+
prequantization_transform = _untie_input_output_embeddings,
20762186
)
20772187
elif qat_scheme == "int4":
20782188
from torchao.quantization import Int4WeightOnlyConfig
@@ -2083,30 +2193,28 @@ def _prepare_model_for_qat(
20832193
lambda m, _: isinstance(m, torch.nn.Linear)
20842194
and m.in_features >= group_size
20852195
)
2196+
torchao_config = TorchAOConfig(
2197+
qat_scheme = qat_scheme,
2198+
base_config_and_filter_fns = [(base_config, filter_fn)],
2199+
)
20862200
else:
20872201
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
2088-
# Save TorchAO schemes
2089-
torchao_config = TorchAOConfig(
2090-
qat_scheme = qat_scheme,
2091-
base_config = base_config,
2092-
group_size = group_size,
2093-
filter_fn = filter_fn,
2094-
)
2202+
assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"
20952203
else:
20962204
torchao_config = qat_scheme
2097-
qat_scheme = torchao_config.qat_scheme
2098-
base_config = torchao_config.base_config
2099-
group_size = torchao_config.group_size
2100-
filter_fn = torchao_config.filter_fn
21012205

21022206
# Save Torchao metadata everywhere
21032207
inner_model = model
21042208
while hasattr(inner_model, "model"):
21052209
inner_model._torchao_config = torchao_config
21062210
inner_model = inner_model.model
21072211
inner_model._torchao_config = torchao_config
2108-
# Quantize with TorchAO
2109-
quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)
2212+
2213+
if torchao_config.prequantization_transform is not None:
2214+
torchao_config.prequantization_transform(model)
2215+
for base_config, filter_fn in torchao_config.base_config_and_filter_fns:
2216+
quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)
2217+
21102218
return model
21112219

21122220

0 commit comments

Comments
 (0)