1818from torch .nn import functional as F
1919import math
2020from unsloth_zoo .log import logger
21+ from unsloth_zoo .temporary_patches .common import torch_compile
22+ torch_matmul = torch .matmul
2123
2224try :
2325 from transformers .integrations .finegrained_fp8 import FP8Linear
3537 from fbgemm_gpu .experimental .gemm .triton_gemm .fp8_gemm import triton_quantize_fp8_block
3638except ImportError :
3739 triton_quantize_fp8_block = None
40+ logger .log ("Unsloth: Could not find fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm.triton_quantize_fp8_block" )
3841
3942try :
4043 from torchao .prototype .blockwise_fp8_inference .blockwise_quantization import (
4144 blockwise_fp8_gemm as torchao_blockwise_gemm ,
4245 )
4346except ImportError :
4447 torchao_blockwise_gemm = None
48+ logger .log ("Unsloth: Could not find torchao.prototype.blockwise_fp8_inference.blockwise_quantization.blockwise_fp8_gemm" )
4549
46- from unsloth_zoo .temporary_patches .common import torch_compile
47-
48- torch_matmul = torch .matmul
4950
5051@triton .jit
5152def weight_dequant_kernel (x_ptr , s_ptr , y_ptr , M , N , BLOCK_SIZE : tl .constexpr ):
@@ -60,6 +61,7 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
6061 s = tl .load (s_ptr + pid_m * n + pid_n )
6162 y = x * s
6263 tl .store (y_ptr + offs , y , mask = mask )
64+ pass
6365
6466
6567def weight_dequant_block (x : torch .Tensor , s : torch .Tensor , block_size : int = 128 , dtype = torch .bfloat16 ) -> torch .Tensor :
@@ -73,6 +75,7 @@ def weight_dequant_block(x: torch.Tensor, s: torch.Tensor, block_size: int = 128
7375 grid = lambda meta : (triton .cdiv (M , meta ['BLOCK_SIZE' ]), triton .cdiv (N , meta ['BLOCK_SIZE' ]))
7476 weight_dequant_kernel [grid ](x , s , y , M , N , BLOCK_SIZE = block_size )
7577 return y
78+ pass
7679
7780def weight_dequant (x : torch .Tensor , s : torch .Tensor , dtype = torch .bfloat16 ):
7881 if s .shape [1 ] == 1 :
@@ -89,7 +92,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16):
8992 else :
9093 # this is block quantized weight
9194 return weight_dequant_block (x , s , dtype = dtype )
92-
95+ pass
9396
9497# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
9598@triton .jit
@@ -106,6 +109,7 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
106109 y = y .to (y_ptr .dtype .element_ty )
107110 tl .store (y_ptr + offs , y )
108111 tl .store (s_ptr + pid , s )
112+ pass
109113
110114def act_quant (x : torch .Tensor , block_size : int = 128 ) -> tuple [torch .Tensor , torch .Tensor ]:
111115 if not x .is_contiguous ():
@@ -119,7 +123,7 @@ def grid(meta):
119123
120124 act_quant_kernel [grid ](x , y , s , BLOCK_SIZE = block_size )
121125 return y , s
122-
126+ pass
123127
124128# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
125129@triton .jit
@@ -205,7 +209,7 @@ def _w8a8_block_fp8_matmul(
205209 c_ptrs = C + stride_cm * offs_cm [:, None ] + stride_cn * offs_cn [None , :]
206210 c_mask = (offs_cm [:, None ] < M ) & (offs_cn [None , :] < N )
207211 tl .store (c_ptrs , c , mask = c_mask )
208-
212+ pass
209213
210214def w8a8_block_fp8_matmul_triton (
211215 A : torch .Tensor ,
@@ -284,6 +288,7 @@ def grid(META):
284288 GROUP_SIZE_M = 8 ,
285289 )
286290 return C
291+ pass
287292
288293def torchao_block_matmul (
289294 act_q : torch .Tensor ,
@@ -301,6 +306,7 @@ def torchao_block_matmul(
301306 block_size = block_size [1 ],
302307 )
303308 return out .to (output_dtype )
309+ pass
304310
305311# This torchao FP8 matmul seems to be ~3x faster than the w8a8_block_fp8_matmul_triton. Though this is 15-30% slower than fbgemm implementation.
306312# But this gives very comparable results when it comes to training loss, so we prefer using it when available.
0 commit comments