@@ -179,12 +179,14 @@ def _norm(self, x):
179179 return x * torch .rsqrt (x .pow (2 ).mean (- 1 , keepdim = True ) + self .eps )
180180
181181 def forward (self , x ):
182- x = self ._norm (x .float ()).type_as (x )
182+ # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
183+ # See https://github.com/huggingface/transformers/pull/29402
184+ output = self ._norm (x .float ())
183185 if self .add_unit_offset :
184- output = x * (1 + self .weight )
186+ output = output * (1 + self .weight . float () )
185187 else :
186- output = x * self .weight
187- return output
188+ output = output * self .weight . float ()
189+ return output . type_as ( x )
188190
189191
190192class GemmaMLP (nn .Module ):
@@ -546,7 +548,10 @@ def forward(
546548 # [batch_size, input_len, hidden_size]
547549 hidden_states = self .embedder (input_token_ids )
548550 # Gemma normalizes the embedding by sqrt(hidden_size).
549- hidden_states = hidden_states * (self .config .hidden_size ** 0.5 )
551+ # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
552+ # See https://github.com/huggingface/transformers/pull/29402
553+ normalizer = torch .tensor (self .config .hidden_size ** 0.5 , dtype = hidden_states .dtype )
554+ hidden_states = hidden_states * normalizer
550555
551556 hidden_states = self .model (
552557 hidden_states = hidden_states ,
0 commit comments