Skip to content

Commit 2f11d8d

Browse files
committed
Enable FP8 + RL training for bf16 models
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351
1 parent b9d9600 commit 2f11d8d

File tree

6 files changed

+234
-10
lines changed

6 files changed

+234
-10
lines changed

‎unsloth/kernels/utils.py‎

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib
1516
import triton
1617
import ctypes
1718

@@ -211,6 +212,10 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
211212
torch_bfloat16 = torch.bfloat16
212213

213214

215+
# Whether torchao can be imported
216+
_HAS_TORCHAO = importlib.util.find_spec("torchao") is not None
217+
218+
214219
def QUANT_STATE(W):
215220
return getattr(W, "quant_state", None)
216221

@@ -329,12 +334,33 @@ def _maybe_fake_quantize_activations(
329334
return X
330335

331336

337+
def _maybe_dequantize_torchao_float8_tensor(w: torch.Tensor) -> torch.Tensor:
338+
"""
339+
Dequantize `w` if it is a `torchao.quantization.Float8Tensor` and only
340+
during the backward pass, when the tensor is no longer rowwise scaled
341+
because it's been transposed.
342+
"""
343+
if not _HAS_TORCHAO:
344+
return w
345+
from torchao.quantization import Float8Tensor
346+
if not isinstance(w, Float8Tensor):
347+
return w
348+
# In the backward pass, rowwise scaled becomes colwise scaled after we
349+
# transpose the weight tensor. Use this case to detect backward
350+
assert w.ndim == 2
351+
if w.block_size[0] == w.shape[0] and w.block_size[1] == 1:
352+
return w.dequantize()
353+
else:
354+
return w
355+
356+
332357
# INTEL GPU Specific Logic
333358
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
334359

335360
@torch.inference_mode
336361
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
337362
# TODO: After adding XPU BNB support, check this function
363+
W = _maybe_dequantize_torchao_float8_tensor(W)
338364
if quant_state is None:
339365
return W
340366
if W.dtype == torch.float8_e4m3fn:
@@ -441,6 +467,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
441467

442468
@torch.inference_mode
443469
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
470+
W = _maybe_dequantize_torchao_float8_tensor(W)
444471
if quant_state is None:
445472
return W
446473
if W.dtype == torch.float8_e4m3fn:
@@ -551,6 +578,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
551578

552579
@torch.inference_mode
553580
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
581+
W = _maybe_dequantize_torchao_float8_tensor(W)
554582
if quant_state is None:
555583
return W
556584
if W.dtype == torch.float8_e4m3fn:
@@ -987,8 +1015,8 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
9871015
if W.dtype == torch.float8_e4m3fn:
9881016
out = fp8_linear(X, W, W_quant)
9891017
else:
990-
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
991-
out = torch_matmul(X, W, out = out)
1018+
W = fast_dequantize(W, W_quant, use_global_buffer = True)
1019+
out = torch_matmul(X, W.t(), out = out)
9921020
if W_quant is not None:
9931021
del W
9941022

‎unsloth/models/_utils.py‎

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2012,7 +2012,7 @@ def error_out_no_vllm(*args, **kwargs):
20122012

20132013
@dataclass
20142014
class TorchAOConfig:
2015-
qat_scheme: str = "int4"
2015+
qat_scheme: Optional[str] = "int4"
20162016

20172017
# Each (config, filter_fn) pair defines a quantization rule
20182018
base_config_and_filter_fns: List[
@@ -2262,3 +2262,22 @@ def verify_fp8_support_if_applicable(model_config):
22622262
raise ValueError(
22632263
f"Unsloth: FP8 quantization is only supported on L4 and higher GPUs with compute capability 8.9 or higher. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details."
22642264
)
2265+
2266+
2267+
def _get_inference_mode_context_manager(model: torch.nn.Module):
2268+
"""
2269+
If the state dict was quantized using torchao, we will run into
2270+
the following error when calling ops like aten.t() in inference mode.
2271+
This is a bug in PyTorch that affects all tensor subclasses.
2272+
2273+
Cannot set version_counter for inference tensor
2274+
2275+
For now, we work around this issue by using `torch.no_grad()` in this case.
2276+
See https://github.com/pytorch/pytorch/issues/164872 for more details.
2277+
Otherwise, just return `torch.inference_mode()`.
2278+
"""
2279+
torchao_config = getattr(model, "torchao_config", None)
2280+
if torchao_config is not None and torchao_config.qat_scheme is None:
2281+
return torch.no_grad()
2282+
else:
2283+
return torch.inference_mode()

‎unsloth/models/llama.py‎

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from ._utils import patch_unsloth_smart_gradient_checkpointing
2222
from ._utils import __version__, importlib_version
2323
from ._utils import move_to_device
24-
from ._utils import _prepare_model_for_qat
24+
from ._utils import (
25+
_get_inference_mode_context_manager,
26+
_prepare_model_for_qat,
27+
)
2528
from torch.nn.functional import scaled_dot_product_attention
2629
from transformers import __version__ as transformers_version
2730
from unsloth_zoo.utils import Version, _get_dtype
@@ -2030,7 +2033,7 @@ def unsloth_fast_generate(
20302033

20312034
# Mixed precision autocast
20322035
with (
2033-
torch.inference_mode(),
2036+
_get_inference_mode_context_manager(self),
20342037
torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype),
20352038
):
20362039
output = self._old_generate(*args, **kwargs)

‎unsloth/models/loader.py‎

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
from transformers import AutoConfig
3232
from transformers import __version__ as transformers_version
3333
from peft import PeftConfig, PeftModel
34-
from .loader_utils import get_model_name
34+
from .loader_utils import (
35+
_check_load_in_fp8_settings,
36+
_offline_quantize_to_fp8,
37+
_tag_model_with_fp8_torchao_config,
38+
get_model_name,
39+
)
3540
import os, contextlib, sys
3641

3742
try:
@@ -140,6 +145,7 @@ def from_pretrained(
140145
max_lora_rank = 64,
141146
disable_log_stats = True,
142147
qat_scheme = None,
148+
load_in_fp8 = False, # fp8 LoRA
143149
*args,
144150
**kwargs,
145151
):
@@ -183,6 +189,7 @@ def from_pretrained(
183189
max_lora_rank = max_lora_rank,
184190
disable_log_stats = disable_log_stats,
185191
qat_scheme = qat_scheme,
192+
load_in_fp8 = load_in_fp8,
186193
*args,
187194
**kwargs,
188195
)
@@ -212,9 +219,23 @@ def from_pretrained(
212219
)
213220
load_in_4bit = False
214221

222+
if load_in_fp8:
223+
_check_load_in_fp8_settings(
224+
fast_inference,
225+
full_finetuning,
226+
load_in_4bit,
227+
load_in_8bit,
228+
load_in_16bit,
229+
use_exact_model_name,
230+
)
231+
215232
old_model_name = model_name
216233
if not use_exact_model_name:
217-
model_name = get_model_name(model_name, load_in_4bit)
234+
if load_in_fp8:
235+
model_name = _offline_quantize_to_fp8(model_name)
236+
else:
237+
model_name = get_model_name(model_name, load_in_4bit)
238+
218239
# Check if pre-quantized models are allowed
219240
# For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
220241
if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
@@ -476,6 +497,8 @@ def from_pretrained(
476497
random_state = random_state,
477498
max_lora_rank = max_lora_rank,
478499
disable_log_stats = disable_log_stats,
500+
qat_scheme = qat_scheme,
501+
load_in_fp8 = load_in_fp8,
479502
*args,
480503
**kwargs,
481504
)
@@ -554,6 +577,9 @@ def from_pretrained(
554577
}
555578
model.config.update({"quantization_config": quantization_config})
556579

580+
if load_in_fp8:
581+
_tag_model_with_fp8_torchao_config(model)
582+
557583
if is_peft:
558584
# From https://github.com/huggingface/peft/issues/184
559585
# Now add PEFT adapters
@@ -634,6 +660,7 @@ def from_pretrained(
634660
max_lora_rank = 64,
635661
disable_log_stats = True,
636662
qat_scheme = None,
663+
load_in_fp8 = False, # fp8 LoRA
637664
*args,
638665
**kwargs,
639666
):
@@ -694,9 +721,23 @@ def from_pretrained(
694721
)
695722
load_in_4bit = False
696723

724+
if load_in_fp8:
725+
_check_load_in_fp8_settings(
726+
fast_inference,
727+
full_finetuning,
728+
load_in_4bit,
729+
load_in_8bit,
730+
load_in_16bit,
731+
use_exact_model_name,
732+
)
733+
697734
old_model_name = model_name
698735
if not use_exact_model_name:
699-
model_name = get_model_name(model_name, load_in_4bit)
736+
if load_in_fp8:
737+
model_name = _offline_quantize_to_fp8(model_name)
738+
else:
739+
model_name = get_model_name(model_name, load_in_4bit)
740+
700741
# Check if pre-quantized models are allowed
701742
# For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
702743
if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
@@ -1130,6 +1171,9 @@ def from_pretrained(
11301171
}
11311172
model.config.update({"quantization_config": quantization_config})
11321173

1174+
if load_in_fp8:
1175+
_tag_model_with_fp8_torchao_config(model)
1176+
11331177
if is_peft:
11341178
# From https://github.com/huggingface/peft/issues/184
11351179
# Now add PEFT adapters

‎unsloth/models/loader_utils.py‎

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib
16+
import os
17+
import re
18+
import tempfile
1519
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
1620

1721
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
1822
from packaging.version import Version
19-
from transformers import __version__ as transformers_version
23+
from transformers import (
24+
AutoModelForCausalLM,
25+
AutoTokenizer,
26+
TorchAoConfig,
27+
__version__ as transformers_version,
28+
)
29+
from unsloth.models._utils import TorchAOConfig
30+
from unsloth_zoo.utils import Version
31+
import torch
2032

2133
transformers_version = Version(transformers_version)
2234
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
@@ -144,3 +156,119 @@ def get_model_name(model_name, load_in_4bit = True):
144156
'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'
145157
)
146158
return new_model_name if new_model_name is not None else model_name
159+
160+
161+
def _get_torchao_fp8_config():
162+
"""
163+
Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig`
164+
to be used for `load_in_fp8=True`.
165+
"""
166+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
167+
168+
return Float8DynamicActivationFloat8WeightConfig(
169+
granularity=PerRow(),
170+
activation_value_lb=1e-12,
171+
)
172+
173+
174+
def _offline_quantize_to_fp8(model_name: str) -> str:
175+
"""
176+
Quantizes the model to fp8 using torchao and saving the quantized model to a
177+
temporary location. Return the path to the quantized model.
178+
179+
Note: Once on-the-fly quantization is added in vllm in
180+
https://github.com/vllm-project/vllm/pull/26327, we should
181+
dynamically quantize the model there instead:
182+
183+
llm = LLM(
184+
...
185+
hf_overrides={"quantization_config_file": "torchao_config.json"},
186+
)
187+
"""
188+
temp_dir = tempfile.gettempdir()
189+
new_model_name = model_name.split("/")[-1] + "-fp8"
190+
new_model_name = os.path.join(temp_dir, new_model_name)
191+
print(f"Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead")
192+
if not os.path.isdir(new_model_name):
193+
qconfig = _get_torchao_fp8_config()
194+
qconfig = TorchAoConfig(qconfig)
195+
model = AutoModelForCausalLM.from_pretrained(
196+
model_name,
197+
torch_dtype="auto",
198+
device_map="auto",
199+
quantization_config=qconfig,
200+
)
201+
tokenizer = AutoTokenizer.from_pretrained(model_name)
202+
model.save_pretrained(new_model_name, safe_serialization=False)
203+
tokenizer.save_pretrained(new_model_name)
204+
return new_model_name
205+
206+
207+
def _tag_model_with_fp8_torchao_config(model: torch.nn.Module):
208+
"""
209+
Tag a model with a `TorchAOConfig` so downstream callers will know what to do with it.
210+
"""
211+
base_config = _get_torchao_fp8_config()
212+
model.torchao_config = TorchAOConfig(
213+
qat_scheme=None,
214+
base_config_and_filter_fns=[(base_config, None)],
215+
)
216+
217+
218+
def _check_load_in_fp8_settings(
219+
fast_inference: bool,
220+
full_finetuning: bool,
221+
load_in_4bit: bool,
222+
load_in_8bit: bool,
223+
load_in_16bit: bool,
224+
use_exact_model_name: bool,
225+
):
226+
"""
227+
Assuming `load_in_fp8=True`, raise appropriate errors on incompatible settings
228+
and environment. Currently this feature requires:
229+
1. H100 GPUs or after
230+
2. torchao 0.15.0+ (or nightly)
231+
3. torch 2.9.0+
232+
4. If fbgemm_gpu_genai is installed, require 1.4.1+
233+
"""
234+
if not fast_inference:
235+
raise ValueError("Unsloth: `load_in_fp8` is only supported for `fast_inference` for now")
236+
if full_finetuning:
237+
raise ValueError("Unsloth: `load_in_fp8` is not compatible with full finetuning")
238+
if load_in_4bit or load_in_8bit or load_in_16bit:
239+
raise ValueError(
240+
"Unsloth: `load_in_fp8` is not compatible with `load_in_4bit`, `load_in_8bit` or `load_in_16bit`",
241+
)
242+
if use_exact_model_name:
243+
raise ValueError("Unsloth: `load_in_fp8` requires `use_exact_model_name=False`")
244+
245+
# Check if this is Hopper or above
246+
if not (torch.cuda.is_available()
247+
and torch.version.cuda
248+
and torch.cuda.get_device_capability() >= (9, 0)
249+
):
250+
raise ValueError("Unsloth: `load_in_fp8` requires H100 GPUs or after")
251+
252+
# Check if torch >= 2.9.0
253+
if Version(torch.__version__) < Version("2.9.0"):
254+
raise ValueError("Unsloth: `load_in_fp8` requires torch 2.9.0+")
255+
256+
# Check if torchao has this PR: https://github.com/pytorch/ao/pull/3158,
257+
# which will be released in 0.15.0.
258+
error_message = "Unsloth: `load_in_fp8` requires torchao 0.15.0+ (or nightly)"
259+
if importlib.util.find_spec("torchao") is None:
260+
raise ValueError(error_message)
261+
import torchao
262+
263+
if Version(torchao.__version__) < Version("0.15.0"):
264+
raise ValueError(error_message)
265+
266+
# Check if fbgemm_gpu_genai is installed, if so, require >= 1.4.1
267+
if (
268+
importlib.util.find_spec("fbgemm_gpu") is not None and
269+
importlib.util.find_spec("fbgemm_gpu.experimental") is not None
270+
):
271+
import fbgemm_gpu.experimental.gen_ai
272+
273+
if Version(fbgemm_gpu.__version__) < Version("1.4.1"):
274+
raise ValueError("Unsloth: `load_in_fp8` is only compatible with fbgemm_gpu_genai 1.4.1+")

0 commit comments

Comments
 (0)