-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Refactor pagedAttention transpose #33102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Refactor pagedAttention transpose #33102
Conversation
zhangYiIntel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors transpose functions from executor_pa.cpp to a shared transpose.hpp header for reuse across multiple components. The changes enable better code organization and add new parameters to the transpose function signature to support quantization features.
- Moved three
transpose_16NxKtemplate overloads fromexecutor_pa.cpptotranspose.hpp - Updated function signatures to include
tmp,group_size, andquant_key_bychannelparameters - Modified all call sites to pass additional parameters (including nullptr for unused
tmpparameter)
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose.hpp | Added three transpose_16NxK template overloads moved from executor_pa.cpp, including support for quantized types (i8, u8, u4) |
| src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/xattention.hpp | Updated calls to transpose_16NxK to include new parameters (nullptr for tmp, 0 for group_size, false for quant_key_bychannel) |
| src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp | Removed transpose_16NxK function definitions that were moved to transpose.hpp, added include for transpose.hpp |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| transpose_16NxK<uint32_t, ov::element::u32>(d, s, N, K >> 1, block_size, dst_stride, src_stride >> 1); | ||
| transpose_16NxK<uint32_t, ov::element::u32>(d, | ||
| s, | ||
| reinterpret_cast<uint32_t*>(0), |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using reinterpret_cast<uint32_t*>(0) to represent nullptr is non-idiomatic and less clear. Replace with nullptr or static_cast<uint32_t*>(nullptr) for better readability.
| reinterpret_cast<uint32_t*>(0), | |
| nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a valid comment.
@mangguo321 , could you please explicitly give your opinion on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a valid comment. @mangguo321 , could you please explicitly give your opinion on this?
Hi @maxnick, this code was originally implemented in executor_pa.cpp and moved here without modification. The original intent is unclear, but I think we can update it to use nullptr for now.
| } | ||
| transpose_16NxK<TDST, precision_of<TDST>::value>(dst, | ||
| tmp, | ||
| reinterpret_cast<TDST*>(0), |
Copilot
AI
Dec 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using reinterpret_cast<TDST*>(0) to represent nullptr is non-idiomatic and less clear. Replace with nullptr for better readability.
| reinterpret_cast<TDST*>(0), | |
| nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with nullptr.
| attn_dequant_by_channel_kernel<TDST, | ||
| SRC_PREC>(s, t, N, K, K / sub_byte_multiplier, src_stride, p_scales, p_zps); | ||
| } else { | ||
| static_assert(SRC_PREC == ov::element::i8, "i8 doesn't support by-channel quantization"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails for types different than i8, but error message suggest that i8 is not correct.
```should the condition be SRC_PREC != ov::element::i8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. I've fixed in the latest commit, thanks!

Details:
Tickets: