Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bc363fa
add MX scale pre-swizzling for gfx1250
matthiasdiener Apr 27, 2026
a6ca3af
switch to mxfp4
matthiasdiener Apr 27, 2026
d1ee5bd
tensile-like implementation
matthiasdiener Apr 28, 2026
d1647ee
Merge remote-tracking branch 'upstream/dev' into mdiener/mxfp8-swizzle
matthiasdiener Apr 29, 2026
1fff6d9
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 1, 2026
d714038
gfx1250 swizzle_xor changes for FP4
matthiasdiener May 1, 2026
76ca4b1
change line endings to unix, trim trailing whitespace
matthiasdiener May 1, 2026
81a0a27
Merge branch 'mdiener/swizzle_xor-1250' into mdiener/mxfp8-swizzle
matthiasdiener May 1, 2026
2991bcf
fix arch
matthiasdiener May 1, 2026
8ceb89c
[WIP] e2e gemm test, not working yet
matthiasdiener May 1, 2026
167d2eb
fix for gfx1250
matthiasdiener May 3, 2026
5d46537
k-tile
matthiasdiener May 3, 2026
313a6b7
extend tests
matthiasdiener May 3, 2026
2a8eeb5
remove ifdef
matthiasdiener May 3, 2026
c37a781
undo BLK32_UE8M0_32_8_EXT
matthiasdiener May 4, 2026
5d2d38f
Merge remote-tracking branch 'upstream/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 5, 2026
f093f64
Revert "change line endings to unix, trim trailing whitespace"
matthiasdiener May 5, 2026
ecbffea
Revert "gfx1250 swizzle_xor changes for FP4"
matthiasdiener May 5, 2026
33fca6e
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 13, 2026
b55a538
address review comments
matthiasdiener May 13, 2026
398cc3c
cleanups
matthiasdiener May 13, 2026
384d590
re-add scale swizzle hooks in GEMM paths for gfx1250
matthiasdiener May 13, 2026
5c5a902
cleanups
matthiasdiener May 13, 2026
2c05ec5
arch fixes
matthiasdiener May 14, 2026
5552b09
more test fixes gfx1250
matthiasdiener May 18, 2026
bdee033
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle
matthiasdiener May 19, 2026
90db6f4
address review comments
matthiasdiener May 19, 2026
2a6302d
additional padding
matthiasdiener May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ list(APPEND test_cuda_sources
test_multi_unpadding.cu
test_causal_softmax.cu
test_swap_first_dims.cu
test_swizzle.cu
../test_common.cu)
if(USE_CUDA)
list(APPEND test_cuda_sources
test_cast_float8blockwise.cu
test_swizzle.cu)
test_cast_float8blockwise.cu)
else()
list(APPEND test_cuda_sources
test_cublaslt_gemm.cu
Expand Down
107 changes: 90 additions & 17 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/swizzle.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

Expand All @@ -30,7 +31,15 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{32, 128, 16},
{64, 128, 32},
{128, 128, 64},
{64, 256, 32},
{128, 384, 64},
{256, 512, 128},
{512, 1024, 256},
{768, 3072, 4096},
{1024, 2048, 128},
{4096, 8192, 64},
};

// A, B, Bias, Gelu, D
Expand Down Expand Up @@ -303,6 +312,42 @@ void cpu_rowwise_to_columnwise(
}
}

#ifdef __HIP_PLATFORM_AMD__
// Swizzle MXFP8 scale_inv of a test::Tensor in-place for gfx1250.
static void swizzle_mxfp8_scales(test::Tensor &t, bool rowwise) {
using namespace transformer_engine;
void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr()
: t.columnwise_scale_inv_dptr();
if (!scale_ptr) return;
const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape()
: t.columnwise_scale_inv_shape();
const NVTEShape data_shape = rowwise ? t.rowwise_shape()
: t.columnwise_shape();
size_t num_scales = 1;
for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d];
uint8_t *d_tmp = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales));
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);
if (rowwise) {
input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
} else {
input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
}
nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice));
NVTE_CHECK_CUDA(cudaFree(d_tmp));
}
#endif // __HIP_PLATFORM_AMD__

std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool use_mxfp8) {
auto [atol, rtol] = getTolerances(type);

Expand All @@ -318,6 +363,14 @@ std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool
else if (use_fp8) {
atol = 1e-3;
rtol = std::max(rtol, 1e-2);
#ifdef __HIP_PLATFORM_AMD__
// Relax for gfx1250
cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);
if (prop.major == 12 && type == DType::kBFloat16) {
rtol = std::max(rtol, 5e-2);
}
#endif
}
else if (type == DType::kBFloat16) {
//relax for certain prime number TN gemm
Expand Down Expand Up @@ -496,6 +549,33 @@ void performTest(const TestParams& params) {
#endif
Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte);

//perform the reference gemm on GPU (before swizzle, which modifies scales in-place)
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

#ifdef __HIP_PLATFORM_AMD__
// On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales.
if (use_mxfp8 && prop.major == 12) {
if (!a_colwise) swizzle_mxfp8_scales(A, true);
if (a_colwise) swizzle_mxfp8_scales(A, false);
if (!b_colwise) swizzle_mxfp8_scales(B, true);
if (b_colwise) swizzle_mxfp8_scales(B, false);
}
#endif

//perform the gemm in GPU
nvte_cublas_gemm(A.data(),
B.data(),
Expand All @@ -517,23 +597,6 @@ void performTest(const TestParams& params) {
pre_gelu_out.to_cpu();
}

//perform the reference gemm on GPU
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

// check if error message happens in running
(void)cudaDeviceSynchronize();
auto err = cudaGetLastError();
Expand Down Expand Up @@ -605,6 +668,16 @@ void performDqTest(const TestParams &params) {
nvte_dequantize(A_fp8.data(), A_ref.data(), 0);
nvte_dequantize(B_fp8.data(), B_ref.data(), 0);

// On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales.
if (prop.major == 12) {
const bool a_colwise = !params.transa;
const bool b_colwise = params.transb;
if (!a_colwise) swizzle_mxfp8_scales(A_fp8, true);
if (a_colwise) swizzle_mxfp8_scales(A_fp8, false);
if (!b_colwise) swizzle_mxfp8_scales(B_fp8, true);
if (b_colwise) swizzle_mxfp8_scales(B_fp8, false);
}

Tensor bias;
Tensor pre_gelu_out;

Expand Down
180 changes: 180 additions & 0 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,183 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<2>(info.param));
return name;
});

#ifdef __HIP_PLATFORM_AMD__

// MX pre-swizzle test (gfx1250 Tensile 3D layout)
//
// Tensile 3D: {K_scale, M}.reshape({K_scale, padM/4, 4}).permute({1, 0, 2})
// For source (m, k): dst = (m/4) * (K*4) + k*4 + (m%4)

// CPU reference for Tensile 3D MX scale pre-swizzle.
// Row-major input [M, K], output is a flat permuted array.
void compute_ref_mx_swizzle_row(const uint8_t *h_input, uint8_t *h_output,
const int M, const int K,
const int orig_M, const int orig_K) {
constexpr int GROUP = 4;
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
uint8_t val = 127; // E8M0 identity: 2^0 = 1.0
if (m < orig_M && k < orig_K) {
val = h_input[m * orig_K + k];
}
int group = k / GROUP;
int within = k % GROUP;
int dst = group * (M * GROUP) + m * GROUP + within;
h_output[dst] = val;
}
}
}

void compute_ref_mx_swizzle_col(const uint8_t *h_input, uint8_t *h_output,
const int M, const int K,
const int orig_M, const int orig_K) {
constexpr int GROUP = 4;
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
uint8_t val = 127;
if (m < orig_M && k < orig_K) {
val = h_input[k * orig_M + m];
}
int group = k / GROUP;
int within = k % GROUP;
int dst = group * (M * GROUP) + m * GROUP + within;
h_output[dst] = val;
}
}
}

static size_t roundup_sz(size_t val, size_t mult) {
return ((val + mult - 1) / mult) * mult;
}

class MxSwizzleTestSuite
: public ::testing::TestWithParam<
std::tuple<std::pair<int, int>, bool>> {};

TEST_P(MxSwizzleTestSuite, TestMxSwizzle) {
Comment thread
alextmagro marked this conversation as resolved.
using namespace transformer_engine;
using namespace test;

cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if (prop.major < 12) {
GTEST_SKIP() << "MXFP8 pre-swizzle is only supported on gfx1250";
}

const auto dims = std::get<0>(GetParam());
const bool rowwise = std::get<1>(GetParam());

// Original (unpadded) scale dimensions
const size_t orig_M = dims.first;
const size_t orig_K = dims.second;

// Padded dimensions: K-tiled layout requires K_scale padded to multiple of 4
const size_t M = orig_M;
const size_t K = roundup_sz(orig_K, 4);

// Allocate host input (unpadded) and fill with random data
const size_t input_size = orig_M * orig_K;
std::unique_ptr<uint8_t[]> h_input(new uint8_t[input_size]);
std::mt19937 rng(42);
for (size_t i = 0; i < input_size; i++) {
h_input[i] = static_cast<uint8_t>(rng() % 256);
}

// Allocate device input
uint8_t *d_input = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_input, input_size));
NVTE_CHECK_CUDA(cudaMemcpy(d_input, h_input.get(), input_size, cudaMemcpyHostToDevice));

// Allocate device output (padded size)
const size_t output_size = M * K;
uint8_t *d_output = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_output, output_size));
NVTE_CHECK_CUDA(cudaMemset(d_output, 0, output_size));

// Build TensorWrapper for input and output
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);

// Data shape must be consistent with scale shape for validation.
// Scale shapes use padded K; data shapes use unpadded dims
// (kernel derives original_M/K from them).
if (rowwise) {
std::vector<size_t> data_shape_in = {orig_M, orig_K * 32};
std::vector<size_t> data_shape_out = {M, K * 32};
std::vector<size_t> scale_shape_in = {M, K};
std::vector<size_t> scale_shape_out = {M, K};
input_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_in);
input_tw.set_rowwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in);
output_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_out);
output_tw.set_rowwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out);
} else {
std::vector<size_t> data_shape_in = {orig_K * 32, orig_M};
std::vector<size_t> data_shape_out = {K * 32, M};
std::vector<size_t> scale_shape_in = {K, M};
std::vector<size_t> scale_shape_out = {K, M};
input_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_in);
input_tw.set_columnwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in);
output_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_out);
output_tw.set_columnwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out);
}

nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);

NVTE_CHECK_CUDA(cudaDeviceSynchronize());

// Copy output back to host
std::unique_ptr<uint8_t[]> h_output(new uint8_t[output_size]);
NVTE_CHECK_CUDA(cudaMemcpy(h_output.get(), d_output, output_size, cudaMemcpyDeviceToHost));

// Compute reference
std::unique_ptr<uint8_t[]> h_ref(new uint8_t[output_size]);
memset(h_ref.get(), 0, output_size);
if (rowwise) {
compute_ref_mx_swizzle_row(h_input.get(), h_ref.get(), M, K, orig_M, orig_K);
} else {
compute_ref_mx_swizzle_col(h_input.get(), h_ref.get(), M, K, orig_M, orig_K);
}

// Compare
compareResults("mx_swizzle", h_output.get(), h_ref.get(), output_size);

cudaFree(d_input);
cudaFree(d_output);
}

namespace {

// Scale dimensions (M_scale, K_scale).
// K_scale will be padded to multiple of 4 by the test.
std::vector<std::pair<int, int>> mx_scale_dims = {
{4, 4}, // minimal
{8, 4}, // small
{32, 8}, // medium
{64, 16}, // larger
{96, 8}, // non-power-of-2 M
{128, 32}, // big
{256, 64}, // bigger
{512, 128}, // stress inter-tile
{1024, 256}, // large
{4096, 256}, // max stress
};

} // namespace

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MxSwizzleTestSuite,
::testing::Combine(
::testing::ValuesIn(mx_scale_dims),
::testing::Values(true, false)
),
[](const testing::TestParamInfo<MxSwizzleTestSuite::ParamType>& info) {
std::string name = "M" + std::to_string(std::get<0>(info.param).first) +
"_K" + std::to_string(std::get<0>(info.param).second) +
(std::get<1>(info.param) ? "_row" : "_col");
return name;
});

#endif // __HIP_PLATFORM_AMD__
Loading
Loading