Skip to content

Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644

Open
zianglih wants to merge 63 commits intoNVIDIA:mainfrom
zianglih:keep-bwd
Open

Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644
zianglih wants to merge 63 commits intoNVIDIA:mainfrom
zianglih:keep-bwd

Conversation

@zianglih
Copy link
Copy Markdown

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Add NVTE_BACKWARD_MODE=default|unquant|dequant env var

Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized env var:

  • Not set: existing default quantization behavior
  • high_precision: quantized fprop + high precision wgrad & dgrad using unquantized activation and weight
    • image
  • dequantized: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized value
    • image

The movitivation for this dequantized design is RL. Unlike pre-training which only needs to preserve coarse optimization direction and convergence, RL gradients are noisy and useful updates are small and delicate. If gradient quantization and chain rule violation are present, noise dominates the true and fragile update signal and model will collapse. This dequantized design avoids gradient quantization and effectively preserves chain rule.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Summary

This PR adds NVTE_BACKWARD_OVERRIDE=high_precision|dequantized support, enabling high-precision backward passes (wgrad & dgrad) in combination with quantized fprop. The high_precision mode saves original unquantized activations/weights; dequantized mode saves quantized fprop tensors and dequantizes them before backward GEMMs to avoid gradient quantization errors during RL fine-tuning.

LayerNormMLP and DelayedScaling recipe intentionally reject NVTE_BACKWARD_OVERRIDE with clear assertion messages and guidance to use LayerNormLinear + Linear as an alternative. All other previously flagged concerns (duplicate recipe fields, recipe None crash, unnecessary saved tensors) appear resolved in this revision.

Confidence Score: 5/5

Safe to merge — all previously flagged blocking issues are resolved; remaining findings are P2 style suggestions.

All previously raised P0/P1 concerns (duplicate recipe fields, recipe None crash, LayerNormMLP assertion message quality, unnecessary saved tensors, DelayedScaling interaction) are addressed in this revision. The feature is guarded by explicit assertions with clear error messages for unsupported combinations. The only new findings are a defensive getattr suggestion in fuser.py and a minor asymmetry in empty-tensor guards for MXFP8 storage, both P2. Comprehensive test coverage was added.

transformer_engine/pytorch/ops/fuser.py (minor: direct attribute access on recipe.backward_override), transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py (minor: asymmetric empty-tensor guard in dequantize vs _FromMXFP8Func.forward)

Important Files Changed

Filename Overview
transformer_engine/common/recipe/init.py Adds backward_override field to all recipe dataclasses. DelayedScaling asserts backward_override is None with a clear error message. Previously flagged duplicate fields in Float8CurrentScaling are gone.
transformer_engine/pytorch/module/linear.py Adds backward_override detection; sets save_original_input=True for high_precision mode and saves unquantized/quantized operands accordingly. Properly asserts against Float8Quantizer (DelayedScaling) when save_original_input is true.
transformer_engine/pytorch/ops/basic/basic_linear.py Functional forward sets columnwise=False for quantized operands when backward_override is not None. op_forward chooses between saving original (high_precision) or quantized (dequantized) tensors. Backward pass dispatches dequantization accordingly.
transformer_engine/pytorch/module/layernorm_linear.py Correctly implements both high_precision and dequantized modes. Disables optimize_for_gemm for MXFP8/NVFP4 dequantized mode. save_for_backward order matches test expectations.
transformer_engine/pytorch/module/layernorm_mlp.py Explicitly asserts backward_override is None with a clear error message directing users to LayerNormLinear + Linear. Intentional documented limitation.
transformer_engine/pytorch/module/grouped_linear.py Adds backward_override support; when set, disables FP8/UB/debug context in backward. Correctly handles both override modes in wgrad/dgrad GEMMs.
transformer_engine/pytorch/ops/fuser.py Adds backward_override to the fusion cache key so fused op graphs are rebuilt when the override changes. Direct attribute access recipe.backward_override could AttributeError for custom Recipe subclasses not defining the field.
transformer_engine/pytorch/ops/basic/quantize.py Disables backward quantization when recipe.backward_override is not None. get_fp8_recipe() is guarded by the fp8_enabled check and always returns a recipe, so no None crash.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Adds empty-tensor early returns in both _FromMXFP8Func.forward (rowwise and columnwise) and MXFP8TensorStorage.dequantize (rowwise only). Asymmetry is functionally safe but inconsistent.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Adds empty-tensor early returns consistent with MXFP8 pattern. Functionally safe for the dequantized backward use case.
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Adds early return for empty rowwise tensor in dequantize, preventing errors when empty-token grouped-linear chunks are encountered in dequantized mode.
tests/pytorch/test_backward_override.py New comprehensive test file covering both override modes across all recipe types, module types, shapes, and fused op patterns. Layout invariant checks guard against hidden requantization during backward.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Forward Pass - quantized fprop] --> B{NVTE_BACKWARD_OVERRIDE}
    B -->|None| C[Default: save rowwise+columnwise quantized tensors]
    B -->|high_precision| D[Save original unquantized input and weight]
    B -->|dequantized| E[Save rowwise-only quantized tensors]
    C --> F[Backward: quantized dgrad and wgrad GEMMs]
    D --> G[Backward: high-precision dgrad and wgrad using original fp16/bf16/fp32 operands]
    E --> H[Backward: dequantize saved tensors then high-precision GEMMs]
    subgraph Supported
        L[Linear]
        M[LayerNormLinear]
        N[GroupedLinear]
        O[ops.Linear / fused ops]
    end
    subgraph Unsupported - assertion error with clear message
        P[LayerNormMLP]
        Q[DelayedScaling recipe]
    end
Loading

Reviews (43): Last reviewed commit: "Merge branch 'main' into keep-bwd" | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Copy Markdown
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?

not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
and use_fp8_bwd
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to assert an error for delayed scaling? Okay with both.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

@zhongbozhu
Copy link
Copy Markdown
Collaborator

  1. Not a fan of NVTE_BACKWARD_MODE, it's too generic. I am still not sure if this feature should be allowed via environment toggle. It's easy for the users but we should make it explicitly configurable via recipe API and not envvar.
  2. Is there a reason to have the dequant mode? Is it just for memory saving? Can't imagine it being numerically better that unquant. Either way, dequantized and high_precision might be better names for these features.

Naming part I agree but I have no strong opinion.

@zianglih
Copy link
Copy Markdown
Author

zianglih commented Mar 13, 2026

Hi @ksivaman , thanks for reviewing!

we should make it explicitly configurable via recipe API and not envvar

Currently the backward_mode is a configurable recipe member, not a global toggle. It is set by the NVTE_BACKWARD_MODE envvar. I can work on a better interface.

Is there a reason to have the dequant mode?

Yes we have very good reasons in RL use cases since it best preserves chain rule and serves as an STE. Our experiments showed clearly more stable gradient curves compared with default and unquant mode. unquant seems to have good numerics but violates chain rule more, which is acceptable in pre-training but not RL.

dequantized and high_precision might be better names for these features

Yes I can change naming to default|high_precision|dequantized.

@zhongbozhu
Copy link
Copy Markdown
Collaborator

Can you clarify the dequant method here? For fprop, we quantize and get input_fp8, and weight_fp8, and then for dequantize you also dequantize both, is that right?

@zianglih
Copy link
Copy Markdown
Author

zianglih commented Mar 13, 2026

Hi @zhongbozhu ,

For fprop, we quantize and get input_fp8, and weight_fp8, and then for dequantize you also dequantize both

This is exactly right. The fprop uses quantized compute specified by the quantization recipe with no behavioral changes. In bwd, input_fp8 is dequantized for high-precision wgrad, weight_fp8 is dequantized for high-precision dgrad, gradient is always kept in high-precision and gradient quantization never happens.

The movitivation for this dequantized design is RL. Unlike pre-training which only needs to preserve coarse optimization direction and convergence, RL gradients are noisy and useful updates are small and delicate. If gradient quantization and chain rule violation are present, noise dominates the true and fragile update signal and model will collapse. This dequantized design avoids gradient quantization and effectively preserves chain rule.

image

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
…zed`

Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih zianglih changed the title Add NVTE_BACKWARD_MODE=default|unquant|dequant Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized Mar 14, 2026
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Copy Markdown
Author

zianglih commented Mar 16, 2026

using "dequantized" in bwd still does not preserve the chain rule 100%, as the quantization in fwd and bwd happens along different dims

@victordion I think you are describing the default TE 1d recipe or requantized behavior.

@victordion
Copy link
Copy Markdown

using "dequantized" in bwd still does not preserve the chain rule 100%, as the quantization in fwd and bwd happens along different dims

@victordion I think you are describing the default TE 1d recipe or requantized behavior.

Right. My mistake. My mental model assumed there is requantize happening. Thanks for responding!

@zianglih zianglih requested review from ksivaman and zhongbozhu March 17, 2026 05:42
@zianglih
Copy link
Copy Markdown
Author

Regarding the env var design, since this feature is mainly used by RL, there has to be a way for the user to directly override the bwd behavior in RL framework instead of plumbing all the way through Megatron.

@ksivaman
Copy link
Copy Markdown
Member

/te-ci L0 L1

@zianglih
Copy link
Copy Markdown
Author

All pytorch ci passed.

Some failed jax tests are due to FileExistsError: [Errno 17] File exists: '/logs' .

@zhongbozhu
Copy link
Copy Markdown
Collaborator

/te-ci L0 L1

Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@ksivaman ksivaman dismissed their stale review March 31, 2026 19:25

Unblocking

@zhongbozhu
Copy link
Copy Markdown
Collaborator

/te-ci L0 L1

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 3, 2026

/te-ci L0 L1

@zianglih
Copy link
Copy Markdown
Author

zianglih commented Apr 7, 2026

Failed JAX ci is unrelated to this PR:

B200:

../../tests/jax/test_permutation.py::TestHighLevelPermutationAPI::test_sort_chunks_by_index[dtype_float32-8-4096-1280] FAILED

L40:

../../tests/jax/test_permutation.py::TestHighLevelPermutationAPI::test_sort_chunks_by_index[dtype_float32-8-4096-1280] FAILED

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants