Skip to content

Conversation

@sarthak-amd
Copy link
Collaborator

@sarthak-amd sarthak-amd commented Jan 20, 2026

Description

Implements the MXFP4 rowwise and columnwise FP32/BF16 -> MXFP4 fused quantization + cast kernel

  • Verify Tolerances and functional Unit Tests

  • The triton te_cast_transpose_mxfp4_triton currently outputs FP4 data in linear layout [M, N/2] with contiguous byte packing. AITER's gemm_a4w4 requires the B matrix in MFMA shuffle layout for tensor cores. This layout shuffle can be fused into the triton kernel in future.

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You

import numpy as np
import os

os.environ["USE_TRITON_FUSED_CAST_TRANSPOSE"] = "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We previously already defined env NVTE_USE_CAST_TRANSPOSE_TRITON.

def test_quantize_mxfp4(shape, in_dtype, rowwise, columnwise, shuffle_B_matrix):
"""Test MXFP4 quantization for rowwise/columnwise modes with/without FP4 shuffle.

Note: FP4 data shuffle (shuffle_B_matrix_for_aiter) is not yet supported in Triton kernel.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If FP4 data shuffle is not yet supported in Triton kernel, why do we need to add it here?

(32768, 160),
(4096, 1632),
(8, 32, 1024),
(16, 8, 4, 512),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some prime numbers like

{1, 3221}, // Prime 456
{2333, 1}, // Prime 345
{1481, 677}}; // Primes 234, 123

Comment on lines +127 to +128
data_atol = 20.0 if in_dtype != torch.float32 else 16.0
scale_atol = 2.0 if in_dtype != torch.float32 else 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data tol seems to be quite large. You can follow our mxfp8 scale and data adjustment scheme:

void adjust_ref_for_e8m0_scale_error(const std::string &name,

use_torch_semantics=True
)

# Compare only valid (non-padded) region - no shuffle extraction needed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is fp4 shuffle?

.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to enable kFloat4E2M1, there are other related changes needed. Search for https://github.com/search?q=repo%3AROCm%2FTransformerEngine%20kFloat4E2M1&type=code for more details:

Image

Comment on lines +61 to +62
- Data: [M, K/2] uint8 tensor (2 FP4 values packed per byte)
- Scale: [M, K/32] uint8 tensor (E8M0 format, one scale per 32-element block)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there alignment/padding requirements for M and K?

Comment on lines +113 to +114
if inp.ndim < 2:
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TE currently supported 2D matrices from flatted high-dimensional tensors:

size_t flat_first_dim() const {
const auto &full_shape = shape();
size_t ret = 1;
if (!full_shape.empty()) {
for (size_t i = 0; i < full_shape.size() - 1; i++) {
ret *= full_shape[i];
}
}
return ret;
}
/*! Matrix width after tensor is flattened to 2D
*
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_last_dim() const {
const auto &full_shape = shape();
if (full_shape.empty()) {
return 1;
} else {
return full_shape.back();
}
}
};


# Allocate PADDED scale tensors for shuffle compatibility
rowwise_scale_N = K // MXFP4_BLOCK_SCALING_SIZE
rowwise_scale_M_pad = cdiv(M, 256) * 256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume this 256 is from some alignment/padding requirement?

@@ -0,0 +1,178 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to add this pytest into our ci script (somewhere near

run_default_fa 1 triton_kernels/test_norms.py
) otherwise it won't be tested

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants