CK Tile MXFP8 Group GEMM gfx1250#578
Conversation
…rash; remaining issue is numerical validation vs BF16 sequential reference.
| test_dequantize_mxfp8.cu | ||
| test_dequantize_nvfp4.cu | ||
| test_cast_nvfp4_transpose.cu | ||
| test_ck_grouped_mxfp8.cu |
There was a problem hiding this comment.
It should be for non CUDA only
| // Currently only support cutlass group gemm on Hopper Arch | ||
| if (!(is_hopper && use_cutlass)) { | ||
| // if (!(is_hopper && use_cutlass)) { | ||
| if (!use_cutlass) { |
| delay_wgrad_compute, | ||
| ): | ||
| os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" | ||
| os.environ["NVTE_ROCM_ENABLE_MXFP8"] = "1" |
There was a problem hiding this comment.
I think this should only be set when the recipe we are testing is mxfp8.
There was a problem hiding this comment.
Good point. Looking at the parametrization, MXFP8BlockScaling is only added to fp8_recipes when NVTE_ROCM_ENABLE_MXFP8=1 is already set before test collection. So setting it inside this test is redundant and also broader than intended. Removed in 746afea
|
|
||
| // Treat TE tensors as generalized 2D matrices by flattening: | ||
| // (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. | ||
| static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, |
There was a problem hiding this comment.
Re-use get_flat_2d_dims from ck_grouped_gemm_common.h
There was a problem hiding this comment.
I think some portion of the code is already present in ck_grouped_gemm_common.h inside ck_grouped_gemm folder. What was the reasoning behind having a separate directory for ck_mx_grouped_gemm?
There was a problem hiding this comment.
No, there really was not a good reason for this. I agree that it makes more sense to keep it all under the same directory, and re-use the common functions already defined in the shared header. I have made these changes in 175855d
Description
This PR integrates CK Tile MXFP8 grouped GEMM backend with TDM into TE. Replaces 3rdparty/aiter with 3rdparty/rocm-libraries for the gfx1250 changes from CK.
Fixes # (16490)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: