-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Bugfix] Fix persistent_masked_m_silu_mul_quant tests #28366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces two important bug fixes for the persistent_masked_m_silu_mul_quant kernel. The first fix correctly handles cases where the number of groups is not divisible by the number of warps by falling back to a 1-warp configuration. The second fix adds a parameter to control ue8m0 scale ceiling to align the kernel's behavior with the reference implementation in tests. The changes are logical and well-targeted. I've found one critical issue in an updated validation check that needs to be addressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
…8366) Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…8366) Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Purpose
(8, 128, 128 * 33, fp8_dtype). This is because when hidden size is>= 4096we launch a kernel with 8 warps and the kernel requires that the NUM_GROUPS (hidden_size / 128) divides evenly between all warps. For this case, the PR just falls back to the 1 warp case.Test Plan
test_silu_mul_fp8_quant_deep_gemm.pylocallyTest Result