Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Beautify
  • Loading branch information
Datta0 committed Oct 27, 2025
commit 0d61b243d22d2a808b33dd333f3704e9cbd8e5d0
6 changes: 3 additions & 3 deletions unsloth/kernels/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class FbgemmFp8Linear_matmul(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, weight_scale, bias=None):

if weight.shape[0] == weight_scale.shape[0] and (weight.shape[0]%8==0 and weight.shape[1]%8==0):
if weight.shape[0] == weight_scale.shape[0] and (weight.shape[0] % 8 == 0 and weight.shape[1] % 8 == 0):
# Edit: The kernel seems to expect that the weight has dimensions divisible by 8. Otherwise it throws `RuntimeError: cutlass cannot implement`
# One thing we can do is to pad the weight and weight scale to multiple of 8 and perform a F8F8BF16 operation.
# I tried benchmarking that for speed but observed that dequantize+bf16 matmul is significantly faster than padding+f8f8bf16 matmul. So we'll go that route.
Expand Down Expand Up @@ -389,7 +389,7 @@ def forward(ctx, x, weight, weight_scale, bias=None):
output = output.to(x.device, x.dtype)
output = output.reshape(output_shape)
del x_quantized, x_scale
elif (weight.shape[0] != weight_scale.shape[0] and weight.shape[1] == weight_scale.shape[0]) or (weight.shape[0]//8!=0 or weight.shape[1]//8!=0):
elif (weight.shape[0] != weight_scale.shape[0] and weight.shape[1] == weight_scale.shape[0]) or (weight.shape[0] // 8 != 0 or weight.shape[1] // 8 != 0):
# Either the weight/scale is transposed or its shape is not divisible by 8. Both cases, dequantizing is the preferred way.
# The transpose case is generally noticed in backward pass when we do dY@W instead of @W.T as we do for forward.
# The shape case, I noticed to happen in MLP of Qwen 2.5 VL 7B where the gate proj is of shape (3420, 1280) and 3420/8=427.5
Expand All @@ -412,7 +412,7 @@ def backward(ctx, grad_output):
return grad_X, None, None, None, None

@torch_compile
def fbgemm_fp8_linear(X, weight, weight_scale, bias=None, ):
def fbgemm_fp8_linear(X, weight, weight_scale, bias=None):
return FbgemmFp8Linear_matmul.apply(X, weight, weight_scale, bias)


Expand Down