Commit 2f11d8d
committed
Enable FP8 + RL training for bf16 models
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage:
- We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16
- We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel
- For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327)
**Example usage:**
```
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen3-8B-Base",
max_seq_length = 2048,
load_in_4bit = False,
fast_inference = True,
max_lora_rank = 32,
load_in_fp8 = True, # set this to True
)
\# the rest is the same as before
model = FastLanguageModel.get_peft_model(...)
```
**Initial results:**
```
\# fp8
{'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01}
\# bf16
{'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01}
```
<img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" />
Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423
**Requires:**
- pytorch/ao#3158 (torchao nightly or 0.15.0+)
- unslothai/unsloth-zoo#3511 parent b9d9600 commit 2f11d8d
File tree
6 files changed
+234
-10
lines changed- unsloth
- kernels
- models
6 files changed
+234
-10
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
15 | 16 | | |
16 | 17 | | |
17 | 18 | | |
| |||
211 | 212 | | |
212 | 213 | | |
213 | 214 | | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
214 | 219 | | |
215 | 220 | | |
216 | 221 | | |
| |||
329 | 334 | | |
330 | 335 | | |
331 | 336 | | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
332 | 357 | | |
333 | 358 | | |
334 | 359 | | |
335 | 360 | | |
336 | 361 | | |
337 | 362 | | |
| 363 | + | |
338 | 364 | | |
339 | 365 | | |
340 | 366 | | |
| |||
441 | 467 | | |
442 | 468 | | |
443 | 469 | | |
| 470 | + | |
444 | 471 | | |
445 | 472 | | |
446 | 473 | | |
| |||
551 | 578 | | |
552 | 579 | | |
553 | 580 | | |
| 581 | + | |
554 | 582 | | |
555 | 583 | | |
556 | 584 | | |
| |||
987 | 1015 | | |
988 | 1016 | | |
989 | 1017 | | |
990 | | - | |
991 | | - | |
| 1018 | + | |
| 1019 | + | |
992 | 1020 | | |
993 | 1021 | | |
994 | 1022 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2012 | 2012 | | |
2013 | 2013 | | |
2014 | 2014 | | |
2015 | | - | |
| 2015 | + | |
2016 | 2016 | | |
2017 | 2017 | | |
2018 | 2018 | | |
| |||
2262 | 2262 | | |
2263 | 2263 | | |
2264 | 2264 | | |
| 2265 | + | |
| 2266 | + | |
| 2267 | + | |
| 2268 | + | |
| 2269 | + | |
| 2270 | + | |
| 2271 | + | |
| 2272 | + | |
| 2273 | + | |
| 2274 | + | |
| 2275 | + | |
| 2276 | + | |
| 2277 | + | |
| 2278 | + | |
| 2279 | + | |
| 2280 | + | |
| 2281 | + | |
| 2282 | + | |
| 2283 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
24 | | - | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
25 | 28 | | |
26 | 29 | | |
27 | 30 | | |
| |||
2030 | 2033 | | |
2031 | 2034 | | |
2032 | 2035 | | |
2033 | | - | |
| 2036 | + | |
2034 | 2037 | | |
2035 | 2038 | | |
2036 | 2039 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
34 | | - | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
35 | 40 | | |
36 | 41 | | |
37 | 42 | | |
| |||
140 | 145 | | |
141 | 146 | | |
142 | 147 | | |
| 148 | + | |
143 | 149 | | |
144 | 150 | | |
145 | 151 | | |
| |||
183 | 189 | | |
184 | 190 | | |
185 | 191 | | |
| 192 | + | |
186 | 193 | | |
187 | 194 | | |
188 | 195 | | |
| |||
212 | 219 | | |
213 | 220 | | |
214 | 221 | | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
215 | 232 | | |
216 | 233 | | |
217 | | - | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
218 | 239 | | |
219 | 240 | | |
220 | 241 | | |
| |||
476 | 497 | | |
477 | 498 | | |
478 | 499 | | |
| 500 | + | |
| 501 | + | |
479 | 502 | | |
480 | 503 | | |
481 | 504 | | |
| |||
554 | 577 | | |
555 | 578 | | |
556 | 579 | | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
557 | 583 | | |
558 | 584 | | |
559 | 585 | | |
| |||
634 | 660 | | |
635 | 661 | | |
636 | 662 | | |
| 663 | + | |
637 | 664 | | |
638 | 665 | | |
639 | 666 | | |
| |||
694 | 721 | | |
695 | 722 | | |
696 | 723 | | |
| 724 | + | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
| 733 | + | |
697 | 734 | | |
698 | 735 | | |
699 | | - | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
| 740 | + | |
700 | 741 | | |
701 | 742 | | |
702 | 743 | | |
| |||
1130 | 1171 | | |
1131 | 1172 | | |
1132 | 1173 | | |
| 1174 | + | |
| 1175 | + | |
| 1176 | + | |
1133 | 1177 | | |
1134 | 1178 | | |
1135 | 1179 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
15 | 19 | | |
16 | 20 | | |
17 | 21 | | |
18 | 22 | | |
19 | | - | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
20 | 32 | | |
21 | 33 | | |
22 | 34 | | |
| |||
144 | 156 | | |
145 | 157 | | |
146 | 158 | | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
0 commit comments