Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
update weight scale transpose check for newer vllm and standby util o…
…verride
  • Loading branch information
Datta0 committed Nov 14, 2025
commit f52d05f5141ed4a8ee8777aaecad35105f04d929
56 changes: 37 additions & 19 deletions unsloth_zoo/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,19 +879,34 @@ def _get_vllm_state_dict(llm, return_state_dict = False, config = None, is_visio
capability = torch.cuda.get_device_capability()
sm_cap = capability[0] * 10 + capability[1]

try:
from vllm.utils.deep_gemm import is_deep_gemm_supported as vllm_is_deep_gemm_supported
is_deep_gemm_supported = vllm_is_deep_gemm_supported()
except Exception as e:
logger.info(f"Unsloth: Could not import vLLM deep_gemm: {e}")
is_deep_gemm_supported = False

try:
cutlass_block_fp8_supported = torch.ops._C.cutlass_scaled_mm_supports_block_fp8(sm_cap)
# vLLM recently removed the transpose of weight scale for Hopper GPUs.
# https://github.com/vllm-project/vllm/pull/28431
# So now we check if the weight process function does a transpose of weight scale before doing so
# https://github.com/vllm-project/vllm/commit/f9a4087182ffcd9404779fcda876f820b3b26d5f#diff-cce58c0ceb6a9b15a01f117d734b93736acc25ed89921c2eacc58ea05bd34d0eL1155-L1157
from vllm.model_executor.layers.quantization.utils.fp8_utils import maybe_post_process_fp8_weight_block
from inspect import getsource
needs_transpose_check = 'layer.weight_scale.data.T.contiguous()' in getsource(maybe_post_process_fp8_weight_block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This check using getsource is fragile and can break easily with vllm updates. A version check would be more robust.

More importantly, the logic seems inverted. needs_transpose_check is True for older vllm versions that already perform the transpose. Your code then performs another transpose (on line 948), resulting in a double transpose. The transpose should only happen for newer vllm versions that don't do it.

The condition on line 942 should likely be if not needs_transpose_check:.

For a more robust solution, consider checking the vllm version:

from vllm import __version__ as VLLM_VERSION
from packaging.version import Version

# The transpose was removed in vLLM v0.4.1.
# We need to transpose only if vLLM does not.
should_transpose = Version(VLLM_VERSION) >= Version("0.4.1")

Then use if should_transpose: where you currently use if needs_transpose_check:.

except Exception as e:
logger.info(f"Unsloth: Could not import vLLM cutlass_block_fp8_supported: {e}")
cutlass_block_fp8_supported = False
pass
logger.info(f"Unsloth: Could not import vLLM fp8_utils: {e}")
needs_transpose_check = False

is_deep_gemm_supported = False
cutlass_block_fp8_supported = False
if needs_transpose_check:
# Only try to import and check if we need to
try:
from vllm.utils.deep_gemm import is_deep_gemm_supported as vllm_is_deep_gemm_supported
is_deep_gemm_supported = vllm_is_deep_gemm_supported()
except Exception as e:
logger.info(f"Unsloth: Could not import vLLM deep_gemm: {e}")

try:
cutlass_block_fp8_supported = torch.ops._C.cutlass_scaled_mm_supports_block_fp8(sm_cap)
except Exception as e:
logger.info(f"Unsloth: Could not import vLLM cutlass_block_fp8_supported: {e}")
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This pass statement is unnecessary and can be removed. The same applies to the pass on line 950.


def get_state_dict(prefix, kk, state_dict, proj, slice_weights=True, slice_index=-1):
proj = getattr(proj, "base_layer", proj)
Expand Down Expand Up @@ -924,18 +939,19 @@ def get_state_dict(prefix, kk, state_dict, proj, slice_weights=True, slice_index
# Also notice that vLLM stores scale in [32,48] which is transpose of what HF expects.
scale_suffix = '.weight_scale_inv'
block_size = proj.weight_block_size[0]
should_use_deepgemm = is_deep_gemm_supported and getattr(proj, "orig_dtype", torch.bfloat16) == torch.bfloat16 and qweight.shape[0] % 128 == 0 and qweight.shape[1] % 128 == 0
if sm_cap==90 and cutlass_block_fp8_supported and not should_use_deepgemm:
# For H100 (at least), the scale seems to be a transpose of what HF expects, while on L4 it is right shape.
# This is done by vLLM based on a few checks that we replicated above.
# https://github.com/vllm-project/vllm/blob/294c805f1df9ddf62c2290989710da9d48ab4973/vllm/model_executor/layers/quantization/utils/fp8_utils.py#L1172-L1179
weight_scale = weight_scale.T
logger.info(f"Unsloth: Transposed weight scale for {prefix} for weight shape {qweight.shape} and scale shape {weight_scale.shape}")
if needs_transpose_check:
should_use_deepgemm = is_deep_gemm_supported and getattr(proj, "orig_dtype", torch.bfloat16) == torch.bfloat16 and qweight.shape[0] % 128 == 0 and qweight.shape[1] % 128 == 0
if sm_cap==90 and cutlass_block_fp8_supported and not should_use_deepgemm:
# For H100 (at least), the scale seems to be a transpose of what HF expects, while on L4 it is right shape.
# This is done by vLLM based on a few checks that we replicated above.
# https://github.com/vllm-project/vllm/blob/294c805f1df9ddf62c2290989710da9d48ab4973/vllm/model_executor/layers/quantization/utils/fp8_utils.py#L1172-L1179
weight_scale = weight_scale.T
logger.info(f"Unsloth: Transposed weight scale for {prefix} for weight shape {qweight.shape} and scale shape {weight_scale.shape}")
pass
a, b = qweight.shape
p, q = weight_scale.shape
# This is just a sanity check to ensure that we don't end up with wrongly sliced weight of shape [0, x] :)
assert a // p == proj.weight_block_size[0] and b // q == proj.weight_block_size[1], f"Unsloth: vLLM weight has unexpected weight shape {qweight.shape} and scale {weight_scale.shape} and block size {proj.weight_block_size}"
assert a // p == proj.weight_block_size[0] and b // q == proj.weight_block_size[1], f"Unsloth: vLLM weight for {prefix} has unexpected weight shape {qweight.shape} and scale {weight_scale.shape} and block size {proj.weight_block_size}"
else:
# This is dynamic quantization (aka per row or per column). The scale is of shape [n,1]
# The weight here is of shape [4096, 6144]. We need to transpose and then slice
Expand Down Expand Up @@ -1548,6 +1564,8 @@ def load_vllm(
assert(conservativeness >= 0.0 and conservativeness <= 1.0)

unsloth_vllm_standby = unsloth_vllm_standby or (os.getenv("UNSLOTH_VLLM_STANDBY", "0") != "0")
# This would give the flexibility to override the util we set for standby mode. In some extreme cases, this can be helpful.
standby_util_override = os.getenv("UNSLOTH_VLLM_STANDBY_UTIL_OVERRIDE", "0") != "0"

free_memory, total_memory = get_mem_info()
# If T4 ie 15GB, we use 0.85 since it'll rarely OOM. Other GPUs 0.9
Expand All @@ -1561,7 +1579,7 @@ def load_vllm(
elif ten_percent >= 1.0: standby_target_gpu_util = 0.8
else: standby_target_gpu_util = 0.75

if unsloth_vllm_standby and gpu_memory_utilization < standby_target_gpu_util:
if unsloth_vllm_standby and gpu_memory_utilization < standby_target_gpu_util and not standby_util_override:
gpu_memory_utilization = standby_target_gpu_util
logger.info(f"Unsloth: Standby mode is enabled. Changing `gpu_memory_utilization` to {gpu_memory_utilization}.")

Expand Down