Skip to content

Conversation

@andrewor14
Copy link
Contributor

Summary: Following #3440, this PR extends torchao FP8 + RL support to also handle 128x128 PerBlock granularity (in addition to PerRow).

Example usage:

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = "block",  # or "row" or True
)

Initial results: TBD

Note:

**Summary:** Following unslothai#3440,
this PR extends torchao FP8 + RL support to also handle 128x128
PerBlock granularity (in addition to PerRow).

**Example usage:**

```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = "block",  # or "row" or True
)
```

**Initial results:** TBD

**Note:**
- Requires pytorch/ao#3370
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @andrewor14, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FP8 quantization capabilities within the unsloth library by integrating 128x128 PerBlock granularity. This new option provides users with more control over the quantization process, potentially leading to improved model performance or memory efficiency for specific use cases. The changes involve updating the core quantization logic and refactoring related utility functions to support the new granularity mode.

Highlights

  • Expanded FP8 Granularity: Introduced support for 128x128 PerBlock granularity for FP8 quantization, complementing the existing PerRow method.
  • Flexible FP8 Loading: The load_in_fp8 parameter now accepts "block" or "row" strings, allowing users to specify the desired quantization granularity.
  • Refactored FP8 Configuration: Centralized FP8 mode determination and validation into a new utility function, _get_fp8_mode_and_check_settings.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully extends FP8 support to include 128x128 PerBlock granularity, in addition to the existing PerRow method. The changes are well-implemented across unsloth/kernels/utils.py, unsloth/models/loader.py, and unsloth/models/loader_utils.py. The logic for selecting and configuring the FP8 mode is clear, and the necessary adjustments to the kernel-level functions are correctly handled. I have one suggestion regarding code maintainability by removing a redundant validation check. Overall, the changes are solid and align well with the PR's objective.

Comment on lines +173 to +174
if fp8_mode == "row":
granularity = PerRow()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This validation is redundant, as fp8_mode is already validated in _get_fp8_mode_and_check_settings before being passed to this function. For internal functions, it's better to rely on assertions for contract checking rather than raising user-facing ValueErrors. This avoids duplicated validation logic and makes the code cleaner.

Consider removing this else block. If you want to keep a check for robustness, an assert would be more appropriate, for example:

assert fp8_mode in ['row', 'block']

However, given the call chain, even an assert is likely unnecessary.

# TODO: generalize this to beyond text models?
# Right now using AutoModel removes the `lm_head` layer,
# which is expected later when loading the vllm state dict
model = AutoModelForCausalLM.from_pretrained(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielhanchen I had to change this back for this to work. When I tried AutoModel it removed the lm_head from the state dict, which later caused an out of bounds exception on this line: https://github.com/unslothai/unsloth-zoo/blob/54dce973426ee61670e15b720619c8539bf05104/unsloth_zoo/vllm_utils.py#L1110. Maybe we can generalize this to multimodal models later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok oh wait I can make this work for vision models without using AutoModel

@danielhanchen danielhanchen merged commit 264ed42 into unslothai:nightly Nov 22, 2025
1 check passed
danielhanchen added a commit that referenced this pull request Nov 25, 2025
* Enable FP8 + RL training for bf16 models (#3440)

* Enable FP8 + RL training for bf16 models

**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage:
- We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16
- We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel
- For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet  (this is in progress: vllm-project/vllm#26327)

**Example usage:**
```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = True,  # set this to True
)

\# the rest is the same as before
model = FastLanguageModel.get_peft_model(...)
```

**Initial results:**
```
\# fp8
{'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01}

\# bf16
{'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01}
```

<img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" />

Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423

**Requires:**
- pytorch/ao#3158 (torchao nightly or 0.15.0+)
- unslothai/unsloth-zoo#351

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* _get_inference_mode_context_manager

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update __init__.py

* Fix/save torchao model loading logic (#3621)

* make loading gpt-oss-BF16 faster. Linked to unsloth-zoo PR #314

* fix model loading and clean merged model directory

* revert default quant

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert mapper.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Update loader_utils.py

* Update loader_utils.py

* Add 128x128 PerBlock FP8 + RL (#3629)

* Add 128x128 PerBlock FP8 + RL

**Summary:** Following #3440,
this PR extends torchao FP8 + RL support to also handle 128x128
PerBlock granularity (in addition to PerRow).

**Example usage:**

```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = "block",  # or "row" or True
)
```

**Initial results:** TBD

**Note:**
- Requires pytorch/ao#3370

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Version

* Update vision.py

* Update rl.py

* Add torch 2.9.1

* Fix auto installer

* Update fp8.py

* Float8

* Update fp8.py

* Update mapper.py

* Update mapper.py

* Update loader_utils.py

* Update loader.py

* Update fp8.py

* Versioning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: andrewor14 <andrewor14@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants