Skip to content

Conversation

@AndreasKaratzas
Copy link
Contributor

@AndreasKaratzas AndreasKaratzas commented Jan 1, 2026

Fixes ROCm accuracy failures in language generation tests by disabling flash_sdp and mem_efficient_sdp backends for HuggingFace Transformers inference.

Problem

Tests comparing HuggingFace and vLLM outputs were failing on ROCm due to numerical divergence:

AssertionError: Test5:
hf:   'fact' (logprob=-7.794)
vllm: 'upset' (logprob=-8.045)

The root cause is accuracy issues in ROCm's flash attention and memory-efficient SDP implementations in HuggingFace Transformers (#30167).

Solution

Add a conftest.py that configures PyTorch SDP backends on ROCm:

  • Disable flash_sdp
  • Disable mem_efficient_sdp
  • Enable math_sdp (more accurate reference implementation)

This ensures HuggingFace uses a numerically stable attention path for baseline comparisons.

Related

Testing

  • Verified pytest -v -s tests/models/language/generation/test_common.py::test_models[False-True-5-32-TitanML/tiny-mixtral] passes with this fix
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas AndreasKaratzas changed the title [ROCm] Fix language generation test accuracy by disabling HF flash_sdp and mem_efficient_sdp Jan 1, 2026
@mergify mergify bot added the rocm Related to AMD ROCm label Jan 1, 2026
@AndreasKaratzas
Copy link
Contributor Author

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a workaround for ROCm accuracy issues in language generation tests by disabling flash_sdp and mem_efficient_sdp backends in HuggingFace Transformers. The change is implemented by adding a conftest.py file that applies these settings at the start of the test session. The approach is sound and the code is clear. My only suggestion is to use a more semantically appropriate pytest hook (pytest_sessionstart instead of pytest_collection_modifyitems) to improve maintainability and adherence to pytest best practices.

Comment on lines 12 to 28
def pytest_collection_modifyitems(config, items):
"""Configure ROCm-specific settings based on collected tests."""
if not current_platform.is_rocm():
return

# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
warnings.warn(
"ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp "
"to avoid HuggingFace Transformers accuracy issues",
UserWarning,
stacklevel=1,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pytest_collection_modifyitems hook is intended for modifying the list of collected test items, not for test session setup. Using it for setup can be misleading and is not idiomatic. A more appropriate hook for this kind of session-level setup is pytest_sessionstart, which is executed once at the beginning of the test session. Alternatively, a session-scoped autouse fixture could be used. Using the correct hook improves code clarity and maintainability.

Suggested change
def pytest_collection_modifyitems(config, items):
"""Configure ROCm-specific settings based on collected tests."""
if not current_platform.is_rocm():
return
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
warnings.warn(
"ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp "
"to avoid HuggingFace Transformers accuracy issues",
UserWarning,
stacklevel=1,
)
def pytest_sessionstart(session):
"""Configure ROCm-specific settings before test session starts."""
if not current_platform.is_rocm():
return
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
warnings.warn(
"ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp "
"to avoid HuggingFace Transformers accuracy issues",
UserWarning,
stacklevel=1,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

1 participant