Skip to content

Commit dce2fda

Browse files
committed
Update fp8.py
1 parent fc178b5 commit dce2fda

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

‎unsloth/kernels/fp8.py‎

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from torch.nn import functional as F
1919
import math
2020
from unsloth_zoo.log import logger
21+
from unsloth_zoo.temporary_patches.common import torch_compile
22+
torch_matmul = torch.matmul
2123

2224
try:
2325
from transformers.integrations.finegrained_fp8 import FP8Linear
@@ -35,17 +37,16 @@
3537
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_block
3638
except ImportError:
3739
triton_quantize_fp8_block = None
40+
logger.log("Unsloth: Could not find fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm.triton_quantize_fp8_block")
3841

3942
try:
4043
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
4144
blockwise_fp8_gemm as torchao_blockwise_gemm,
4245
)
4346
except ImportError:
4447
torchao_blockwise_gemm = None
48+
logger.log("Unsloth: Could not find torchao.prototype.blockwise_fp8_inference.blockwise_quantization.blockwise_fp8_gemm")
4549

46-
from unsloth_zoo.temporary_patches.common import torch_compile
47-
48-
torch_matmul = torch.matmul
4950

5051
@triton.jit
5152
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
@@ -60,6 +61,7 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
6061
s = tl.load(s_ptr + pid_m * n + pid_n)
6162
y = x * s
6263
tl.store(y_ptr + offs, y, mask=mask)
64+
pass
6365

6466

6567
def weight_dequant_block(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype=torch.bfloat16) -> torch.Tensor:
@@ -73,6 +75,7 @@ def weight_dequant_block(x: torch.Tensor, s: torch.Tensor, block_size: int = 128
7375
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
7476
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
7577
return y
78+
pass
7679

7780
def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16):
7881
if s.shape[1] == 1:
@@ -89,7 +92,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16):
8992
else:
9093
# this is block quantized weight
9194
return weight_dequant_block(x, s, dtype=dtype)
92-
95+
pass
9396

9497
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
9598
@triton.jit
@@ -106,6 +109,7 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
106109
y = y.to(y_ptr.dtype.element_ty)
107110
tl.store(y_ptr + offs, y)
108111
tl.store(s_ptr + pid, s)
112+
pass
109113

110114
def act_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
111115
if not x.is_contiguous():
@@ -119,7 +123,7 @@ def grid(meta):
119123

120124
act_quant_kernel[grid](x, y, s, BLOCK_SIZE = block_size)
121125
return y, s
122-
126+
pass
123127

124128
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
125129
@triton.jit
@@ -205,7 +209,7 @@ def _w8a8_block_fp8_matmul(
205209
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
206210
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
207211
tl.store(c_ptrs, c, mask = c_mask)
208-
212+
pass
209213

210214
def w8a8_block_fp8_matmul_triton(
211215
A: torch.Tensor,
@@ -284,6 +288,7 @@ def grid(META):
284288
GROUP_SIZE_M = 8,
285289
)
286290
return C
291+
pass
287292

288293
def torchao_block_matmul(
289294
act_q: torch.Tensor,
@@ -301,6 +306,7 @@ def torchao_block_matmul(
301306
block_size=block_size[1],
302307
)
303308
return out.to(output_dtype)
309+
pass
304310

305311
# This torchao FP8 matmul seems to be ~3x faster than the w8a8_block_fp8_matmul_triton. Though this is 15-30% slower than fbgemm implementation.
306312
# But this gives very comparable results when it comes to training loss, so we prefer using it when available.

0 commit comments

Comments
 (0)