-
-
Notifications
You must be signed in to change notification settings - Fork 4k
Extend TorchAOConfig to support mobile usecases #3587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @metascroy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and deployment capabilities of TorchAO-based quantization within the system. It introduces a more sophisticated configuration for defining layer-specific quantization rules and pre-quantization model transformations, particularly benefiting embedding layers. Furthermore, it provides a crucial utility for converting QAT-trained models into a standard HuggingFace format, streamlining the process of distributing quantized models. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant enhancements to the TorchAOConfig for more flexible quantization-aware training. Key changes include support for multiple quantization configurations, a new pre-quantization transform, and utilities for untying embeddings and converting models. While the new features are powerful, I've identified a critical issue with a duplicate class definition and a potential NameError due to a missing import. Addressing these will ensure the stability and correctness of the new functionalities.
| class TorchAOConfig: | ||
| qat_scheme: str = "int4" | ||
| base_config: AOBaseConfig = field( | ||
| default_factory = lambda: Int4WeightOnlyConfig(group_size = 128) | ||
| base_config_and_filter_fns: List[Tuple["AOBaseConfig", Optional[Callable]]] = field( | ||
| default_factory = lambda: [(Int4WeightOnlyConfig(group_size = 128), None)] | ||
| ) | ||
| group_size: int = 128 | ||
| filter_fn: Optional[Callable] = None | ||
|
|
||
| def __post_init__(self): | ||
| if self.filter_fn is None: | ||
| self.filter_fn = ( | ||
|
|
||
| @dataclass | ||
| class TorchAOConfig: | ||
| qat_scheme: str = "int4" | ||
|
|
||
| # Each (config, filter_fn) pair defines a quantization rule | ||
| base_config_and_filter_fns: List[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TorchAOConfig class is defined twice in this file. The first definition appears to be an older, less-featured version, while the second one contains the new enhancements. This duplication will lead to the second definition overwriting the first, but it's confusing and should be cleaned up by removing the first, outdated definition.
| # 4) Prevent future automatic re-tying | ||
| def _no_tie(self): | ||
| return | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alias nn is used here for torch.nn, but it's not imported in this file's scope. This will likely cause a NameError at runtime. To fix this, you should either add import torch.nn as nn at the top of the file or use torch.nn.Parameter directly for better clarity and to avoid reliance on imports from other modules.
| out_proj.weight = torch.nn.Parameter(W) # new storage, keeps dtype/device |
5da79ee to
f71bfa4
Compare
for more information, see https://pre-commit.ci
|
Thank you! |
This PR updates TorchAOConfig to be more flexible. Specifically:
We allow multiple base_config/filter_fn so that users can quantize different layers differently, or quantize both linear/embedding layers. Previously only linear layers were quantized and all linear layers used the same quantiation.
Adds a new prequantization_transform to TorchAOConfig that is applied before calling quantize_. This is useful for untying parameters if you want to quantize the input/output embeddings differently in tied models.
Adds utility _untie_input_output_embeddings to untie input/output embeddings
Redefines int8-int4 to be more mobile friendly by adding embedding quantization.
Adds new _convert_torchao_model utility to convert the QAT'd torch model to a format that can be distributed on HuggingFace.
With these changes, we can QAT a model for ExecuTorch with:
After calling FastLanguageModel.from_pretrained, the model has torchao FakeQuantizedLinear and FakeQuantizedEmbedding modules as specified by the qat_scheme ("int8-int4").
After calling _convert_torchao_model, the model has tensor subclasses and an HF quantization config attached.