Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644
Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644zianglih wants to merge 63 commits intoNVIDIA:mainfrom
NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644Conversation
Greptile SummaryThis PR adds
Confidence Score: 5/5Safe to merge — all previously flagged blocking issues are resolved; remaining findings are P2 style suggestions. All previously raised P0/P1 concerns (duplicate recipe fields, recipe None crash, LayerNormMLP assertion message quality, unnecessary saved tensors, DelayedScaling interaction) are addressed in this revision. The feature is guarded by explicit assertions with clear error messages for unsupported combinations. The only new findings are a defensive getattr suggestion in fuser.py and a minor asymmetry in empty-tensor guards for MXFP8 storage, both P2. Comprehensive test coverage was added. transformer_engine/pytorch/ops/fuser.py (minor: direct attribute access on recipe.backward_override), transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py (minor: asymmetric empty-tensor guard in dequantize vs _FromMXFP8Func.forward) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Forward Pass - quantized fprop] --> B{NVTE_BACKWARD_OVERRIDE}
B -->|None| C[Default: save rowwise+columnwise quantized tensors]
B -->|high_precision| D[Save original unquantized input and weight]
B -->|dequantized| E[Save rowwise-only quantized tensors]
C --> F[Backward: quantized dgrad and wgrad GEMMs]
D --> G[Backward: high-precision dgrad and wgrad using original fp16/bf16/fp32 operands]
E --> H[Backward: dequantize saved tensors then high-precision GEMMs]
subgraph Supported
L[Linear]
M[LayerNormLinear]
N[GroupedLinear]
O[ops.Linear / fused ops]
end
subgraph Unsupported - assertion error with clear message
P[LayerNormMLP]
Q[DelayedScaling recipe]
end
Reviews (43): Last reviewed commit: "Merge branch 'main' into keep-bwd" | Re-trigger Greptile |
|
I'll work on potential unit test breakage. |
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?
| not ctx.use_bias | ||
| and not ctx.requires_wgrad | ||
| and ctx.grad_output_quantizer is not None | ||
| and use_fp8_bwd |
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
Maybe it's better to assert an error for delayed scaling? Okay with both.
There was a problem hiding this comment.
I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this seems redundant too if we skip quant in grad_output_preprocess
Naming part I agree but I have no strong opinion. |
|
Hi @ksivaman , thanks for reviewing!
Currently the
Yes we have very good reasons in RL use cases since it best preserves chain rule and serves as an STE. Our experiments showed clearly more stable gradient curves compared with
Yes I can change naming to |
|
Can you clarify the dequant method here? For fprop, we quantize and get input_fp8, and weight_fp8, and then for dequantize you also dequantize both, is that right? |
|
Hi @zhongbozhu ,
This is exactly right. The fprop uses quantized compute specified by the quantization recipe with no behavioral changes. In bwd, input_fp8 is dequantized for high-precision wgrad, weight_fp8 is dequantized for high-precision dgrad, gradient is always kept in high-precision and gradient quantization never happens. The movitivation for this
|
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
…zed` Signed-off-by: Ziang Li <ziangli@umich.edu>
NVTE_BACKWARD_MODE=default|unquant|dequantNVTE_BACKWARD_OVERRIDE=high_precision|dequantized
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@victordion I think you are describing the default TE 1d recipe or requantized behavior. |
Right. My mistake. My mental model assumed there is requantize happening. Thanks for responding! |
|
Regarding the env var design, since this feature is mainly used by RL, there has to be a way for the user to directly override the bwd behavior in RL framework instead of plumbing all the way through Megatron. |
|
/te-ci L0 L1 |
|
All pytorch ci passed. Some failed jax tests are due to |
|
/te-ci L0 L1 |
|
/te-ci L0 L1 |
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
|
/te-ci L0 L1 |
|
Failed JAX ci is unrelated to this PR: B200: L40: |

Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.Add NVTE_BACKWARD_MODE=default|unquant|dequant env varAdd
NVTE_BACKWARD_OVERRIDE=high_precision|dequantizedenv var:high_precision: quantized fprop + high precision wgrad & dgrad using unquantized activation and weightdequantized: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized valueThe movitivation for this dequantized design is RL. Unlike pre-training which only needs to preserve coarse optimization direction and convergence, RL gradients are noisy and useful updates are small and delicate. If gradient quantization and chain rule violation are present, noise dominates the true and fragile update signal and model will collapse. This dequantized design avoids gradient quantization and effectively preserves chain rule.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: