Skip to content

Conversation

@sarckk
Copy link
Collaborator

@sarckk sarckk commented Nov 10, 2025

Purpose

BugFix: #27058 seems to have missed import for aiter.gemm_a8w8_blockscale

BugFix: #24490 changed it so that we always do quant op. Fix to do quant op only if the input_scale is None:

if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
if (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
):
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
return rocm_aiter_ops.gemm_w8a8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)

BugFix: correct invocation of aiter ops

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the rocm Related to AMD ROCm label Nov 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request correctly addresses a missing import for gemm_a8w8_blockscale within a conditional block, preventing a NameError in the rocm_aiter_gemm_w8a8_blockscale_impl function. However, the current implementation strategy for gemm_a8w8_blockscale introduces an ambiguity, as the same symbol is dynamically bound to different functions (from aiter and aiter.ops.triton.gemm_a8w8_blockscale) depending on the execution path. While Python's scoping rules allow this, it can lead to confusion and potential subtle bugs if the functions are not perfectly interchangeable. A clearer approach would involve using distinct aliases for these imports to explicitly differentiate the intended function in each branch.

@wuhuikx
Copy link
Contributor

wuhuikx commented Nov 10, 2025

@HaiShaw @tjtanaa could you please help review?

@houseroad houseroad added ready-for-merge Indicate this PR is ready to be merged by the maintainers, used by reviewers without merge access. ready ONLY add when PR is ready to merge/full CI is needed labels Nov 10, 2025
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

@mergify
Copy link

mergify bot commented Nov 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 10, 2025
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
@sarckk sarckk force-pushed the fix-missing-aiter-import branch from 059e08f to e667127 Compare November 10, 2025 16:44
@mergify mergify bot removed the needs-rebase label Nov 10, 2025
Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarckk A better approach would be to move all the gemm_a8w8_blockscale imports to a separate if else condition. This will handle both MI300x and MI355x

+    use_triton = not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k)
    
+    if use_triton:
+        from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
+    else:
+        from aiter import gemm_a8w8_blockscale
        
    if input_scale is not None:
        q_input = input_2d
    elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
-        from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale

        # MI350 case uses triton kernel
        q_input, input_scale = per_token_group_quant_fp8(
            input_2d,
            group_size,
            column_major_scales=False,
            use_ue8m0=False,
        )
    else:
        # MI300 uses tuned AITER ASM/C++ kernel
        import aiter as rocm_aiter
-        from aiter import gemm_a8w8_blockscale, get_hip_quant
+      from aiter import get_hip_quant
        aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
        q_input, input_scale = aiter_per1x128_quant(
            input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
        )

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
@sarckk
Copy link
Collaborator Author

sarckk commented Nov 10, 2025

@tjtanaa addressed your comments. although it seems that the real issue was that the use_triton branch should not have been in an elif? looks like your #24490 changed this so we no longer need this PR, I think

if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
if (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
):
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
return rocm_aiter_ops.gemm_w8a8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)

if use_triton:
gemm_w8a8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
else:
gemm_w8a8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason the gemm_a8w8_blockscale_bpreshuffle is not used?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongxiayang the gemm_a8w8_blockscaled_bpreshuffle is in another PR. Moreover, the AITER commit in the Dockerfile.rocm_base does not have that kernel, so it is not yet merged.

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for your work @sarckk

@tjtanaa tjtanaa enabled auto-merge (squash) November 10, 2025 21:35
@tjtanaa tjtanaa merged commit 0211435 into vllm-project:main Nov 10, 2025
54 checks passed
@sarckk sarckk deleted the fix-missing-aiter-import branch November 10, 2025 23:14
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Nov 13, 2025
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed ready-for-merge Indicate this PR is ready to be merged by the maintainers, used by reviewers without merge access. rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants