Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix FP8 for models with non 8 multiple weights
  • Loading branch information
Datta0 committed Oct 23, 2025
commit 78f7c79dd9a17cafc0542abaeaa75ea3109caeb1
27 changes: 17 additions & 10 deletions unsloth/kernels/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,16 +342,13 @@ class FbgemmFp8Linear(torch.autograd.Function):

@staticmethod
def forward(ctx, x, weight, weight_scale, bias=None):
if weight.shape[0] != weight_scale.shape[0]:
if weight.shape[1] == weight_scale.shape[0]:
# This is generally the case when we do backward pass. The only way is to dequantize as there is no column wise fp8 matmul
W_deq = weight_dequant(weight, weight_scale).T
x = torch_matmul(x, W_deq)
del W_deq
return x
else:
raise ValueError(f"Shapes are incompatible {weight.shape=}, {weight_scale.shape=}, {x.shape=}")
else:

if weight.shape[0] == weight_scale.shape[0] and (weight.shape[0]%8==0 and weight.shape[1]%8==0):
# Edit: The kernel seems to expect that the weight has dimensions divisible by 8. Otherwise it throws `RuntimeError: cutlass cannot implement`
# One thing we can do is to pad the weight and weight scale to multiple of 8 and perform a F8F8BF16 operation.
# I tried benchmarking that for speed but observed that dequantize+bf16 matmul is significantly faster than padding+f8f8bf16 matmul. So we'll go that route.
# So essentially, f8f8bf16_rowise only happens when shapes are proper (no transposes) and divisible by 8.

# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
output_shape = (*x.shape[:-1], -1)
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
Expand All @@ -378,6 +375,16 @@ def forward(ctx, x, weight, weight_scale, bias=None):
output = output.to(x.device, x.dtype)
output = output.reshape(output_shape)
del x_quantized, x_scale
elif (weight.shape[0] != weight_scale.shape[0] and weight.shape[1] == weight_scale.shape[0]) or (weight.shape[0]//8!=0 or weight.shape[1]//8!=0):
# Either the weight/scale is transposed or its shape is not divisible by 8. Both cases, dequantizing is the preferred way.
# The transpose case is generally noticed in backward pass when we do dY@W instead of @W.T as we do for forward.
# The shape case, I noticed to happen in MLP of Qwen 2.5 VL 7B where the gate proj is of shape (3420, 1280) and 3420/8=427.5

W_deq = weight_dequant(weight, weight_scale).T
output = torch_matmul(x, W_deq)
del W_deq
else:
raise ValueError(f"Shapes are incompatible {weight.shape=}, {weight_scale.shape=}, {x.shape=}")

ctx.weight = weight
ctx.weight_scale = weight_scale
Expand Down