Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions example/15_grouped_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)

add_example_executable(example_grouped_gemm_wmma_fixed_nk_fp16 grouped_gemm_wmma_fixed_nk_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_fp16)


list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
Expand Down
382 changes: 382 additions & 0 deletions example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
}

std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
return pass;
}

Expand Down Expand Up @@ -329,9 +330,9 @@ int main(int argc, char* argv[])

for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(128 + rand() % 128);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(1024);
problem_size.Ms.push_back(256);
problem_size.Ns.push_back(256);
problem_size.Ks.push_back(256);

problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
Expand Down
2 changes: 2 additions & 0 deletions example/15_grouped_gemm/run_grouped_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
ComputeDataType>(c_device_tensors[i], c_host_tensors[i]);
#endif
}

std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}

if(config.time_kernel)
Expand Down
Loading