-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[ROCm] Add missing gemm_a8w8_blockscale import #28378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request correctly addresses a missing import for gemm_a8w8_blockscale within a conditional block, preventing a NameError in the rocm_aiter_gemm_w8a8_blockscale_impl function. However, the current implementation strategy for gemm_a8w8_blockscale introduces an ambiguity, as the same symbol is dynamically bound to different functions (from aiter and aiter.ops.triton.gemm_a8w8_blockscale) depending on the execution path. While Python's scoping rules allow this, it can lead to confusion and potential subtle bugs if the functions are not perfectly interchangeable. A clearer approach would involve using distinct aliases for these imports to explicitly differentiate the intended function in each branch.
yewentao256
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
059e08f to
e667127
Compare
tjtanaa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sarckk A better approach would be to move all the gemm_a8w8_blockscale imports to a separate if else condition. This will handle both MI300x and MI355x
+ use_triton = not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k)
+ if use_triton:
+ from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
+ else:
+ from aiter import gemm_a8w8_blockscale
if input_scale is not None:
q_input = input_2d
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
- from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
# MI350 case uses triton kernel
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
group_size,
column_major_scales=False,
use_ue8m0=False,
)
else:
# MI300 uses tuned AITER ASM/C++ kernel
import aiter as rocm_aiter
- from aiter import gemm_a8w8_blockscale, get_hip_quant
+ from aiter import get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)|
@tjtanaa addressed your comments. although it seems that the real issue was that the use_triton branch should not have been in an vllm/vllm/model_executor/layers/quantization/utils/fp8_utils.py Lines 319 to 350 in d0e186c
|
| if use_triton: | ||
| gemm_w8a8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale | ||
| else: | ||
| gemm_w8a8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason the gemm_a8w8_blockscale_bpreshuffle is not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hongxiayang the gemm_a8w8_blockscaled_bpreshuffle is in another PR. Moreover, the AITER commit in the Dockerfile.rocm_base does not have that kernel, so it is not yet merged.
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
tjtanaa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you for your work @sarckk
Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Purpose
BugFix: #27058 seems to have missed import for
aiter.gemm_a8w8_blockscaleBugFix: #24490 changed it so that we always do quant op. Fix to do quant op only if the input_scale is None:
vllm/vllm/model_executor/layers/quantization/utils/fp8_utils.py
Lines 319 to 350 in d0e186c
BugFix: correct invocation of aiter ops
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.