From 56780d13dfca69127c52187db24cfdcbcba2720c Mon Sep 17 00:00:00 2001 From: Evgeny Date: Thu, 21 May 2026 17:35:56 +0000 Subject: [PATCH 1/2] Enable colwise only 2d nvfp4 Signed-off-by: Evgeny --- .../nvfp4/test_nvfp4_quantize_exact.py | 97 +++++++++++++++++++ ...quantize_transpose_vector_blockwise_fp4.cu | 70 +++++++++++-- 2 files changed, 160 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 53569d90d9..371de99068 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -534,3 +534,100 @@ def test_nvfp4_quantization_noncontiguous_inputs( torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # Aligned tiles + (128, 128), + (256, 256), + (512, 512), + (2048, 2048), + # Padded tiles (non-multiple of kTileDim=128) + (256, 272), + (304, 304), + (320, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_2d_columnwise_only_matches_both_directions( + x_dtype: torch.dtype, + M: int, + N: int, +): + """Bitwise check: 2D NVFP4 with columnwise-only must produce the same + columnwise data/scales as the columnwise half of (rowwise + columnwise) 2D. + + Exercises the columnwise-only path through the 2D-amax-only pass added to + ``quantize_transpose_vector_blockwise_fp4.cu``. Before that change, this + configuration was rejected by + ``NVTE_CHECK(return_identity || !use_2d_quantization)``. + """ + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((M, N), dtype=x_dtype, device=device) + + def _make_quantizer(*, rowwise: bool, columnwise: bool) -> NVFP4Quantizer: + return NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=rowwise, + columnwise=columnwise, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + row_scaled_nvfp4=False, + ) + + # Reference: produce both directions in a single kernel call. + q_both = _make_quantizer(rowwise=True, columnwise=True) + out_both = q_both(x) + + # SUT: produce columnwise only (the path that hits the new amax-only pass). + q_col_only = _make_quantizer(rowwise=False, columnwise=True) + out_col_only = q_col_only(x) + + # Columnwise data/scales/amax must be bitwise identical between the two paths. + # If amax_smem is populated differently in the column-only path, scales diverge, + # and the FP4 cast (which divides by encode_scale) produces different bytes. + assert out_both._columnwise_data is not None + assert out_col_only._columnwise_data is not None + torch.testing.assert_close( + out_col_only._columnwise_data.view(dtype=torch.uint8), + out_both._columnwise_data.view(dtype=torch.uint8), + atol=0, + rtol=0, + ) + + # Compare only the valid (in-bounds) region of the columnwise scale tensor. + # The padded tail (rows K..round_up(K, 128), cols ceil(M/16)..round_up(.., 4)) + # exists for cuBLAS alignment and is NEVER written by the kernel — its bytes + # are whatever ``at::empty`` returned, which differs between two allocations. + NVFP4_BLOCK = 16 + valid_outer = N # cols of input == rows of columnwise scale tensor + valid_inner = (M + NVFP4_BLOCK - 1) // NVFP4_BLOCK + assert out_both._columnwise_scale_inv is not None + assert out_col_only._columnwise_scale_inv is not None + col_sx_both = out_both._columnwise_scale_inv.view(dtype=torch.uint8) + col_sx_col_only = out_col_only._columnwise_scale_inv.view(dtype=torch.uint8) + torch.testing.assert_close( + col_sx_col_only[:valid_outer, :valid_inner], + col_sx_both[:valid_outer, :valid_inner], + atol=0, + rtol=0, + ) + + assert out_both._amax_columnwise is not None + assert out_col_only._amax_columnwise is not None + torch.testing.assert_close( + out_col_only._amax_columnwise, out_both._amax_columnwise, atol=0, rtol=0 + ) + + # Sanity: column-only path must not allocate a rowwise output. + assert out_col_only._rowwise_data is None diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index cf9821f1a9..8f970864b0 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -353,14 +353,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo extern __shared__ char smem_base[]; SMemVec* smem = reinterpret_cast(&smem_base[0]); - // 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode. - // Instead of static_assert, return early if these invalid modes are detected. + // 2D block scaling is not supported for E8 scaling MXFP4. + // Instead of static_assert, return early if this invalid mode is detected. if constexpr (kIs2DBlockScaling && kIsE8Scaling) { return; } - if constexpr (kIs2DBlockScaling && !kReturnIdentity) { - return; - } // for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4 // use constexpr to define the size, when not using 2D, use minimal size 1x1 constexpr int kFP4BlockScalingSize = 16; @@ -576,6 +573,67 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } + // Step 2.5: 2D-amax-only pass for columnwise-only mode. + // When only the transposed output is requested but 2D block scaling is enabled, the columnwise + // reads in Step 3 (line ~660 below) still need amax_smem populated. Re-run the load + local-amax + // + 2D warp/smem reduction from Step 2 (steps 2.1-2.3), skipping the rowwise scale/quantize/store + // writes that Step 2 normally does. Same amax_smem values as the rowwise-enabled path, so the + // dgrad/wgrad columnwise output of (rowwise=False, columnwise=True, 2D) is bitwise identical to + // the columnwise half of (rowwise=True, columnwise=True, 2D). + if constexpr (!kReturnIdentity && kIs2DBlockScaling) { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; // 4 iterations for kTileDim=128 + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1 (amax-only): Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2 (amax-only): Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3 (amax-only): 2D warp + smem amax reduction (mirrors Step 2's 2D path) + constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32 + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + int tid_in_warp_x = threadIdx.x % kNumThreadsStore; + int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; + CType amax_warp_reduced = groupMax( + amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]); + int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; + if (tid_in_warp_y == 0) { + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] + [warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced; + } + __syncthreads(); + + if (data_row_idx % kFP4BlockScalingSize == 0) { + CType amax_2d = 0.0; + for (int i = 0; i < k2DBlockAmaxReduceDim; i++) { + amax_2d = fmaxf(amax_2d, + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); + } + amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d; + } + __syncthreads(); + r_s += r_stride; + } + } + // Step 3: Transpose, cast and store to output_t if constexpr (kReturnTranspose) { constexpr int c_stride = @@ -731,8 +789,6 @@ void quantize_transpose_vector_blockwise_fp4( NVTE_CHECK(return_identity || return_transpose, "At least one of return_identity or return_transpose must be true."); - NVTE_CHECK(return_identity || !use_2d_quantization, - "2D block quantization is only supported when return_identity is true."); NVTE_CHECK(!row_scaled_nvfp4 || (return_identity && !return_transpose), "Row-scaled NVFP4 quantization only supports rowwise quantization."); NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, From 61a238713c5f229172da4a778ab280918a25ab63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 May 2026 17:40:23 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../transpose/quantize_transpose_vector_blockwise_fp4.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 8f970864b0..0efef2c7af 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -624,8 +624,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo if (data_row_idx % kFP4BlockScalingSize == 0) { CType amax_2d = 0.0; for (int i = 0; i < k2DBlockAmaxReduceDim; i++) { - amax_2d = fmaxf(amax_2d, - amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); + amax_2d = + fmaxf(amax_2d, amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); } amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d; }