-
Notifications
You must be signed in to change notification settings - Fork 270
173 implement device grouped gemm fixed nk for rdna4 #3668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
173 implement device grouped gemm fixed nk for rdna4 #3668
Conversation
…ement-device_grouped_gemm_fastgelu-for-rdna4
zsotakal
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! A few minor things, mostly related to code quality.
|
|
||
| for(int i = 0; i < problem_size.group_count; i++) | ||
| { | ||
| problem_size.Ms.push_back(128 + rand() % 128); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you inteted to change the xdl example?
| id_local += grid_size_grp; | ||
| } | ||
|
|
||
| #undef TRACE_THREAD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed or just a leftover from the troubleshooting?
| typename BElementwiseOperation, | ||
| typename CDEElementwiseOperation, | ||
| GemmSpecialization GemmSpec, | ||
| ck::index_t NumGemmKPrefetchStage, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this parameter is XDL only
| std::ostringstream err; | ||
| err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ | ||
| << ":" << __LINE__ << ", in function: " << __func__; | ||
| // throw std::runtime_error(err.str()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you intentionally remove runtime_errror?
| arg.c_element_op_); | ||
| }; | ||
|
|
||
| // const auto tail_num = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find this tail number logic anywhere in the existing code. Also there are seemingly related commented out sections below. Can you explain a bit what was the motivation behind it?
| { | ||
| using GemmSpecialization = tensor_operation::device::GemmSpecialization; | ||
| constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding || | ||
| // using GemmSpecialization = tensor_operation::device::GemmSpecialization; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dead code (repeats multiple times)
| b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread); | ||
| ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k[i]); | ||
| ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n[i]); | ||
| max_abs_in_val = 10.f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this supposed to be 5.0f?
| target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE ${PROFILER_LIBS}) | ||
|
|
||
| rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code block intented for troubleshooting. Remove pls
| #include "ck/utility/tuple.hpp" | ||
|
|
||
| #include "ck/tensor_description/tensor_descriptor.hpp" | ||
| #include "ck/tensor_description/tensor_descriptor_helper.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This header is not used.
| >(static_cast<void*>(p_shared), | ||
| splitk_batch_offset, | ||
| kernel_arg, | ||
| block_2_etile_map, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
block_2_etile_map and splitk_batch_offset can be defined as const.
| 1, 1, 1, 1, 1))>; | ||
|
|
||
| template <typename UnderlyingBlockToCTileMap> | ||
| struct OffsettedBlockToCTileMapMLoops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OffsettedBlockToCTileMapMLoops is the same struct as in the Xdl variant. Maybe consider moving it to the base class DeviceGroupedGemmFixedNK?
| }; | ||
|
|
||
| template <index_t MPerBlock_, index_t NPerBlock_> | ||
| struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as my other comment. Maybe consider moving this struct definition to the base class?
| } | ||
| } | ||
|
|
||
| // private: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see this comes from the XDL variant, but this line serves no purpose other than (in the best case) annoy the reader or (in the worst case) make them wonder the reason behind it and lose time on it.
| } | ||
| } | ||
|
|
||
| // if constexpr(std::is_same<ADataType, ck::bhalf_t>::value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix or remove this
| { | ||
| ignore = tail_num; | ||
| lambda(std::integral_constant<TailNumber, TailNumber::Full>{}); | ||
| // switch(tail_num) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as my comment above, fix this or remove it.
| #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" | ||
| #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" | ||
| #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" | ||
| #include "ck/utility/loop_scheduler.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used, include scheduler_enum instead
| #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" | ||
| #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" | ||
| #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" | ||
| #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This include is not used
| static constexpr InstanceVariant InstanceVariants[] = { | ||
|
|
||
| make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1), | ||
| // make_tuple(GemmDefault, InterwaveScheduler, PipelineV1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove those comments
| bool pass = true; | ||
| for(int kbatch : kbatches) | ||
| { | ||
| pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl<ADataType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function call may throw. What happens if it does? Should you maybe pass the fail_if_no_supported_instances as an argument and early return without throwing? Or at least use a try-catch block here?
|
Added some comments you might want to address. Great work overall! :) |
Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered