Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions benchmarks/kernels/bench_block_fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os

# Disable DeepGEMM for this benchmark to use CUTLASS
os.environ["VLLM_USE_DEEP_GEMM"] = "0"

import torch

from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear,
W8A8BlockFp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
Expand Down Expand Up @@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min

# Create random FP8 tensors
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max

# Create quantized weight tensor
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

# Create scales
# Create weight scales
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
Expand All @@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
* factor_for_scale
)

# SM90 CUTLASS requires row-major format for scales
if use_cutlass and current_platform.is_device_capability(90):
Bs = Bs.T.contiguous()
# Create W8A8BlockFp8LinearOp instance
weight_group_shape = GroupShape(block_n, block_k)
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization

linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=weight_group_shape,
act_quant_group_shape=act_quant_group_shape,
cutlass_block_fp8_supported=use_cutlass,
use_aiter_and_is_supported=False,
)

def run():
if use_cutlass:
return apply_w8a8_block_fp8_linear(
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
)
else:
return apply_w8a8_block_fp8_linear(
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
)
return linear_op.apply(
input=A_ref,
weight=B,
weight_scale=Bs,
input_scale=None,
bias=None,
)

return run

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
using ElementBlockScale = float;

using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::GMMA::Major::MN, cute::GMMA::Major::K>;

using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def process_weights_after_loading(self, layer) -> None:
layer.input_scale = None

if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
maybe_post_process_fp8_weight_block(layer)

def apply_weights(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
return

if self.block_quant:
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
maybe_post_process_fp8_weight_block(layer)

def apply(
self,
Expand Down
22 changes: 3 additions & 19 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,13 @@ def cutlass_scaled_mm(
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
is_hopper: bool | None = None,
) -> torch.Tensor:
if is_hopper is None:
is_hopper = current_platform.is_device_capability(90)
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None and is_hopper else Bs.T,
scale_b=Bs.T,
)


Expand Down Expand Up @@ -130,7 +126,7 @@ def _padded_cutlass(
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)

output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]

Expand Down Expand Up @@ -303,7 +299,6 @@ def _run_cutlass(
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
False,
)

def _run_aiter(
Expand Down Expand Up @@ -1124,9 +1119,7 @@ def process_fp8_weight_block_strategy(
return weight, weight_scale


def maybe_post_process_fp8_weight_block(
layer: torch.nn.Module, cutlass_block_fp8_supported: bool
):
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
assert layer.weight_block_size is not None

from vllm.utils.deep_gemm import (
Expand All @@ -1145,15 +1138,6 @@ def maybe_post_process_fp8_weight_block(
requant_weight_ue8m0_inplace(
layer.weight.data, layer.weight_scale.data, block_sz
)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif (
current_platform.is_device_capability(90)
and cutlass_block_fp8_supported
and not should_use_deepgemm
):
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False
)


def expert_weight_is_col_major(x: torch.Tensor) -> bool:
Expand Down