Commit 366328f
committed
Enable FP8 + RL training for bf16 models
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage:
- We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16
- We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel
- For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327)
**Example usage:**
```
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen3-8B-Base",
max_seq_length = 2048,
load_in_4bit = False,
fast_inference = True,
max_lora_rank = 32,
load_in_fp8 = True, # set this to True
)
\# the rest is the same as before
model = FastLanguageModel.get_peft_model(...)
```
**Initial results:**
```
\# fp8
{'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01}
\# bf16
{'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01}
```
<img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" />
Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423
**Requires:**
- pytorch/ao#3158 (torchao nightly or 0.15.0+)
- unslothai/unsloth-zoo#3511 parent b9d9600 commit 366328f
File tree
5 files changed
+181
-8
lines changed- unsloth
- kernels
- models
5 files changed
+181
-8
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
15 | 16 | | |
16 | 17 | | |
17 | 18 | | |
| |||
211 | 212 | | |
212 | 213 | | |
213 | 214 | | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
214 | 219 | | |
215 | 220 | | |
216 | 221 | | |
| |||
329 | 334 | | |
330 | 335 | | |
331 | 336 | | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
332 | 357 | | |
333 | 358 | | |
334 | 359 | | |
335 | 360 | | |
336 | 361 | | |
337 | 362 | | |
| 363 | + | |
338 | 364 | | |
339 | 365 | | |
340 | 366 | | |
| |||
441 | 467 | | |
442 | 468 | | |
443 | 469 | | |
| 470 | + | |
444 | 471 | | |
445 | 472 | | |
446 | 473 | | |
| |||
551 | 578 | | |
552 | 579 | | |
553 | 580 | | |
| 581 | + | |
554 | 582 | | |
555 | 583 | | |
556 | 584 | | |
| |||
987 | 1015 | | |
988 | 1016 | | |
989 | 1017 | | |
990 | | - | |
991 | | - | |
| 1018 | + | |
| 1019 | + | |
992 | 1020 | | |
993 | 1021 | | |
994 | 1022 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2012 | 2012 | | |
2013 | 2013 | | |
2014 | 2014 | | |
2015 | | - | |
| 2015 | + | |
2016 | 2016 | | |
2017 | 2017 | | |
2018 | 2018 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
34 | | - | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
35 | 40 | | |
36 | 41 | | |
37 | 42 | | |
| |||
140 | 145 | | |
141 | 146 | | |
142 | 147 | | |
| 148 | + | |
143 | 149 | | |
144 | 150 | | |
145 | 151 | | |
| |||
183 | 189 | | |
184 | 190 | | |
185 | 191 | | |
| 192 | + | |
186 | 193 | | |
187 | 194 | | |
188 | 195 | | |
| |||
212 | 219 | | |
213 | 220 | | |
214 | 221 | | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
215 | 232 | | |
216 | 233 | | |
217 | | - | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
218 | 239 | | |
219 | 240 | | |
220 | 241 | | |
| |||
476 | 497 | | |
477 | 498 | | |
478 | 499 | | |
| 500 | + | |
| 501 | + | |
479 | 502 | | |
480 | 503 | | |
481 | 504 | | |
| |||
554 | 577 | | |
555 | 578 | | |
556 | 579 | | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
557 | 583 | | |
558 | 584 | | |
559 | 585 | | |
| |||
634 | 660 | | |
635 | 661 | | |
636 | 662 | | |
| 663 | + | |
637 | 664 | | |
638 | 665 | | |
639 | 666 | | |
| |||
694 | 721 | | |
695 | 722 | | |
696 | 723 | | |
| 724 | + | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
| 733 | + | |
697 | 734 | | |
698 | 735 | | |
699 | | - | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
| 740 | + | |
700 | 741 | | |
701 | 742 | | |
702 | 743 | | |
| |||
1130 | 1171 | | |
1131 | 1172 | | |
1132 | 1173 | | |
| 1174 | + | |
| 1175 | + | |
| 1176 | + | |
1133 | 1177 | | |
1134 | 1178 | | |
1135 | 1179 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
15 | 18 | | |
16 | 19 | | |
17 | 20 | | |
18 | 21 | | |
19 | | - | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
20 | 30 | | |
21 | 31 | | |
22 | 32 | | |
| |||
144 | 154 | | |
145 | 155 | | |
146 | 156 | | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
536 | 536 | | |
537 | 537 | | |
538 | 538 | | |
539 | | - | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
540 | 553 | | |
541 | 554 | | |
542 | 555 | | |
| |||
0 commit comments