-
Notifications
You must be signed in to change notification settings - Fork 23
MXFP4 Cast Transpose Triton [WIP] #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You
| import numpy as np | ||
| import os | ||
|
|
||
| os.environ["USE_TRITON_FUSED_CAST_TRANSPOSE"] = "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We previously already defined env NVTE_USE_CAST_TRANSPOSE_TRITON.
| def test_quantize_mxfp4(shape, in_dtype, rowwise, columnwise, shuffle_B_matrix): | ||
| """Test MXFP4 quantization for rowwise/columnwise modes with/without FP4 shuffle. | ||
|
|
||
| Note: FP4 data shuffle (shuffle_B_matrix_for_aiter) is not yet supported in Triton kernel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If FP4 data shuffle is not yet supported in Triton kernel, why do we need to add it here?
| (32768, 160), | ||
| (4096, 1632), | ||
| (8, 32, 1024), | ||
| (16, 8, 4, 512), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some prime numbers like
TransformerEngine/tests/cpp/operator/test_cast_transpose.cu
Lines 90 to 92 in 9d6b0e5
| {1, 3221}, // Prime 456 | |
| {2333, 1}, // Prime 345 | |
| {1481, 677}}; // Primes 234, 123 |
| data_atol = 20.0 if in_dtype != torch.float32 else 16.0 | ||
| scale_atol = 2.0 if in_dtype != torch.float32 else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Data tol seems to be quite large. You can follow our mxfp8 scale and data adjustment scheme:
TransformerEngine/tests/cpp/test_common.cu
Line 730 in 9d6b0e5
| void adjust_ref_for_e8m0_scale_error(const std::string &name, |
| use_torch_semantics=True | ||
| ) | ||
|
|
||
| # Compare only valid (non-padded) region - no shuffle extraction needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is fp4 shuffle?
| .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ | ||
| .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ | ||
| .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ | ||
| .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are going to enable kFloat4E2M1, there are other related changes needed. Search for https://github.com/search?q=repo%3AROCm%2FTransformerEngine%20kFloat4E2M1&type=code for more details:
| - Data: [M, K/2] uint8 tensor (2 FP4 values packed per byte) | ||
| - Scale: [M, K/32] uint8 tensor (E8M0 format, one scale per 32-element block) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there alignment/padding requirements for M and K?
| if inp.ndim < 2: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TE currently supported 2D matrices from flatted high-dimensional tensors:
TransformerEngine/transformer_engine/common/common.h
Lines 238 to 262 in 9d6b0e5
| size_t flat_first_dim() const { | |
| const auto &full_shape = shape(); | |
| size_t ret = 1; | |
| if (!full_shape.empty()) { | |
| for (size_t i = 0; i < full_shape.size() - 1; i++) { | |
| ret *= full_shape[i]; | |
| } | |
| } | |
| return ret; | |
| } | |
| /*! Matrix width after tensor is flattened to 2D | |
| * | |
| * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted | |
| * as a (D1*D2*...*D(n-1), Dn) matrix. | |
| */ | |
| size_t flat_last_dim() const { | |
| const auto &full_shape = shape(); | |
| if (full_shape.empty()) { | |
| return 1; | |
| } else { | |
| return full_shape.back(); | |
| } | |
| } | |
| }; |
|
|
||
| # Allocate PADDED scale tensors for shuffle compatibility | ||
| rowwise_scale_N = K // MXFP4_BLOCK_SCALING_SIZE | ||
| rowwise_scale_M_pad = cdiv(M, 256) * 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I presume this 256 is from some alignment/padding requirement?
| @@ -0,0 +1,178 @@ | |||
| # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will need to add this pytest into our ci script (somewhere near
TransformerEngine/ci/pytorch.sh
Line 74 in 9d6b0e5
| run_default_fa 1 triton_kernels/test_norms.py |
Description
Implements the MXFP4
rowwiseandcolumnwiseFP32/BF16 -> MXFP4 fused quantization + cast kernelVerify Tolerances and functional Unit Tests
The triton
te_cast_transpose_mxfp4_tritoncurrently outputs FP4 data in linear layout [M, N/2] with contiguous byte packing. AITER'sgemm_a4w4requires the B matrix in MFMA shuffle layout for tensor cores. This layout shuffle can be fused into the triton kernel in future.