[Common] Enable NVFP4 2D block scaling in columnwise only#3027
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables NVFP4 2D block scaling in columnwise-only mode by removing the early-return guard and host-side check that previously blocked
Confidence Score: 4/5Safe to merge; the new amax-only pass correctly mirrors the existing Step 2 reduction and is guarded by compile-time template parameters, leaving the existing rowwise and dual-direction paths completely unchanged. The CUDA kernel change is surgical: the new Step 2.5 block is mutually exclusive with Step 2 at compile time, uses identical warp-shuffle and shared-memory reduction logic, and the final __syncthreads() inside the loop correctly makes amax_smem visible before Step 3 reads it. The bitwise-equality test covers both aligned and padded shapes in two dtypes. The one minor issue is a step-label name collision that has no runtime impact. transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu — the new Step 2.5 block and surrounding synchronization are the only areas worth a second look. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Kernel Entry] --> B{kIs2DBlockScaling
&& kIsE8Scaling?}
B -- yes --> Z[Early return]
B -- no --> C[Step 1: Load input to smem
__syncthreads]
C --> D{kReturnIdentity?}
D -- yes --> E[Step 2: Cast and store rowwise
2.1 load smem to regs
2.2 local amax
2.3 2D warp+smem reduction to amax_smem
2.4-2.8 scale / quant / store]
D -- no --> F{kIs2DBlockScaling?}
F -- yes --> G[Step 2.5 NEW: amax-only pass
load smem to regs
local amax
2D warp+smem reduction to amax_smem
no scale/quant/store]
F -- no --> H{kReturnTranspose?}
E --> H
G --> H
H -- yes --> I[Step 3: Transpose cast and store columnwise
read amax_smem for 2D path
compute scale / quant / store]
H -- no --> J[Done]
I --> J
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| } | ||
| } | ||
|
|
||
| // Step 2.5: 2D-amax-only pass for columnwise-only mode. |
There was a problem hiding this comment.
Step label collision with existing substep
The new outer-level block is named "Step 2.5" at line 576, but that same label is already used at line 522 for the "Write scale_inv" substep inside Step 2's loop (if constexpr (kReturnIdentity)). A future reader scanning the file will find two distinct "Step 2.5" sections with different semantics. Consider renaming the new block to something like "Step 2b" or "Step 2.5 (outer)" to distinguish it from the // Step 2.5: Write scale_inv substep inside the inner loop.
|
This is just the fallback kernel being changed. Does the main kernel already support this? |
Description
Enabling 2D NVFP4 quantization in columnwise-only mode.
Needed by HybridQuantizer (PR #2817) for MXFP8 fwd + NVFP4 bwd on W.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: