-
Notifications
You must be signed in to change notification settings - Fork 123
Add Gemma 3 support for FunctionGemma and other Gemma 3 models #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
1f4f847 to
13b3a33
Compare
|
I generated the tiny models with # Create tiny random Gemma 3 models for Bumblebee testing
# Run with: HF_TOKEN=hf_xxx python create_tiny_gemma3.py
import os
from huggingface_hub import login
# Login with token
token = os.environ.get("HF_TOKEN")
if token:
login(token=token)
else:
print("Warning: HF_TOKEN not set, using cached credentials")
from transformers import (
Gemma3TextConfig,
Gemma3TextModel,
Gemma3ForCausalLM, # No "Text" variant for CausalLM
Gemma3TextForSequenceClassification,
)
# Tiny config matching Gemma 3 text architecture
config = Gemma3TextConfig(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
intermediate_size=64,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=512,
rms_norm_eps=1e-6,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
sliding_window=128,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
tie_word_embeddings=True,
initializer_range=0.02,
query_pre_attn_scalar=8,
)
# For sequence classification
config_seq_cls = Gemma3TextConfig(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
intermediate_size=64,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=512,
rms_norm_eps=1e-6,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
sliding_window=128,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
tie_word_embeddings=True,
initializer_range=0.02,
query_pre_attn_scalar=8,
num_labels=2,
)
models = [
(Gemma3TextModel, config, "tiny-random-Gemma3Model"),
(Gemma3ForCausalLM, config, "tiny-random-Gemma3ForCausalLM"),
(Gemma3TextForSequenceClassification, config_seq_cls, "tiny-random-Gemma3ForSequenceClassification"),
]
print("Creating tiny random Gemma 3 models...")
for model_class, model_config, name in models:
print(f"\nCreating {name}...")
model = model_class(model_config)
local_path = f"./{name}"
model.save_pretrained(local_path)
print(f" Saved to {local_path}")
repo_id = f"nmaroulis/{name}"
model.push_to_hub(repo_id)
print(f" Pushed to https://huggingface.co/{repo_id}")
print("\nDone!")
|
Gemma 3 architecture includes several key differences from Gemma v1: - QK-norm (RMS normalization on query/key after projection) - Pre/post FFN layer norms (pre_feedforward_layernorm, post_feedforward_layernorm) - Different residual connection order (after post_attention_layernorm) - Alternating local/global attention (sliding window) - RMS norm with shift=1.0 formula: output * (1.0 + weight) Files added: - lib/bumblebee/text/gemma3.ex: Full Gemma 3 model implementation - test/bumblebee/text/gemma3_test.exs: Unit tests - notebooks/function_calling.livemd: Livebook with FunctionGemma examples Files modified: - lib/bumblebee.ex: Model and tokenizer registrations - lib/bumblebee/layers/transformer.ex: Per-layer attention_window_size support
13b3a33 to
1fc7aaf
Compare
…ests - Refactor decoder to use shared Layers.Transformer.blocks infrastructure - Use per-layer attention_window_size function for alternating local/global attention - Use query_norm/key_norm options for QK-normalization - Use custom block_type function for Gemma 3's unique normalization structure - Add assert_all_close with reference values from Python transformers - Fix bug in Layers.Transformer.blocks where attention_window_size was duplicated when using a function for per-layer configuration - Update params_mapping to use query_norm/key_norm naming from shared infrastructure
|
Thanks for the review feedback! I've addressed both comments:
Added assert_all_close assertions with reference values obtained from Python transformers: import torch input_ids = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]) :basemodel = AutoModel.from_pretrained("nmaroulis/tiny-random-Gemma3Model") hidden_state[[.., 1..3, 1..3]]: [[-1.6458, 0.7249, -0.5747], [-1.9452, -0.1602, -0.2329], [-2.3408, -0.4665, -0.1177]]:for_causal_language_modelingmodel = AutoModelForCausalLM.from_pretrained("nmaroulis/tiny-random-Gemma3ForCausalLM") logits[[.., 1..3, 1..3]]: [[0.1472, 0.0633, 0.0922], [-0.1089, -0.0344, 0.0755], [0.0112, 0.1083, 0.1461]]:for_sequence_classificationmodel = AutoModelForSequenceClassification.from_pretrained("nmaroulis/tiny-random-Gemma3ForSequenceClassification") logits: [[-0.0060, -0.0212]]All tests now verify numerical equivalence with Python transformers.
Replaced the custom block iteration with Layers.Transformer.blocks. The key changes:
Layers.Transformer.blocks(hidden_state, The custom gemma3_block_impl/4 function handles Gemma 3's unique block structure while leveraging the shared attention infrastructure. |
|
Btw. I uploaded the tiny-random models to bumblebee-testing, so you can switch the tests to use those :) |
|
@nyo16 by the way, the gemma3 attention uses a configurable scalar here. I've just pushed a8caabd, which adds support for I updated the tiny-random checkpoints I pushed, so that |
- Rename Bumblebee.Text.Gemma3 to Bumblebee.Text.Gemma3Text to distinguish text-only model from future multimodal Gemma3 - Add attention_scale_base config option (from query_pre_attn_scalar) - Compute attention scale as attention_scale_base ** -0.5 - Update model mappings to use Gemma3Text* variants - Update tests to use bumblebee-testing models with Python reference values - Fix duplicate attention_window_size key in transformer.ex after merge
|
Thank so much @jonatanklosko I've updated the branch based on your feedback:
Note: There are still some small numerical differences between Elixir and Python outputs (max ~0.15), so I used slightly higher tolerances in the tests. This might be worth investigating further if needed, but the model works correctly with FunctionGemma examples. I tested with the code and go correct answer, if i find more time i will investigate more. |
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
| global_attention_layer_interval: [ | ||
| default: 6, | ||
| doc: """ | ||
| the interval for global attention layers. In Gemma 3, every Nth layer uses global | ||
| attention while others use local (sliding window) attention. A value of 6 means | ||
| layers 5, 11, 17, 23... use global attention (5:1 local/global ratio) | ||
| """ | ||
| ], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We never load this from the HuggingFace config, so it will cause discrepancies (could be the reason why tests failed).
In fact, the "interval" approach is deprecated, see this code.
I think we can go with option :layer_types, which is a list of either :sliding_attention or :full_attention. We have a similar config here:
bumblebee/lib/bumblebee/diffusion/controlnet.ex
Lines 414 to 422 in a8caabd
| up_block_types: { | |
| "up_block_types", | |
| list( | |
| mapping(%{ | |
| "UpBlock2D" => :up_block, | |
| "CrossAttnUpBlock2D" => :cross_attention_up_block | |
| }) | |
| ) | |
| }, |
We should also handle "sliding_window_pattern" if present. One way to do it would be to do something like this:
# Support sliding_window_pattern for backward compatibility, see https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/gemma3/configuration_gemma3.py#L188-L195
data =
Map.put_new_lazy(data, "layer_types", fn ->
pattern = data["sliding_window_pattern"] || 6
# generate a list of Python-like layer types
end)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw. I pushed updated checkpoints to HF with:
"layer_types": [
"sliding_attention",
"full_attention"
],
This way we can test both layer types.
You will need to update the reference values.
We should be within 4 digit precision, or in other words we should not need FWIW. I generated the checkpoints using this config: from transformers import Gemma3TextConfig, Gemma3TextModel, Gemma3ForCausalLM, Gemma3TextForSequenceClassification
config = Gemma3TextConfig(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
intermediate_size=37,
hidden_act="gelu_pytorch_tanh",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
is_decoder=False,
initializer_range=0.02,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
sliding_window=128,
query_pre_attn_scalar=8,
sliding_window_pattern=2
# layer_types=[
# "sliding_attention",
# "full_attention"
# ]
)
for c in [Gemma3TextModel, Gemma3ForCausalLM, Gemma3TextForSequenceClassification]:
name = c.__name__
c(config).save_pretrained(f"bumblebee-testing/tiny-random-{name}", repo_id=f"bumblebee-testing/tiny-random-{name}", push_to_hub=True) |
Adds support for Gemma 3 architecture, enabling FunctionGemma (
google/functiongemma-270m-it) and other Gemma 3 models to run in Bumblebee.Why FunctionGemma?
Gemma 3 Architecture Changes
Gemma 3 has several key differences from Gemma v1:
weight * normalized(1 + weight) * normalizedFiles Changed
lib/bumblebee/text/gemma3.ex- Full Gemma 3 model implementation with custom decoder supporting QK-norm and extra layer normslib/bumblebee.ex- Model and tokenizer registrations forGemma3Model,Gemma3ForCausalLM, etc.lib/bumblebee/layers/transformer.ex- Per-layerattention_window_sizecallback for alternating local/global attentiontest/bumblebee/text/gemma3_test.exs- Unit tests (require tiny-random models on HuggingFace)notebooks/function_calling.livemd- Comprehensive Livebook example with:FunctionGemma.Schema- Build function declarationsFunctionGemma.Parser- Parse function call responsesFunctionGemma.Executor- Execute parsed callsSmartHome- Mock functions demo (lights, thermostat, weather)Example Usage
Test Plan