Skip to content

Conversation

@metascroy
Copy link
Contributor

This PR updates TorchAOConfig to be more flexible. Specifically:

  1. We allow multiple base_config/filter_fn so that users can quantize different layers differently, or quantize both linear/embedding layers. Previously only linear layers were quantized and all linear layers used the same quantiation.

  2. Adds a new prequantization_transform to TorchAOConfig that is applied before calling quantize_. This is useful for untying parameters if you want to quantize the input/output embeddings differently in tied models.

  3. Adds utility _untie_input_output_embeddings to untie input/output embeddings

  4. Redefines int8-int4 to be more mobile friendly by adding embedding quantization.

  5. Adds new _convert_torchao_model utility to convert the QAT'd torch model to a format that can be distributed on HuggingFace.

With these changes, we can QAT a model for ExecuTorch with:

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_id,
    max_seq_length=max_seq_length,
    dtype=torch.bfloat16,
    load_in_4bit =False,
    full_finetuning=True,
    qat_scheme="int8-int4",
)

...

trainer_stats = trainer.train()
...

from unsloth.models._utils import _convert_torchao_model
_convert_torchao_model(model)

After calling FastLanguageModel.from_pretrained, the model has torchao FakeQuantizedLinear and FakeQuantizedEmbedding modules as specified by the qat_scheme ("int8-int4").

After calling _convert_torchao_model, the model has tensor subclasses and an HF quantization config attached.

@metascroy
Copy link
Contributor Author

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @metascroy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and deployment capabilities of TorchAO-based quantization within the system. It introduces a more sophisticated configuration for defining layer-specific quantization rules and pre-quantization model transformations, particularly benefiting embedding layers. Furthermore, it provides a crucial utility for converting QAT-trained models into a standard HuggingFace format, streamlining the process of distributing quantized models.

Highlights

  • Flexible Quantization Rules: The TorchAOConfig has been updated to support multiple base_config and filter_fn pairs. This allows users to apply different quantization strategies to various layers within a model, such as linear and embedding layers, providing more granular control than before where only linear layers were uniformly quantized.
  • Prequantization Transformation Hook: A new prequantization_transform field has been added to TorchAOConfig. This optional callable allows for model modifications, like untying parameters, to be applied before the actual quantization process begins, which is particularly useful for custom quantization needs.
  • Untie Input/Output Embeddings Utility: A new utility function, _untie_input_output_embeddings, has been introduced. This function modifies a HuggingFace model in-place to untie its input and output embeddings, enabling them to be quantized differently.
  • Enhanced int8-int4 QAT Scheme: The int8-int4 QAT scheme has been redefined to be more mobile-friendly. It now includes 8-bit per-axis quantization for embedding layers and 4-bit per-group quantization for other layers, along with the automatic untying of input/output embeddings via the new prequantization_transform.
  • HuggingFace Model Conversion Utility: A new utility, _convert_torchao_model, has been added. This function converts a QAT-prepared Torch model into a format compatible with HuggingFace, attaching a TorchAoConfig to the model's configuration for easier distribution.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@metascroy metascroy changed the title up Nov 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements to the TorchAOConfig for more flexible quantization-aware training. Key changes include support for multiple quantization configurations, a new pre-quantization transform, and utilities for untying embeddings and converting models. While the new features are powerful, I've identified a critical issue with a duplicate class definition and a potential NameError due to a missing import. Addressing these will ensure the stability and correctness of the new functionalities.

Comment on lines 2014 to 2017
class TorchAOConfig:
qat_scheme: str = "int4"
base_config: AOBaseConfig = field(
default_factory = lambda: Int4WeightOnlyConfig(group_size = 128)
base_config_and_filter_fns: List[Tuple["AOBaseConfig", Optional[Callable]]] = field(
default_factory = lambda: [(Int4WeightOnlyConfig(group_size = 128), None)]
)
group_size: int = 128
filter_fn: Optional[Callable] = None

def __post_init__(self):
if self.filter_fn is None:
self.filter_fn = (

@dataclass
class TorchAOConfig:
qat_scheme: str = "int4"

# Each (config, filter_fn) pair defines a quantization rule
base_config_and_filter_fns: List[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The TorchAOConfig class is defined twice in this file. The first definition appears to be an older, less-featured version, while the second one contains the new enhancements. This duplication will lead to the second definition overwriting the first, but it's confusing and should be cleaned up by removing the first, outdated definition.

# 4) Prevent future automatic re-tying
def _no_tie(self):
return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The alias nn is used here for torch.nn, but it's not imported in this file's scope. This will likely cause a NameError at runtime. To fix this, you should either add import torch.nn as nn at the top of the file or use torch.nn.Parameter directly for better clarity and to avoid reliance on imports from other modules.

Suggested change
out_proj.weight = torch.nn.Parameter(W) # new storage, keeps dtype/device
@danielhanchen
Copy link
Contributor

Thank you!

@danielhanchen danielhanchen merged commit b9d9600 into unslothai:main Nov 14, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants