7474]
7575
7676import torch
77- from typing import Union , Optional , List , Any , Callable , Tuple
77+ from typing import Union , Optional , List , Any , Callable , Tuple , Iterator
7878from platform import system as platform_system
7979
8080platform_system = platform_system ()
@@ -2013,18 +2013,110 @@ def error_out_no_vllm(*args, **kwargs):
20132013@dataclass
20142014class TorchAOConfig :
20152015 qat_scheme : str = "int4"
2016- base_config : AOBaseConfig = field (
2017- default_factory = lambda : Int4WeightOnlyConfig (group_size = 128 )
2018- )
2019- group_size : int = 128
2020- filter_fn : Optional [Callable ] = None
20212016
2022- def __post_init__ (self ):
2023- if self .filter_fn is None :
2024- self .filter_fn = (
2017+ # Each (config, filter_fn) pair defines a quantization rule
2018+ base_config_and_filter_fns : List [
2019+ Tuple ["AOBaseConfig" , Optional [Callable [[torch .nn .Module , str ], bool ]]]
2020+ ] = field (
2021+ default_factory = lambda : [
2022+ (
2023+ Int4WeightOnlyConfig (group_size = 128 ),
20252024 lambda m , _ : isinstance (m , torch .nn .Linear )
2026- and m .in_features >= self .group_size
2027- )
2025+ and getattr (m , "in_features" , 0 ) >= 128 ,
2026+ ),
2027+ ]
2028+ )
2029+
2030+ # Optional transformation to apply before quantization setup
2031+ prequantization_transform : Optional [Callable [[torch .nn .Module ], None ]] = None
2032+
2033+
2034+ def _untie_input_output_embeddings (model : torch .nn .Module ) -> None :
2035+ """
2036+ Utility to untie input/output embeddings in a HuggingFace model.
2037+ This is useful if we want to quantize the input/ouput embeddings differently.
2038+ Model is modified in-place.
2039+ """
2040+
2041+ # 1) Persist setting in config
2042+ if hasattr (model .config , "tie_word_embeddings" ):
2043+ model .config .tie_word_embeddings = False
2044+
2045+ # 2) Find input and output embeddings
2046+ in_emb = model .get_input_embeddings ()
2047+ out_proj = model .get_output_embeddings () or getattr (model , "lm_head" , None )
2048+ if out_proj is None :
2049+ raise AttributeError ("Couldn't locate output projection (lm_head)." )
2050+
2051+ # (Optional) sanity: shapes should match [vocab, hidden]
2052+ assert (
2053+ out_proj .weight .shape == in_emb .weight .shape
2054+ ), f"Shape mismatch: out_proj { out_proj .weight .shape } vs in_emb { in_emb .weight .shape } "
2055+
2056+ # 3) Only clone if they are actually tied (shared storage)
2057+ if out_proj .weight .data_ptr () == in_emb .weight .data_ptr ():
2058+ with torch .no_grad ():
2059+ W = in_emb .weight .detach ().clone ()
2060+ out_proj .weight = torch .nn .Parameter (W ) # new storage, keeps dtype/device
2061+
2062+ # 4) Prevent future automatic re-tying
2063+ def _no_tie (self ):
2064+ return
2065+
2066+ model .tie_weights = _no_tie .__get__ (model , model .__class__ )
2067+
2068+ # 5) Verify no shared storage
2069+ assert (
2070+ out_proj .weight .data_ptr () != in_emb .weight .data_ptr ()
2071+ ), "Embeddings still tied!"
2072+
2073+
2074+ def _filter_fn_to_fqns (
2075+ model : torch .nn .Module ,
2076+ filter_fn : Callable [[torch .nn .Module , str ], bool ],
2077+ ) -> Iterator [str ]:
2078+ """
2079+ Given a model and a filter function (m, fqn) -> bool,
2080+ yield fully qualified names (FQNs) of modules that match.
2081+ """
2082+ for fqn , module in model .named_modules ():
2083+ if filter_fn (module , fqn ):
2084+ yield fqn
2085+
2086+
2087+ def _convert_torchao_model (model ):
2088+ from transformers import TorchAoConfig
2089+ from torchao .quantization import quantize_ , ModuleFqnToConfig
2090+ from torchao .quantization .qat import QATConfig
2091+ from torchao .utils import TorchAOBaseTensor
2092+
2093+ module_to_fqn_dict = {}
2094+ for base_config , filter_fn in model ._torchao_config .base_config_and_filter_fns :
2095+ quantize_ (model , QATConfig (base_config , step = "convert" ), filter_fn = filter_fn )
2096+
2097+ # Default filter function used for quantize_
2098+ if filter_fn is None :
2099+ if "_default" in module_to_fqn_dict :
2100+ raise ValueError ("Cannot use multiple default quantization configs" )
2101+ module_to_fqn_dict ["_default" ] = base_config
2102+ else :
2103+ for fqn in _filter_fn_to_fqns (model , filter_fn ):
2104+ if fqn in module_to_fqn_dict :
2105+ raise ValueError (f"Found multiple quantization configs for { fqn } " )
2106+ module_to_fqn_dict [fqn ] = base_config
2107+
2108+ in_emb = model .get_input_embeddings ()
2109+ out_proj = model .get_output_embeddings () or getattr (model , "lm_head" , None )
2110+ kwargs = {}
2111+ if isinstance (in_emb .weight , TorchAOBaseTensor ) or (
2112+ out_proj is not None and isinstance (out_proj .weight , TorchAOBaseTensor )
2113+ ):
2114+ kwargs ["include_input_output_embeddings" ] = True
2115+ kwargs ["modules_to_not_convert" ] = []
2116+
2117+ quant_config = ModuleFqnToConfig (module_to_fqn_dict )
2118+ quantization_config = TorchAoConfig (quant_type = quant_config , ** kwargs )
2119+ model .config .quantization_config = quantization_config
20282120
20292121
20302122def _prepare_model_for_qat (
@@ -2041,13 +2133,11 @@ def _prepare_model_for_qat(
20412133 For more details: https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700
20422134 """
20432135 from torchao .quantization import PerRow , quantize_
2044- from torchao .quantization .granularity import PerGroup
2136+ from torchao .quantization .granularity import PerGroup , PerAxis
20452137 from torchao .quantization .qat import QATConfig
20462138
20472139 if not isinstance (qat_scheme , TorchAOConfig ):
2048- filter_fn = None
2049- group_size = None
2050- base_config = None
2140+ torchao_config : Optional [TorchAOConfig ] = None
20512141 if qat_scheme == "fp8-int4" :
20522142 from torchao .quantization import Float8DynamicActivationInt4WeightConfig
20532143
@@ -2057,22 +2147,42 @@ def _prepare_model_for_qat(
20572147 lambda m , _ : isinstance (m , torch .nn .Linear )
20582148 and m .in_features >= group_size
20592149 )
2150+ torchao_config = TorchAOConfig (
2151+ qat_scheme = qat_scheme ,
2152+ base_config_and_filter_fns = [(base_config , filter_fn )],
2153+ )
20602154 elif qat_scheme == "fp8-fp8" :
20612155 from torchao .quantization import Float8DynamicActivationFloat8WeightConfig
20622156
20632157 base_config = Float8DynamicActivationFloat8WeightConfig (
20642158 granularity = PerRow ()
20652159 )
2160+ torchao_config = TorchAOConfig (
2161+ qat_scheme = qat_scheme , base_config_and_filter_fns = [(base_config , None )]
2162+ )
20662163 elif qat_scheme == "int8-int4" :
2067- from torchao .quantization import Int8DynamicActivationIntxWeightConfig
2068-
2069- group_size = 32
2070- base_config = Int8DynamicActivationIntxWeightConfig (
2071- weight_dtype = torch .int4 , weight_granularity = PerGroup (group_size )
2164+ from torchao .quantization import (
2165+ Int8DynamicActivationIntxWeightConfig ,
2166+ IntxWeightOnlyConfig ,
20722167 )
2073- filter_fn = (
2074- lambda m , _ : isinstance (m , torch .nn .Linear )
2075- and m .in_features >= group_size
2168+
2169+ torchao_config = TorchAOConfig (
2170+ qat_scheme = qat_scheme ,
2171+ base_config_and_filter_fns = [
2172+ (
2173+ IntxWeightOnlyConfig (
2174+ weight_dtype = torch .int8 , granularity = PerAxis (0 )
2175+ ),
2176+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
2177+ ),
2178+ (
2179+ Int8DynamicActivationIntxWeightConfig (
2180+ weight_dtype = torch .int4 , weight_granularity = PerGroup (32 )
2181+ ),
2182+ None ,
2183+ ),
2184+ ],
2185+ prequantization_transform = _untie_input_output_embeddings ,
20762186 )
20772187 elif qat_scheme == "int4" :
20782188 from torchao .quantization import Int4WeightOnlyConfig
@@ -2083,30 +2193,28 @@ def _prepare_model_for_qat(
20832193 lambda m , _ : isinstance (m , torch .nn .Linear )
20842194 and m .in_features >= group_size
20852195 )
2196+ torchao_config = TorchAOConfig (
2197+ qat_scheme = qat_scheme ,
2198+ base_config_and_filter_fns = [(base_config , filter_fn )],
2199+ )
20862200 else :
20872201 raise ValueError (f"Unexpected QAT scheme { qat_scheme } " )
2088- # Save TorchAO schemes
2089- torchao_config = TorchAOConfig (
2090- qat_scheme = qat_scheme ,
2091- base_config = base_config ,
2092- group_size = group_size ,
2093- filter_fn = filter_fn ,
2094- )
2202+ assert torchao_config is not None , f"TorchAOConfig was not set for { qat_scheme } "
20952203 else :
20962204 torchao_config = qat_scheme
2097- qat_scheme = torchao_config .qat_scheme
2098- base_config = torchao_config .base_config
2099- group_size = torchao_config .group_size
2100- filter_fn = torchao_config .filter_fn
21012205
21022206 # Save Torchao metadata everywhere
21032207 inner_model = model
21042208 while hasattr (inner_model , "model" ):
21052209 inner_model ._torchao_config = torchao_config
21062210 inner_model = inner_model .model
21072211 inner_model ._torchao_config = torchao_config
2108- # Quantize with TorchAO
2109- quantize_ (model , QATConfig (base_config , step = "prepare" ), filter_fn = filter_fn )
2212+
2213+ if torchao_config .prequantization_transform is not None :
2214+ torchao_config .prequantization_transform (model )
2215+ for base_config , filter_fn in torchao_config .base_config_and_filter_fns :
2216+ quantize_ (model , QATConfig (base_config , step = "prepare" ), filter_fn = filter_fn )
2217+
21102218 return model
21112219
21122220
0 commit comments