Skip to content

Conversation

@bidlekm
Copy link

@bidlekm bidlekm commented Jan 28, 2026

Proposed changes

Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

Copy link
Contributor

@zsotakal zsotakal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! A few minor things, mostly related to code quality.


for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(128 + rand() % 128);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you inteted to change the xdl example?

id_local += grid_size_grp;
}

#undef TRACE_THREAD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed or just a leftover from the troubleshooting?

typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this parameter is XDL only

std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
// throw std::runtime_error(err.str());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you intentionally remove runtime_errror?

arg.c_element_op_);
};

// const auto tail_num =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find this tail number logic anywhere in the existing code. Also there are seemingly related commented out sections below. Can you explain a bit what was the motivation behind it?

{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding ||
// using GemmSpecialization = tensor_operation::device::GemmSpecialization;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code (repeats multiple times)

b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n[i]);
max_abs_in_val = 10.f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this supposed to be 5.0f?

target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE ${PROFILER_LIBS})

rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code block intented for troubleshooting. Remove pls

#include "ck/utility/tuple.hpp"

#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This header is not used.

>(static_cast<void*>(p_shared),
splitk_batch_offset,
kernel_arg,
block_2_etile_map,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_2_etile_map and splitk_batch_offset can be defined as const.

1, 1, 1, 1, 1))>;

template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OffsettedBlockToCTileMapMLoops is the same struct as in the Xdl variant. Maybe consider moving it to the base class DeviceGroupedGemmFixedNK?

};

template <index_t MPerBlock_, index_t NPerBlock_>
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my other comment. Maybe consider moving this struct definition to the base class?

}
}

// private:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see this comes from the XDL variant, but this line serves no purpose other than (in the best case) annoy the reader or (in the worst case) make them wonder the reason behind it and lose time on it.

}
}

// if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix or remove this

{
ignore = tail_num;
lambda(std::integral_constant<TailNumber, TailNumber::Full>{});
// switch(tail_num)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my comment above, fix this or remove it.

#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/utility/loop_scheduler.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used, include scheduler_enum instead

#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This include is not used

static constexpr InstanceVariant InstanceVariants[] = {

make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1),
// make_tuple(GemmDefault, InterwaveScheduler, PipelineV1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove those comments

bool pass = true;
for(int kbatch : kbatches)
{
pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl<ADataType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function call may throw. What happens if it does? Should you maybe pass the fail_if_no_supported_instances as an argument and early return without throwing? Or at least use a try-catch block here?

@chris-tsiaousis-hpc
Copy link
Contributor

Added some comments you might want to address. Great work overall! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants