Skip to content

Conversation

@rolandtannous
Copy link
Collaborator

@rolandtannous rolandtannous commented Nov 20, 2025

Problem

1- AutoModel: We previously used AutoModel from pretrained to load models in our implementation of the save_pretrained_torchao method. However AutoModel loads the base model without a specific task-head.

2-Merged model directory: save_pretrained_torchao merges the model (if necessary) before converting and does not delete the merged model directory after the process is done which consumes unnecessary disk space and potentially confuses users a

Solution

1-AutoModel:

  • We test if the model is a vlm or a text model
  • We set auto_model to either AutoModelForCausalLM or AutoModelForImageTextToText based on the test result
  • We also set auto_processor to either AutoProcessor or AutoTokenizer based on the above test
    This loads the model properly
is_vlm = False
    if hasattr(self, "config") and hasattr(self.config, "architectures"):
        is_vlm = any(
            x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
            for x in self.config.architectures
        )
        is_vlm = is_vlm or hasattr(self.config, "vision_config")
    auto_model = AutoModelForImageTextToText if is_vlm else AutoModelForCausalLM
    auto_processor = AutoProcessor if is_vlm else AutoTokenizer

    tokenizer = auto_processor.from_pretrained(arguments["save_directory"])
...
 model = auto_model.from_pretrained(
        arguments["save_directory"],
        device_map = "auto",
        quantization_config = quantization_config,
        **kwargs,
    )

2 - Merged model directory: We remove the save_directory directory tree after the conversion process is complete. The torchao converted model is saved to the torchao_save_directory instead.

if os.path.exists(save_directory):
        try:
            import shutil
            shutil.rmtree(save_directory)
        except:
            pass
            

solves

#3599

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @rolandtannous, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refines the save_pretrained_torchao functionality by enhancing model loading to correctly identify and load Vision-Language Models (VLMs) or text models with their appropriate task-specific heads. Additionally, it introduces a cleanup mechanism to automatically remove temporary merged model directories after conversion, optimizing disk space usage. A minor update to the model mapper is also included.

Highlights

  • Improved torchao model loading logic: The save_pretrained_torchao method now intelligently selects between AutoModelForCausalLM and AutoModelForImageTextToText (and their respective processors) based on whether the model is a Vision-Language Model (VLM) or a text model. This ensures models are loaded with the correct task-head, addressing an issue where AutoModel loaded only the base model.
  • Clean up of merged model directories: After the torchao conversion process, the temporary merged model directory is now automatically removed. This prevents unnecessary disk space consumption and potential user confusion by ensuring only the final torchao converted model directory remains.
  • Added BF16 model to mapper: The unsloth/gpt-oss-20b-unsloth-bnb-4bit entry in the model mapper has been updated to include unsloth/gpt-oss-20b-BF16, expanding the recognized model aliases.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses two issues with the save_pretrained_torchao method. First, it now uses the appropriate AutoModel class (AutoModelForCausalLM or AutoModelForImageTextToText) based on whether the model is a Vision Language Model, ensuring that the correct model head is loaded. Second, it adds a cleanup step to remove the temporary merged model directory, which helps conserve disk space. The logic for detecting VLMs and the overall changes look good. I've left one comment regarding error handling in the cleanup step to make it more robust.

@danielhanchen danielhanchen changed the base branch from main to nightly November 20, 2025 08:02
@danielhanchen danielhanchen merged commit 2765b51 into unslothai:nightly Nov 20, 2025
1 check passed
danielhanchen added a commit that referenced this pull request Nov 25, 2025
* Enable FP8 + RL training for bf16 models (#3440)

* Enable FP8 + RL training for bf16 models

**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage:
- We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16
- We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel
- For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet  (this is in progress: vllm-project/vllm#26327)

**Example usage:**
```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = True,  # set this to True
)

\# the rest is the same as before
model = FastLanguageModel.get_peft_model(...)
```

**Initial results:**
```
\# fp8
{'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01}

\# bf16
{'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01}
```

<img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" />

Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423

**Requires:**
- pytorch/ao#3158 (torchao nightly or 0.15.0+)
- unslothai/unsloth-zoo#351

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* _get_inference_mode_context_manager

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update __init__.py

* Fix/save torchao model loading logic (#3621)

* make loading gpt-oss-BF16 faster. Linked to unsloth-zoo PR #314

* fix model loading and clean merged model directory

* revert default quant

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert mapper.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Update loader_utils.py

* Update loader_utils.py

* Add 128x128 PerBlock FP8 + RL (#3629)

* Add 128x128 PerBlock FP8 + RL

**Summary:** Following #3440,
this PR extends torchao FP8 + RL support to also handle 128x128
PerBlock granularity (in addition to PerRow).

**Example usage:**

```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = "block",  # or "row" or True
)
```

**Initial results:** TBD

**Note:**
- Requires pytorch/ao#3370

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Version

* Update vision.py

* Update rl.py

* Add torch 2.9.1

* Fix auto installer

* Update fp8.py

* Float8

* Update fp8.py

* Update mapper.py

* Update mapper.py

* Update loader_utils.py

* Update loader.py

* Update fp8.py

* Versioning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: andrewor14 <andrewor14@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants