-
-
Notifications
You must be signed in to change notification settings - Fork 4k
Add 128x128 PerBlock FP8 + RL #3629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,13 +16,14 @@ | |
| import os | ||
| import re | ||
| import tempfile | ||
| from typing import Union | ||
| from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit | ||
|
|
||
| # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! | ||
| from packaging.version import Version | ||
| from transformers import ( | ||
| AutoModel, | ||
| AutoProcessor, | ||
| AutoModelForCausalLM, | ||
| AutoTokenizer, | ||
| TorchAoConfig, | ||
| __version__ as transformers_version, | ||
| ) | ||
|
|
@@ -158,20 +159,31 @@ def get_model_name(model_name, load_in_4bit = True): | |
| return new_model_name if new_model_name is not None else model_name | ||
|
|
||
|
|
||
| def _get_torchao_fp8_config(): | ||
| def _get_torchao_fp8_config(fp8_mode: str): | ||
| """ | ||
| Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig` | ||
| to be used for `load_in_fp8=True`. | ||
| """ | ||
| from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow | ||
| from torchao.quantization import ( | ||
| Float8DynamicActivationFloat8WeightConfig, | ||
| PerBlock, | ||
| PerRow, | ||
| ) | ||
|
|
||
| if fp8_mode == "row": | ||
| granularity = PerRow() | ||
| elif fp8_mode == "block": | ||
| granularity = (PerBlock([1, 128]), PerBlock([128, 128])) | ||
| else: | ||
| raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'") | ||
|
|
||
| return Float8DynamicActivationFloat8WeightConfig( | ||
| granularity = PerRow(), | ||
| granularity = granularity, | ||
| activation_value_lb = 1e-12, | ||
| ) | ||
|
|
||
|
|
||
| def _offline_quantize_to_fp8(model_name: str) -> str: | ||
| def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str: | ||
| """ | ||
| Quantizes the model to fp8 using torchao and saving the quantized model to a | ||
| temporary location. Return the path to the quantized model. | ||
|
|
@@ -186,53 +198,72 @@ def _offline_quantize_to_fp8(model_name: str) -> str: | |
| ) | ||
| """ | ||
| temp_dir = tempfile.gettempdir() | ||
| new_model_name = model_name.split("/")[-1] + "-fp8" | ||
| new_model_name = model_name.split("/")[-1] + "-fp8-" + fp8_mode | ||
| new_model_name = os.path.join(temp_dir, new_model_name) | ||
| print( | ||
| f"Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead" | ||
| f"Unsloth: Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead" | ||
| ) | ||
|
|
||
| if not os.path.isdir(new_model_name): | ||
| qconfig = _get_torchao_fp8_config() | ||
| qconfig = _get_torchao_fp8_config(fp8_mode) | ||
| qconfig = TorchAoConfig(qconfig) | ||
| model = AutoModel.from_pretrained( | ||
| # TODO: generalize this to beyond text models? | ||
| # Right now using AutoModel removes the `lm_head` layer, | ||
| # which is expected later when loading the vllm state dict | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @danielhanchen I had to change this back for this to work. When I tried
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh ok oh wait I can make this work for vision models without using AutoModel |
||
| model_name, | ||
| torch_dtype = "auto", | ||
| device_map = "auto", | ||
| quantization_config = qconfig, | ||
| ) | ||
| tokenizer = AutoProcessor.from_pretrained(model_name) | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| model.save_pretrained(new_model_name, safe_serialization = False) | ||
| tokenizer.save_pretrained(new_model_name) | ||
| return new_model_name | ||
|
|
||
|
|
||
| def _tag_model_with_fp8_torchao_config(model: torch.nn.Module): | ||
| def _tag_model_with_fp8_torchao_config(model: torch.nn.Module, fp8_mode: str): | ||
| """ | ||
| Tag a model with a `TorchAOConfig` so downstream callers will know what to do with it. | ||
| """ | ||
| base_config = _get_torchao_fp8_config() | ||
| base_config = _get_torchao_fp8_config(fp8_mode) | ||
| model.torchao_config = TorchAOConfig( | ||
| qat_scheme = None, | ||
| base_config_and_filter_fns = [(base_config, None)], | ||
| ) | ||
|
|
||
|
|
||
| def _check_load_in_fp8_settings( | ||
| def _get_fp8_mode_and_check_settings( | ||
| load_in_fp8: Union[bool, str], | ||
| fast_inference: bool, | ||
| full_finetuning: bool, | ||
| load_in_4bit: bool, | ||
| load_in_8bit: bool, | ||
| load_in_16bit: bool, | ||
| use_exact_model_name: bool, | ||
| ): | ||
| ) -> str: | ||
| """ | ||
| Assuming `load_in_fp8=True`, raise appropriate errors on incompatible settings | ||
| Assuming `load_in_fp8` is enabled, raise appropriate errors on incompatible settings | ||
| and environment. Currently this feature requires: | ||
|
|
||
| 1. H100 GPUs or after | ||
| 2. torchao 0.15.0+ (or nightly) | ||
| 3. torch 2.9.0+ | ||
| 4. If fbgemm_gpu_genai is installed, require 1.4.1+ | ||
|
|
||
| Returns the fp8 mode, one of "row" or "block". | ||
| """ | ||
| assert load_in_fp8 is not False | ||
| if load_in_fp8 is True: | ||
| fp8_mode = "row" # default | ||
| else: | ||
| fp8_mode = load_in_fp8 | ||
|
|
||
| # Check user settings | ||
| if fp8_mode not in ["row", "block"]: | ||
| raise ValueError( | ||
| f"Unsloth: `load_in_fp8` can only be 'row' or 'block', got '{fp8_mode}'" | ||
| ) | ||
| if not fast_inference: | ||
| raise ValueError( | ||
| "Unsloth: `load_in_fp8` is only supported for `fast_inference` for now" | ||
|
|
@@ -263,13 +294,16 @@ def _check_load_in_fp8_settings( | |
| # Check if torchao has this PR: https://github.com/pytorch/ao/pull/3158, | ||
| # which will be released in 0.15.0. | ||
| if importlib.util.find_spec("torchao") is None: | ||
| raise ValueError("Unsloth: Please install torchao for on the fly float8 to work!") | ||
| raise ValueError( | ||
| "Unsloth: Please install torchao for on the fly float8 to work!" | ||
| ) | ||
| import torchao | ||
|
|
||
| error_message = \ | ||
| "Unsloth: `load_in_fp8` requires torchao 0.15.0+ (or nightly).\n"\ | ||
| f"You have torchao version={torchao.__version__}\n"\ | ||
| error_message = ( | ||
| "Unsloth: `load_in_fp8` requires torchao 0.15.0+ (or nightly).\n" | ||
| f"You have torchao version={torchao.__version__}\n" | ||
| "Use `pip install --upgrade --force-reinstall torchao`" | ||
| ) | ||
| if Version(torchao.__version__) < Version("0.15.0"): | ||
| raise ValueError(error_message) | ||
|
|
||
|
|
@@ -284,3 +318,4 @@ def _check_load_in_fp8_settings( | |
| raise ValueError( | ||
| "Unsloth: `load_in_fp8` is only compatible with fbgemm_gpu_genai 1.4.1+" | ||
| ) | ||
| return fp8_mode | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This validation is redundant, as
fp8_modeis already validated in_get_fp8_mode_and_check_settingsbefore being passed to this function. For internal functions, it's better to rely on assertions for contract checking rather than raising user-facingValueErrors. This avoids duplicated validation logic and makes the code cleaner.Consider removing this
elseblock. If you want to keep a check for robustness, anassertwould be more appropriate, for example:However, given the call chain, even an assert is likely unnecessary.