Skip to content

Commit c4f8c67

Browse files
committed
Enable FP8 + RL training for bf16 models
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 40% lower memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly) - unslothai/unsloth-zoo#351
1 parent aa7cfa1 commit c4f8c67

File tree

6 files changed

+148
-6
lines changed

6 files changed

+148
-6
lines changed

‎unsloth/kernels/fast_lora.py‎

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from .utils import (
17+
_maybe_dequantize_torchao_float8_tensor,
1718
_maybe_fake_quantize_activations,
1819
fast_dequantize,
1920
QUANT_STATE,
@@ -128,6 +129,10 @@ def backward(ctx, dY: torch.Tensor):
128129
) = ctx.custom_saved_tensors
129130
gateA, gateB, upA, upB, downA, downB, X, e, g = ctx.saved_tensors
130131

132+
gateW = _maybe_dequantize_torchao_float8_tensor(gateW)
133+
upW = _maybe_dequantize_torchao_float8_tensor(upW)
134+
downW = _maybe_dequantize_torchao_float8_tensor(downW)
135+
131136
batch, seq_len, hd = X.shape
132137
dY = dY.view(-1, dY.shape[-1])
133138
X = X.view(-1, X.shape[-1])
@@ -420,6 +425,10 @@ def backward(ctx, dQ, dK, dV):
420425
VB,
421426
) = ctx.saved_tensors
422427

428+
QW = _maybe_dequantize_torchao_float8_tensor(QW)
429+
KW = _maybe_dequantize_torchao_float8_tensor(KW)
430+
VW = _maybe_dequantize_torchao_float8_tensor(VW)
431+
423432
batch, seq_len, hd = X.shape
424433
dQ = dQ.view(-1, dQ.shape[-1])
425434
dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
@@ -593,6 +602,8 @@ def backward(ctx, dY: torch.Tensor):
593602
W, W_quant, S = ctx.custom_saved_tensors
594603
A, B, X = ctx.saved_tensors
595604

605+
W = _maybe_dequantize_torchao_float8_tensor(W)
606+
596607
batch, seq_len, hd = X.shape
597608
dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
598609
X = X.reshape(-1, X.shape[-1]) # Must be reshape

‎unsloth/kernels/utils.py‎

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib
1516
import triton
1617
import ctypes
1718

@@ -211,6 +212,10 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
211212
torch_bfloat16 = torch.bfloat16
212213

213214

215+
# Whether torchao can be imported
216+
_HAS_TORCHAO = importlib.util.find_spec("torchao") is not None
217+
218+
214219
def QUANT_STATE(W):
215220
return getattr(W, "quant_state", None)
216221

@@ -329,6 +334,20 @@ def _maybe_fake_quantize_activations(
329334
return X
330335

331336

337+
def _maybe_dequantize_torchao_float8_tensor(x: torch.Tensor) -> torch.Tensor:
338+
"""
339+
If `x` is a `torchao.quantization.Float8Tensor`, dequantize it.
340+
This is used in the backward pass of LoRA autograd functions.
341+
"""
342+
if not _HAS_TORCHAO:
343+
return x
344+
from torchao.quantization import Float8Tensor
345+
if isinstance(x, Float8Tensor):
346+
return x.dequantize()
347+
else:
348+
return x
349+
350+
332351
# INTEL GPU Specific Logic
333352
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
334353

‎unsloth/models/_utils.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2012,7 +2012,7 @@ def error_out_no_vllm(*args, **kwargs):
20122012

20132013
@dataclass
20142014
class TorchAOConfig:
2015-
qat_scheme: str = "int4"
2015+
qat_scheme: Optional[str] = "int4"
20162016
base_config: AOBaseConfig = field(
20172017
default_factory = lambda: Int4WeightOnlyConfig(group_size = 128)
20182018
)

‎unsloth/models/loader.py‎

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131
from transformers import AutoConfig
3232
from transformers import __version__ as transformers_version
3333
from peft import PeftConfig, PeftModel
34-
from .loader_utils import get_model_name
34+
from .loader_utils import (
35+
_offline_quantize_to_fp8,
36+
_tag_model_with_fp8_torchao_config,
37+
get_model_name,
38+
)
3539
import os, contextlib, sys
3640

3741
try:
@@ -139,6 +143,7 @@ def from_pretrained(
139143
max_lora_rank = 64,
140144
disable_log_stats = True,
141145
qat_scheme = None,
146+
load_in_fp8 = False, # fp8 LoRA
142147
*args,
143148
**kwargs,
144149
):
@@ -182,6 +187,7 @@ def from_pretrained(
182187
max_lora_rank = max_lora_rank,
183188
disable_log_stats = disable_log_stats,
184189
qat_scheme = qat_scheme,
190+
load_in_fp8 = load_in_fp8,
185191
*args,
186192
**kwargs,
187193
)
@@ -211,9 +217,24 @@ def from_pretrained(
211217
)
212218
load_in_4bit = False
213219

220+
if load_in_fp8 and not fast_inference:
221+
raise ValueError("Unsloth: `load_in_fp8` is only supported for `fast_inference` for now")
222+
if load_in_fp8 and full_finetuning:
223+
raise ValueError("Unsloth: `load_in_fp8` is not compatible with full finetuning")
224+
if load_in_fp8 and (load_in_4bit or load_in_8bit or load_in_16bit):
225+
raise ValueError(
226+
"Unsloth: `load_in_fp8` is not compatible with `load_in_4bit`, `load_in_8bit` or `load_in_16bit`",
227+
)
228+
if load_in_fp8 and use_exact_model_name:
229+
raise ValueError("Unsloth: `load_in_fp8` requires `use_exact_model_name=False`")
230+
214231
old_model_name = model_name
215232
if not use_exact_model_name:
216-
model_name = get_model_name(model_name, load_in_4bit)
233+
if load_in_fp8:
234+
model_name = _offline_quantize_to_fp8(model_name)
235+
else:
236+
model_name = get_model_name(model_name, load_in_4bit)
237+
217238
# Check if pre-quantized models are allowed
218239
# For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
219240
if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
@@ -475,6 +496,8 @@ def from_pretrained(
475496
random_state = random_state,
476497
max_lora_rank = max_lora_rank,
477498
disable_log_stats = disable_log_stats,
499+
qat_scheme = qat_scheme,
500+
load_in_fp8 = load_in_fp8,
478501
*args,
479502
**kwargs,
480503
)
@@ -553,6 +576,9 @@ def from_pretrained(
553576
}
554577
model.config.update({"quantization_config": quantization_config})
555578

579+
if load_in_fp8:
580+
_tag_model_with_fp8_torchao_config(model)
581+
556582
if is_peft:
557583
# From https://github.com/huggingface/peft/issues/184
558584
# Now add PEFT adapters
@@ -621,6 +647,7 @@ def from_pretrained(
621647
max_lora_rank = 64,
622648
disable_log_stats = True,
623649
qat_scheme = None,
650+
load_in_fp8 = False, # fp8 LoRA
624651
*args,
625652
**kwargs,
626653
):
@@ -681,9 +708,24 @@ def from_pretrained(
681708
)
682709
load_in_4bit = False
683710

711+
if load_in_fp8 and not fast_inference:
712+
raise ValueError("Unsloth: `load_in_fp8` is only supported for `fast_inference` for now")
713+
if load_in_fp8 and full_finetuning:
714+
raise ValueError("Unsloth: `load_in_fp8` is not compatible with full finetuning")
715+
if load_in_fp8 and (load_in_4bit or load_in_8bit or load_in_16bit):
716+
raise ValueError(
717+
"Unsloth: `load_in_fp8` is not compatible with `load_in_4bit`, `load_in_8bit` or `load_in_16bit`",
718+
)
719+
if load_in_fp8 and use_exact_model_name:
720+
raise ValueError("Unsloth: `load_in_fp8` requires `use_exact_model_name=False`")
721+
684722
old_model_name = model_name
685723
if not use_exact_model_name:
686-
model_name = get_model_name(model_name, load_in_4bit)
724+
if load_in_fp8:
725+
model_name = _offline_quantize_to_fp8(model_name)
726+
else:
727+
model_name = get_model_name(model_name, load_in_4bit)
728+
687729
# Check if pre-quantized models are allowed
688730
# For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
689731
if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
@@ -1117,6 +1159,9 @@ def from_pretrained(
11171159
}
11181160
model.config.update({"quantization_config": quantization_config})
11191161

1162+
if load_in_fp8:
1163+
_tag_model_with_fp8_torchao_config(model)
1164+
11201165
if is_peft:
11211166
# From https://github.com/huggingface/peft/issues/184
11221167
# Now add PEFT adapters

‎unsloth/models/loader_utils.py‎

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
import tempfile
1517
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
1618

1719
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
1820
from packaging.version import Version
19-
from transformers import __version__ as transformers_version
21+
from transformers import (
22+
AutoModelForCausalLM,
23+
AutoTokenizer,
24+
TorchAoConfig,
25+
__version__ as transformers_version,
26+
)
27+
from unsloth.models._utils import TorchAOConfig
28+
import torch
2029

2130
transformers_version = Version(transformers_version)
2231
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
@@ -144,3 +153,48 @@ def get_model_name(model_name, load_in_4bit = True):
144153
'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'
145154
)
146155
return new_model_name if new_model_name is not None else model_name
156+
157+
158+
def _offline_quantize_to_fp8(model_name: str) -> str:
159+
"""
160+
Quantizes the model to fp8 using torchao and saving the quantized model to a
161+
temporary location. Return the path to the quantized model.
162+
163+
Note: Once on-the-fly quantization is added in vllm in
164+
https://github.com/vllm-project/vllm/pull/26327, we should
165+
dynamically quantize the model there instead:
166+
167+
llm = LLM(
168+
...
169+
hf_overrides={"quantization_config_file": "torchao_config.json"},
170+
)
171+
"""
172+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
173+
174+
temp_dir = tempfile.gettempdir()
175+
new_model_name = model_name.split("/")[-1] + "-fp8"
176+
new_model_name = os.path.join(temp_dir, new_model_name)
177+
print(f"Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead")
178+
if not os.path.isdir(new_model_name):
179+
qconfig = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
180+
qconfig = TorchAoConfig(qconfig)
181+
model = AutoModelForCausalLM.from_pretrained(
182+
model_name,
183+
torch_dtype="auto",
184+
device_map="auto",
185+
quantization_config=qconfig,
186+
)
187+
tokenizer = AutoTokenizer.from_pretrained(model_name)
188+
model.save_pretrained(new_model_name, safe_serialization=False)
189+
tokenizer.save_pretrained(new_model_name)
190+
return new_model_name
191+
192+
193+
def _tag_model_with_fp8_torchao_config(model: torch.nn.Module):
194+
"""
195+
Tag a model with a `TorchAOConfig` so downstream callers will know what to do with it.
196+
"""
197+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
198+
199+
base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
200+
model.torchao_config = TorchAOConfig(qat_scheme=None, base_config=base_config)

‎unsloth/models/rl_replacements.py‎

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,20 @@ def _get_per_token_logps_and_entropies(
536536
)
537537

538538
with torch.amp.autocast(device_type = "cuda", dtype = self._autocast_dtype):
539-
with torch.inference_mode():
539+
# If the state dict was quantized using torchao, we will run into
540+
# the following error when calling ops like aten.t() in inference mode.
541+
# This is a bug in PyTorch that affects all tensor subclasses.
542+
#
543+
# Cannot set version_counter for inference tensor
544+
#
545+
# For now, we work around this issue by using torch.no_grad in this case.
546+
# See https://github.com/pytorch/pytorch/issues/164872 for more details
547+
torchao_config = getattr(model, "torchao_config", None)
548+
if torchao_config is not None and torchao_config.qat_scheme is None:
549+
ctx_manager = torch.no_grad()
550+
else:
551+
ctx_manager = torch.inference_mode()
552+
with ctx_manager:
540553
if pixel_values is None:
541554
attention_mask = input_ids != self.processing_class.pad_token_id
542555
attention_mask = attention_mask.to(attention_mask.dtype)

0 commit comments

Comments
 (0)