-
-
Notifications
You must be signed in to change notification settings - Fork 4k
Add 128x128 PerBlock FP8 + RL #3629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
**Summary:** Following unslothai#3440, this PR extends torchao FP8 + RL support to also handle 128x128 PerBlock granularity (in addition to PerRow). **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = "block", # or "row" or True ) ``` **Initial results:** TBD **Note:** - Requires pytorch/ao#3370
Summary of ChangesHello @andrewor14, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the FP8 quantization capabilities within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully extends FP8 support to include 128x128 PerBlock granularity, in addition to the existing PerRow method. The changes are well-implemented across unsloth/kernels/utils.py, unsloth/models/loader.py, and unsloth/models/loader_utils.py. The logic for selecting and configuring the FP8 mode is clear, and the necessary adjustments to the kernel-level functions are correctly handled. I have one suggestion regarding code maintainability by removing a redundant validation check. Overall, the changes are solid and align well with the PR's objective.
| if fp8_mode == "row": | ||
| granularity = PerRow() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This validation is redundant, as fp8_mode is already validated in _get_fp8_mode_and_check_settings before being passed to this function. For internal functions, it's better to rely on assertions for contract checking rather than raising user-facing ValueErrors. This avoids duplicated validation logic and makes the code cleaner.
Consider removing this else block. If you want to keep a check for robustness, an assert would be more appropriate, for example:
assert fp8_mode in ['row', 'block']However, given the call chain, even an assert is likely unnecessary.
| # TODO: generalize this to beyond text models? | ||
| # Right now using AutoModel removes the `lm_head` layer, | ||
| # which is expected later when loading the vllm state dict | ||
| model = AutoModelForCausalLM.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielhanchen I had to change this back for this to work. When I tried AutoModel it removed the lm_head from the state dict, which later caused an out of bounds exception on this line: https://github.com/unslothai/unsloth-zoo/blob/54dce973426ee61670e15b720619c8539bf05104/unsloth_zoo/vllm_utils.py#L1110. Maybe we can generalize this to multimodal models later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok oh wait I can make this work for vision models without using AutoModel
* Enable FP8 + RL training for bf16 models (#3440) * Enable FP8 + RL training for bf16 models **Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * _get_inference_mode_context_manager * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update __init__.py * Fix/save torchao model loading logic (#3621) * make loading gpt-oss-BF16 faster. Linked to unsloth-zoo PR #314 * fix model loading and clean merged model directory * revert default quant * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert mapper.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Update loader_utils.py * Update loader_utils.py * Add 128x128 PerBlock FP8 + RL (#3629) * Add 128x128 PerBlock FP8 + RL **Summary:** Following #3440, this PR extends torchao FP8 + RL support to also handle 128x128 PerBlock granularity (in addition to PerRow). **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = "block", # or "row" or True ) ``` **Initial results:** TBD **Note:** - Requires pytorch/ao#3370 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Version * Update vision.py * Update rl.py * Add torch 2.9.1 * Fix auto installer * Update fp8.py * Float8 * Update fp8.py * Update mapper.py * Update mapper.py * Update loader_utils.py * Update loader.py * Update fp8.py * Versioning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: andrewor14 <andrewor14@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Summary: Following #3440, this PR extends torchao FP8 + RL support to also handle 128x128 PerBlock granularity (in addition to PerRow).
Example usage:
Initial results: TBD
Note: