Skip to content

Commit 5511aab

Browse files
committed
Add 128x128 PerBlock FP8 + RL
**Summary:** Following unslothai#3440, this PR extends torchao FP8 + RL support to also handle 128x128 PerBlock granularity (in addition to PerRow). **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = "block", # or "row" or True ) ``` **Initial results:** TBD **Note:** - Requires pytorch/ao#3370
1 parent e28b7c2 commit 5511aab

File tree

3 files changed

+71
-42
lines changed

3 files changed

+71
-42
lines changed

‎unsloth/kernels/utils.py‎

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,7 @@ def _maybe_fake_quantize_activations(
352352
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
353353
# TODO: After adding XPU BNB support, check this function
354354
if isinstance(W, Float8Tensor):
355-
# TorchAO Float8Tensor
356-
# In the backward pass, rowwise scaled becomes colwise scaled after we
357-
# transpose the weight tensor. Use this case to detect backward
358-
assert W.ndim == 2
359-
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
360-
return W.dequantize()
355+
return W.dequantize()
361356
if quant_state is None:
362357
return W
363358
if W.dtype == torch.float8_e4m3fn:
@@ -465,12 +460,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
465460
@torch.inference_mode
466461
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
467462
if isinstance(W, Float8Tensor):
468-
# TorchAO Float8Tensor
469-
# In the backward pass, rowwise scaled becomes colwise scaled after we
470-
# transpose the weight tensor. Use this case to detect backward
471-
assert W.ndim == 2
472-
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
473-
return W.dequantize()
463+
return W.dequantize()
474464
if quant_state is None:
475465
return W
476466
if W.dtype == torch.float8_e4m3fn:
@@ -582,12 +572,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
582572
@torch.inference_mode
583573
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
584574
if isinstance(W, Float8Tensor):
585-
# TorchAO Float8Tensor
586-
# In the backward pass, rowwise scaled becomes colwise scaled after we
587-
# transpose the weight tensor. Use this case to detect backward
588-
assert W.ndim == 2
589-
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
590-
return W.dequantize()
575+
return W.dequantize()
591576
if quant_state is None:
592577
return W
593578
if W.dtype == torch.float8_e4m3fn:
@@ -1021,7 +1006,17 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
10211006
else:
10221007
reshape = False
10231008

1024-
if W.dtype == torch.float8_e4m3fn:
1009+
if isinstance(W, Float8Tensor):
1010+
assert W.ndim == 2
1011+
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
1012+
# In the backward pass, rowwise scaled becomes colwise scaled after we
1013+
# transpose the weight tensor. Use this case to detect backward.
1014+
# TODO: would be simpler if we simply don't call `matmul_lora` in backward
1015+
W = W.dequantize()
1016+
else:
1017+
W = W.contiguous()
1018+
out = torch_matmul(X, W.t(), out = out)
1019+
elif W.dtype == torch.float8_e4m3fn:
10251020
out = fp8_linear(X, W, W_quant)
10261021
else:
10271022
W = fast_dequantize(W, W_quant, use_global_buffer = True)

‎unsloth/models/loader.py‎

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from transformers import __version__ as transformers_version
3333
from peft import PeftConfig, PeftModel
3434
from .loader_utils import (
35-
_check_load_in_fp8_settings,
35+
_get_fp8_mode_and_check_settings,
3636
_offline_quantize_to_fp8,
3737
_tag_model_with_fp8_torchao_config,
3838
get_model_name,
@@ -220,19 +220,22 @@ def from_pretrained(
220220
load_in_4bit = False
221221

222222
if load_in_fp8:
223-
_check_load_in_fp8_settings(
223+
fp8_mode = _get_fp8_mode_and_check_settings(
224+
load_in_fp8,
224225
fast_inference,
225226
full_finetuning,
226227
load_in_4bit,
227228
load_in_8bit,
228229
load_in_16bit,
229230
use_exact_model_name,
230231
)
232+
else:
233+
fp8_mode = None
231234

232235
old_model_name = model_name
233236
if not use_exact_model_name:
234237
if load_in_fp8:
235-
model_name = _offline_quantize_to_fp8(model_name)
238+
model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
236239
else:
237240
model_name = get_model_name(model_name, load_in_4bit)
238241

@@ -578,7 +581,7 @@ def from_pretrained(
578581
model.config.update({"quantization_config": quantization_config})
579582

580583
if load_in_fp8:
581-
_tag_model_with_fp8_torchao_config(model)
584+
_tag_model_with_fp8_torchao_config(model, fp8_mode)
582585

583586
if is_peft:
584587
# From https://github.com/huggingface/peft/issues/184
@@ -722,19 +725,22 @@ def from_pretrained(
722725
load_in_4bit = False
723726

724727
if load_in_fp8:
725-
_check_load_in_fp8_settings(
728+
fp8_mode = _get_fp8_mode_and_check_settings(
729+
load_in_fp8,
726730
fast_inference,
727731
full_finetuning,
728732
load_in_4bit,
729733
load_in_8bit,
730734
load_in_16bit,
731735
use_exact_model_name,
732736
)
737+
else:
738+
fp8_mode = None
733739

734740
old_model_name = model_name
735741
if not use_exact_model_name:
736742
if load_in_fp8:
737-
model_name = _offline_quantize_to_fp8(model_name)
743+
model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
738744
else:
739745
model_name = get_model_name(model_name, load_in_4bit)
740746

@@ -1172,7 +1178,7 @@ def from_pretrained(
11721178
model.config.update({"quantization_config": quantization_config})
11731179

11741180
if load_in_fp8:
1175-
_tag_model_with_fp8_torchao_config(model)
1181+
_tag_model_with_fp8_torchao_config(model, fp8_mode)
11761182

11771183
if is_peft:
11781184
# From https://github.com/huggingface/peft/issues/184

‎unsloth/models/loader_utils.py‎

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
import os
1717
import re
1818
import tempfile
19+
from typing import Union
1920
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
2021

2122
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
2223
from packaging.version import Version
2324
from transformers import (
24-
AutoModel,
25-
AutoProcessor,
25+
AutoModelForCausalLM,
26+
AutoTokenizer,
2627
TorchAoConfig,
2728
__version__ as transformers_version,
2829
)
@@ -158,20 +159,27 @@ def get_model_name(model_name, load_in_4bit = True):
158159
return new_model_name if new_model_name is not None else model_name
159160

160161

161-
def _get_torchao_fp8_config():
162+
def _get_torchao_fp8_config(fp8_mode: str):
162163
"""
163164
Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig`
164165
to be used for `load_in_fp8=True`.
165166
"""
166-
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
167+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerBlock, PerRow
168+
169+
if fp8_mode == "row":
170+
granularity = PerRow()
171+
elif fp8_mode == "block":
172+
granularity = (PerBlock([1, 128]), PerBlock([128, 128]))
173+
else:
174+
raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'")
167175

168176
return Float8DynamicActivationFloat8WeightConfig(
169-
granularity = PerRow(),
177+
granularity = granularity,
170178
activation_value_lb = 1e-12,
171179
)
172180

173181

174-
def _offline_quantize_to_fp8(model_name: str) -> str:
182+
def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:
175183
"""
176184
Quantizes the model to fp8 using torchao and saving the quantized model to a
177185
temporary location. Return the path to the quantized model.
@@ -186,53 +194,72 @@ def _offline_quantize_to_fp8(model_name: str) -> str:
186194
)
187195
"""
188196
temp_dir = tempfile.gettempdir()
189-
new_model_name = model_name.split("/")[-1] + "-fp8"
197+
new_model_name = model_name.split("/")[-1] + "-fp8-" + fp8_mode
190198
new_model_name = os.path.join(temp_dir, new_model_name)
191199
print(
192-
f"Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead"
200+
f"Unsloth: Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead"
193201
)
202+
194203
if not os.path.isdir(new_model_name):
195-
qconfig = _get_torchao_fp8_config()
204+
qconfig = _get_torchao_fp8_config(fp8_mode)
196205
qconfig = TorchAoConfig(qconfig)
197-
model = AutoModel.from_pretrained(
206+
# TODO: generalize this to beyond text models?
207+
# Right now using AutoModel removes the `lm_head` layer,
208+
# which is expected later when loading the vllm state dict
209+
model = AutoModelForCausalLM.from_pretrained(
198210
model_name,
199211
torch_dtype = "auto",
200212
device_map = "auto",
201213
quantization_config = qconfig,
202214
)
203-
tokenizer = AutoProcessor.from_pretrained(model_name)
215+
tokenizer = AutoTokenizer.from_pretrained(model_name)
204216
model.save_pretrained(new_model_name, safe_serialization = False)
205217
tokenizer.save_pretrained(new_model_name)
206218
return new_model_name
207219

208220

209-
def _tag_model_with_fp8_torchao_config(model: torch.nn.Module):
221+
def _tag_model_with_fp8_torchao_config(model: torch.nn.Module, fp8_mode: str):
210222
"""
211223
Tag a model with a `TorchAOConfig` so downstream callers will know what to do with it.
212224
"""
213-
base_config = _get_torchao_fp8_config()
225+
base_config = _get_torchao_fp8_config(fp8_mode)
214226
model.torchao_config = TorchAOConfig(
215227
qat_scheme = None,
216228
base_config_and_filter_fns = [(base_config, None)],
217229
)
218230

219231

220-
def _check_load_in_fp8_settings(
232+
def _get_fp8_mode_and_check_settings(
233+
load_in_fp8: Union[bool, str],
221234
fast_inference: bool,
222235
full_finetuning: bool,
223236
load_in_4bit: bool,
224237
load_in_8bit: bool,
225238
load_in_16bit: bool,
226239
use_exact_model_name: bool,
227-
):
240+
) -> str:
228241
"""
229-
Assuming `load_in_fp8=True`, raise appropriate errors on incompatible settings
242+
Assuming `load_in_fp8` is enabled, raise appropriate errors on incompatible settings
230243
and environment. Currently this feature requires:
244+
231245
1. H100 GPUs or after
232246
2. torchao 0.15.0+ (or nightly)
233247
3. torch 2.9.0+
234248
4. If fbgemm_gpu_genai is installed, require 1.4.1+
249+
250+
Returns the fp8 mode, one of "row" or "block".
235251
"""
252+
assert load_in_fp8 is not False
253+
if load_in_fp8 is True:
254+
fp8_mode = "row" # default
255+
else:
256+
fp8_mode = load_in_fp8
257+
258+
# Check user settings
259+
if fp8_mode not in ["row", "block"]:
260+
raise ValueError(
261+
f"Unsloth: `load_in_fp8` can only be 'row' or 'block', got '{fp8_mode}'"
262+
)
236263
if not fast_inference:
237264
raise ValueError(
238265
"Unsloth: `load_in_fp8` is only supported for `fast_inference` for now"
@@ -284,3 +311,4 @@ def _check_load_in_fp8_settings(
284311
raise ValueError(
285312
"Unsloth: `load_in_fp8` is only compatible with fbgemm_gpu_genai 1.4.1+"
286313
)
314+
return fp8_mode

0 commit comments

Comments
 (0)