Skip to content

Commit 42df147

Browse files
authored
Merge pull request #67 from unslothai/main
Fix downcasting and upcasting
2 parents a3567e4 + f499fd4 commit 42df147

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

‎gemma/model.py‎

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,14 @@ def _norm(self, x):
179179
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
180180

181181
def forward(self, x):
182-
x = self._norm(x.float()).type_as(x)
182+
# Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
183+
# See https://github.com/huggingface/transformers/pull/29402
184+
output = self._norm(x.float())
183185
if self.add_unit_offset:
184-
output = x * (1 + self.weight)
186+
output = output * (1 + self.weight.float())
185187
else:
186-
output = x * self.weight
187-
return output
188+
output = output * self.weight.float()
189+
return output.type_as(x)
188190

189191

190192
class GemmaMLP(nn.Module):
@@ -546,7 +548,10 @@ def forward(
546548
# [batch_size, input_len, hidden_size]
547549
hidden_states = self.embedder(input_token_ids)
548550
# Gemma normalizes the embedding by sqrt(hidden_size).
549-
hidden_states = hidden_states * (self.config.hidden_size**0.5)
551+
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
552+
# See https://github.com/huggingface/transformers/pull/29402
553+
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
554+
hidden_states = hidden_states * normalizer
550555

551556
hidden_states = self.model(
552557
hidden_states=hidden_states,

0 commit comments

Comments
 (0)