From 8cad9ae18755493533a80536e79685ee5ea3cd74 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 2 Feb 2026 16:45:50 -0800 Subject: [PATCH 01/61] Add NVTE_KEEP_BACKWARD_UNQUANTIZED Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/base.py | 4 +- .../pytorch/module/grouped_linear.py | 36 +++-- .../pytorch/module/layernorm_linear.py | 80 +++++++--- .../pytorch/module/layernorm_mlp.py | 147 +++++++++++------- transformer_engine/pytorch/module/linear.py | 65 +++++--- .../pytorch/ops/basic/basic_linear.py | 48 ++++-- .../pytorch/ops/basic/quantize.py | 6 +- .../ops/fused/backward_activation_bias.py | 7 +- .../fused/forward_linear_bias_activation.py | 18 ++- .../ops/fused/forward_linear_bias_add.py | 18 ++- .../ops/fused/forward_linear_scale_add.py | 18 ++- .../ops/fused/userbuffers_forward_linear.py | 49 +++++- transformer_engine/pytorch/ops/fuser.py | 16 +- transformer_engine/pytorch/quantization.py | 5 + 14 files changed, 375 insertions(+), 142 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 28da4873f0..f4351a3be8 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1184,9 +1184,11 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8 and not ctx.debug: + if not use_fp8_bwd and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: # Perform NCCL all-gather grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index fade2957d5..3f055a2b77 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -97,6 +97,9 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + save_original_input = True num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] @@ -291,6 +294,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -299,7 +303,11 @@ def forward( ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad(inp, weights[0], biases[0]) + ): ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() @@ -323,6 +331,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -338,7 +348,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: if ctx.use_bias: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) recipe = ctx.fp8_recipe @@ -392,7 +402,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = ( @@ -403,13 +413,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - # Make sure weights are available in column-wise format - # for dgrad computation. - for weight in weights: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) + weights_for_dgrad = weights if use_fp8_bwd else origin_weights + if use_fp8_bwd: + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights_for_dgrad: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( - weights, + weights_for_dgrad, grad_output, [dgrad], ctx.grad_input_quantizers, @@ -423,7 +435,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): wgrad_gemm_use_split_accumulator = ( @@ -451,7 +463,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -528,7 +540,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() - and not ctx.fp8 + and not use_fp8_bwd ): grad_biases = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a90105477c..e26dae54fa 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -140,6 +140,7 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -198,7 +199,10 @@ def forward( if fp8: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and not keep_backward_unquantized, + ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) @@ -211,6 +215,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -234,6 +239,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -408,13 +414,14 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: + ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input: + if backward_needs_input and not keep_backward_unquantized: if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -426,7 +433,7 @@ def forward( ln_out.update_usage(rowwise_usage=False) if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out) + mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -438,7 +445,7 @@ def forward( mu, rsigma, weightmat if fp8 and not is_weight_param_quantized else None, - ln_out if weight.requires_grad else None, + ln_out_to_save if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -465,7 +472,7 @@ def forward( weight, bias, ln_weight, - ln_out, + ln_out_to_save, mu, rsigma, ) @@ -492,6 +499,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -514,7 +522,11 @@ def forward( ctx.requires_dgrad = inp_requires_grad ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad(inp, ln_weight, ln_bias, weight, bias) + ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): @@ -591,6 +603,15 @@ def backward( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -600,23 +621,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -627,7 +648,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -664,7 +685,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None: + if ctx.input_quantizer is not None and use_quantized_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -702,18 +723,22 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): + if ( + use_quantized_bwd + and ctx.weight_quantizer is not None + and isinstance(weight, QuantizedTensorStorage) + ): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_quantized_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -729,12 +754,13 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight if use_quantized_bwd else origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -781,7 +807,11 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -793,7 +823,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -819,14 +849,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -835,7 +865,7 @@ def backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -861,7 +891,9 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": ( + ctx.grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -869,7 +901,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 037fb6c858..94e30a2afa 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -234,6 +234,7 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -351,8 +352,10 @@ def _forward( # bwd needs fc1 input when grad is enabled, fc1 needs grad, and either # 1) no checkpointing # or 2) doing the recomputation with checkpointing - backwards_needs_fc1_input = fc1_weight.requires_grad and ( - (is_grad_enabled and not checkpoint) or is_recomputation + backwards_needs_fc1_input = ( + fc1_weight.requires_grad + and ((is_grad_enabled and not checkpoint) or is_recomputation) + and not keep_backward_unquantized ) device = inp.device @@ -395,6 +398,7 @@ def _forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom ) @@ -416,6 +420,7 @@ def _forward( # do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication @@ -613,6 +618,10 @@ def _forward( if fc2_input_quantizer is not None: fc2_input_quantizer.calibrate(act_out) + act_out_hp = act_out + if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: + act_out_hp = activation_func(fc1_out, None, **act_params) + # we want to skip fc2 computation if we are checkpointing and recomputing, # otherwise we compute fc2 if not (is_recomputation and checkpoint): @@ -688,22 +697,30 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: + ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: - clear_tensor_data(ln_out) - ln_out = None + clear_tensor_data(ln_out_to_save) + ln_out_to_save = None if not fc2_weight.requires_grad: - clear_tensor_data(act_out) - act_out = None + clear_tensor_data(act_out_to_save) + act_out_to_save = None if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: mark_activation_offload( - inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + inputmat, + mu, + rsigma, + ln_out_to_save, + fc1_out, + fc1_out_without_bias, + act_out_to_save, ) # Scatter intermediate/activation tensors saved for the backward pass @@ -716,9 +733,9 @@ def _forward( fsdp_group, mu, rsigma, - ln_out, + ln_out_to_save, fc1_out_without_bias if bias_gelu_fusion else fc1_out, - act_out, + act_out_to_save, ( fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) @@ -746,13 +763,13 @@ def _forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out, + ln_out_to_save, fc1_weight_final, fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, - act_out, + act_out_to_save, fc2_weight_final, fc2_weight, fc2_bias, @@ -800,6 +817,7 @@ def _forward( ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -828,8 +846,12 @@ def _forward( ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad( - inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad( + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + ) ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() @@ -998,6 +1020,16 @@ def backward( origin_fc1_weight.main_grad = fc1_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1017,7 +1049,7 @@ def backward( # Choose whether to use GEMM kernel with split accumulator dgrad_use_split_accumulator = _2X_ACC_DGRAD wgrad_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -1031,7 +1063,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1044,7 +1076,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", use_fp8_bwd) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -1059,7 +1091,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1068,7 +1100,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -1105,7 +1137,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 + not use_fp8_bwd and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) and (not ctx.debug) @@ -1114,20 +1146,23 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.fc2_weight_quantizer is not None + and isinstance(ctx.fc2_weight, QuantizedTensorStorage) ): ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM + fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( - fc2_weight, + fc2_weight_for_dgrad, grad_output, layout="NN", grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if fc2_dgrad_gemm_gelu_fusion or ctx.debug + if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_quantized_bwd else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1159,7 +1194,11 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -1172,7 +1211,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", use_fp8_bwd) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1195,14 +1234,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1211,7 +1250,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): + if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1221,7 +1260,9 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision + "quantization_params": ( + ctx.fc2_grad_weight_quantizer if use_quantized_bwd else None + ), # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc1_weight, "overwrite_main_grad", False) @@ -1258,8 +1299,8 @@ def fc2_wgrad_gemm( # Update grad bias if needed if fc2_bias_grad is None: if ( - ctx.fp8 - and ctx.fp8_recipe.float8_block_scaling() + use_fp8_bwd + and fp8_recipe_bwd.float8_block_scaling() and fc2_bias is not None ): # BGRAD not fused with GEMM for float8 blockwise gemm. @@ -1279,12 +1320,12 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None: + if ctx.fc1_grad_output_quantizer is not None and use_quantized_bwd: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" - assert not ctx.fp8 + assert not use_fp8_bwd fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: dact = ctx.fc1_grad_output_quantizer(dact) @@ -1294,13 +1335,10 @@ def fc2_wgrad_gemm( fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) elif ( - _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None - and ctx.fp8 + _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func( - ctx.activation, ctx.fp8_recipe if ctx.fp8 else None - )[2] + dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), @@ -1310,18 +1348,16 @@ def fc2_wgrad_gemm( else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func( - ctx.activation, ctx.fp8_recipe if ctx.fp8 else None - )[1] + activation_func_bwd = _act_func(ctx.activation, fp8_recipe_bwd)[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision - if ctx.fp8: + if use_fp8_bwd: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) - or ctx.fp8_recipe.custom() + or fp8_recipe_bwd.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) @@ -1349,16 +1385,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", use_fp8_bwd) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -1366,8 +1402,10 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # Make sure required data is available - if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.fc1_weight_quantizer is not None + and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1382,12 +1420,13 @@ def fc2_wgrad_gemm( gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) # dgrad GEMM + fc1_weight_for_dgrad = fc1_weight if use_fp8_bwd else origin_fc1_weight gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight, + fc1_weight_for_dgrad, dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer, + quantization_params=ctx.fc1_grad_input_quantizer if use_quantized_bwd else None, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1436,7 +1475,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1446,7 +1485,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1468,7 +1507,9 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc1_grad_weight_quantizer, + "quantization_params": ( + ctx.fc1_grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1e3eadc405..9f6c07832c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -128,6 +128,9 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + save_original_input = True # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -442,6 +445,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -478,7 +482,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and requires_grad(inp, weight, bias): + if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): @@ -535,6 +539,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -544,23 +557,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -574,7 +587,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -593,6 +606,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None + and use_quantized_bwd ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -622,7 +636,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -648,7 +662,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -689,20 +703,22 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance( - weight_fp8, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.weight_quantizer is not None + and isinstance(weight_fp8, QuantizedTensorStorage) ): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_quantized_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -719,12 +735,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight_fp8 if use_quantized_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight_fp8, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -773,7 +790,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -783,7 +800,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -795,7 +816,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -815,7 +836,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -824,7 +845,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -850,7 +871,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": ( + ctx.grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -858,7 +881,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 48376a297f..f2b8ba106e 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,12 +332,14 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad + keep_backward_unquantized = FP8GlobalStateManager.keep_backward_unquantized() + columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=False) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -420,6 +422,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -459,6 +462,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -510,7 +515,10 @@ def _functional_forward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) x, x_async = gather_along_first_dim( @@ -542,7 +550,10 @@ def _functional_forward( elif with_quantized_compute and not is_quantized_tensor(w): if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Check output tensor @@ -611,14 +622,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -968,6 +988,9 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -984,6 +1007,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -993,10 +1017,16 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = self.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - ctx.save_for_backward(x_local, w) - ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + ctx.save_for_backward(saved_input, saved_weight) + ctx.with_quantized_compute = with_quantized_compute and not keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index fa3efc3807..7dd8f1a7ac 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -57,7 +57,11 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = fp8_enabled and self._quantize_backward + quantize_backward = ( + fp8_enabled + and self._quantize_backward + and not FP8GlobalStateManager.keep_backward_unquantized() + ) # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 4ab082d32b..59e9af14f4 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.quantization import Recipe +from transformer_engine.pytorch.quantization import Recipe, FP8GlobalStateManager from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, @@ -105,7 +105,10 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion - if recipe is None: + if recipe is None or ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ): return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index dfc11a19e7..0a28d00706 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,6 +92,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -109,6 +112,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -118,10 +122,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 2dfc0566b7..41ae096e54 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,6 +86,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -106,6 +109,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -115,10 +119,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index ae4bdd4b19..b06f5ad36a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,6 +65,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -87,6 +90,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -96,10 +100,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 0d3e1d0416..3e5a389246 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -94,6 +94,7 @@ def _functional_forward( tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -126,6 +127,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -200,7 +203,10 @@ def _functional_forward( if with_ub_all_gather: if input_quantizer is not None: if not is_quantized_tensor(x_local): - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -216,7 +222,10 @@ def _functional_forward( else: if with_quantized_compute: if not is_quantized_tensor(x_local): - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) x_local = input_quantizer(x_local) else: x_local = maybe_dequantize(x_local, dtype) @@ -227,7 +236,10 @@ def _functional_forward( if not with_quantized_compute: w = maybe_dequantize(w, dtype) elif with_quantized_compute and not is_quantized_tensor(w): - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Construct output tensor if needed @@ -257,14 +269,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -311,6 +332,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): @@ -340,6 +364,7 @@ def fuser_forward( tensor_parallel_size=self.tensor_parallel_size, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=None, # Not supported @@ -352,10 +377,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 80386db2d9..0316213480 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -109,6 +109,10 @@ def forward( # Apply forward ops x = input_ extra_outputs = [None] * fuser._num_basic_ops + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ) for op, basic_op_idxs in fuser._forward_ops: # Set if backward op is required @@ -120,7 +124,7 @@ def forward( prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op_grad_output_quantizer = None - if prev_op is not None: + if prev_op is not None and not keep_backward_unquantized: prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None @@ -287,7 +291,15 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - if func_ctx.is_first_module and not _is_graph_capturing(): + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ) + if ( + func_ctx.is_first_module + and not keep_backward_unquantized + and not _is_graph_capturing() + ): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 47e6d5c8dc..42e3df6a7c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -431,6 +431,11 @@ def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL + @classmethod + def keep_backward_unquantized(cls) -> bool: + """Should backward skip FP8 quantization and use high precision""" + return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) + @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" From b4f91ef335afc0978899b3a76a429f63324fd3be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 00:49:22 +0000 Subject: [PATCH 02/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +--- transformer_engine/pytorch/ops/fuser.py | 6 +----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 94e30a2afa..f5b94afa5e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1334,9 +1334,7 @@ def fc2_wgrad_gemm( dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) - elif ( - _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd - ): + elif _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd: # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] fc1_bias_grad, dact = dbias_dact_quantize_func( diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 0316213480..996ad1c67b 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -295,11 +295,7 @@ def backward( FP8GlobalStateManager.is_fp8_enabled() and FP8GlobalStateManager.keep_backward_unquantized() ) - if ( - func_ctx.is_first_module - and not keep_backward_unquantized - and not _is_graph_capturing() - ): + if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( From cf8ac99ee0761a20741c882909ffd6f505bebb7c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:36:13 -0800 Subject: [PATCH 03/61] Disable ub and clean up Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 9 ++-- .../pytorch/module/layernorm_mlp.py | 13 ++--- transformer_engine/pytorch/module/linear.py | 17 +++---- .../ops/fused/userbuffers_forward_linear.py | 49 +++---------------- 4 files changed, 25 insertions(+), 63 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e26dae54fa..0ccdd03b10 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -607,6 +607,7 @@ def backward( use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -621,23 +622,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f5b94afa5e..4353d368bc 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1025,6 +1025,7 @@ def backward( use_quantized_bwd = use_fp8_bwd or ctx.debug fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -1076,7 +1077,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad", use_fp8_bwd) + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -1100,7 +1101,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -1194,11 +1195,7 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -1211,7 +1208,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad", use_fp8_bwd) + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 9f6c07832c..3ec87cc9bb 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -543,6 +543,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -557,23 +558,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -800,11 +801,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -816,7 +813,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 3e5a389246..0d3e1d0416 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -94,7 +94,6 @@ def _functional_forward( tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, - keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -127,8 +126,6 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. - keep_backward_unquantized: bool, default = `False` - Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -203,10 +200,7 @@ def _functional_forward( if with_ub_all_gather: if input_quantizer is not None: if not is_quantized_tensor(x_local): - input_quantizer.set_usage( - rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, - ) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -222,10 +216,7 @@ def _functional_forward( else: if with_quantized_compute: if not is_quantized_tensor(x_local): - input_quantizer.set_usage( - rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, - ) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) x_local = input_quantizer(x_local) else: x_local = maybe_dequantize(x_local, dtype) @@ -236,10 +227,7 @@ def _functional_forward( if not with_quantized_compute: w = maybe_dequantize(w, dtype) elif with_quantized_compute and not is_quantized_tensor(w): - weight_quantizer.set_usage( - rowwise=True, - columnwise=input_requires_grad and not keep_backward_unquantized, - ) + weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = weight_quantizer(w) # Construct output tensor if needed @@ -269,23 +257,14 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if ( - w is not weight - and with_quantized_compute - and is_quantized_tensor(w) - and not keep_backward_unquantized - ): + if w is not weight and with_quantized_compute and is_quantized_tensor(w): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if ( - with_quantized_compute - and is_quantized_tensor(x_local) - and not keep_backward_unquantized - ): + if with_quantized_compute and is_quantized_tensor(x_local): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -332,9 +311,6 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() - ) if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): @@ -364,7 +340,6 @@ def fuser_forward( tensor_parallel_size=self.tensor_parallel_size, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=None, # Not supported @@ -377,18 +352,10 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) - linear_op_ctx.save_for_backward(saved_input, saved_weight) - linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized - ) + mark_activation_offload(x_local) + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer From a90c4d6a5cb362279f6449fbc4ead64922c7dd26 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:37:57 -0800 Subject: [PATCH 04/61] Drop fuser changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/fuser.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 996ad1c67b..80386db2d9 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -109,10 +109,6 @@ def forward( # Apply forward ops x = input_ extra_outputs = [None] * fuser._num_basic_ops - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ) for op, basic_op_idxs in fuser._forward_ops: # Set if backward op is required @@ -124,7 +120,7 @@ def forward( prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op_grad_output_quantizer = None - if prev_op is not None and not keep_backward_unquantized: + if prev_op is not None: prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None @@ -291,11 +287,7 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ) - if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing(): + if func_ctx.is_first_module and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( From 80917d6641111c9bab0168a7bb6864a3e6d48fce Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:56:43 -0800 Subject: [PATCH 05/61] Replace use_quantized_bwd with use_fp8_bwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 19 +++++++------ .../pytorch/module/layernorm_mlp.py | 27 +++++++++---------- transformer_engine/pytorch/module/linear.py | 23 ++++++++-------- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 0ccdd03b10..f6664c3981 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -605,7 +605,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -649,7 +648,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_quantized_bwd: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -686,7 +685,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None and use_quantized_bwd: + if ctx.input_quantizer is not None and use_fp8_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -725,7 +724,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage) ): @@ -739,7 +738,7 @@ def backward( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_quantized_bwd: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -755,13 +754,13 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight if use_quantized_bwd else origin_weight + weight_for_dgrad = weight if use_fp8_bwd else origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -850,14 +849,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -893,7 +892,7 @@ def backward( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ( - ctx.grad_weight_quantizer if use_quantized_bwd else None + ctx.grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4353d368bc..0a7baf8759 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1022,7 +1022,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True @@ -1064,7 +1063,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None and use_quantized_bwd: + if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1092,7 +1091,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_quantized_bwd: + if use_fp8_bwd: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1148,7 +1147,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensorStorage) ): @@ -1163,7 +1162,7 @@ def backward( grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_quantized_bwd + if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_fp8_bwd else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1231,14 +1230,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1258,7 +1257,7 @@ def backward( else ctx.activation_dtype ), "quantization_params": ( - ctx.fc2_grad_weight_quantizer if use_quantized_bwd else None + ctx.fc2_grad_weight_quantizer if use_fp8_bwd else None ), # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad @@ -1317,7 +1316,7 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None and use_quantized_bwd: + if ctx.fc1_grad_output_quantizer is not None and use_fp8_bwd: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu @@ -1398,7 +1397,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ( - use_quantized_bwd + use_fp8_bwd and ctx.fc1_weight_quantizer is not None and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) ): @@ -1421,7 +1420,7 @@ def fc2_wgrad_gemm( dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.fc1_grad_input_quantizer if use_fp8_bwd else None, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1470,7 +1469,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1480,7 +1479,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1503,7 +1502,7 @@ def fc2_wgrad_gemm( else ctx.activation_dtype ), "quantization_params": ( - ctx.fc1_grad_weight_quantizer if use_quantized_bwd else None + ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3ec87cc9bb..0441bf958d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -541,7 +541,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -588,7 +587,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_quantized_bwd: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -607,7 +606,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and use_quantized_bwd + and use_fp8_bwd ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -637,7 +636,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -663,7 +662,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_quantized_bwd: + if use_fp8_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -705,7 +704,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage) ): @@ -719,7 +718,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_quantized_bwd: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -736,13 +735,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 if use_quantized_bwd else weight + weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -791,7 +790,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -833,7 +832,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -869,7 +868,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ( - ctx.grad_weight_quantizer if use_quantized_bwd else None + ctx.grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad From 07cb1df3dd5c4d1983a47c037886a874bf059bb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:57:32 +0000 Subject: [PATCH 06/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_linear.py | 4 +--- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +--- transformer_engine/pytorch/module/linear.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f6664c3981..d830328eab 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -891,9 +891,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 0a7baf8759..c461baa745 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1501,9 +1501,7 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0441bf958d..3befb71469 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -867,9 +867,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) From dc5d3819df59cf27a46bc6554288c247b42b6c01 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 10:30:04 -0800 Subject: [PATCH 07/61] Ignore keep_backward_unquantized if delayed scaling Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 1 + transformer_engine/pytorch/module/linear.py | 1 + transformer_engine/pytorch/quantization.py | 3 +++ 3 files changed, 5 insertions(+) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3f055a2b77..bfd4079dcd 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -99,6 +99,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: + # Note, keep_backward_unquantized is ignored when delayed scaling is used save_original_input = True num_gemms = len(m_splits) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3befb71469..83fb4aa75c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -130,6 +130,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: + # Note, keep_backward_unquantized is ignored when delayed scaling is used save_original_input = True # NVTX label for profiling diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 42e3df6a7c..f73d71664b 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -434,6 +434,9 @@ def with_high_precision_init_val(cls) -> bool: @classmethod def keep_backward_unquantized(cls) -> bool: """Should backward skip FP8 quantization and use high precision""" + recipe = cls.get_fp8_recipe() + if recipe.delayed(): + return False return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) @classmethod From 05fb8940a32e0d86d6089f7f48de83bd7625b4a3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 10:39:02 -0800 Subject: [PATCH 08/61] Refactor ignoring NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- transformer_engine/pytorch/quantization.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index bfd4079dcd..263149d6ff 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -99,7 +99,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: - # Note, keep_backward_unquantized is ignored when delayed scaling is used + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True num_gemms = len(m_splits) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 83fb4aa75c..3a9558ff05 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -130,7 +130,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: - # Note, keep_backward_unquantized is ignored when delayed scaling is used + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True # NVTX label for profiling diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index f73d71664b..3c4a6b9ffe 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -435,7 +435,8 @@ def with_high_precision_init_val(cls) -> bool: def keep_backward_unquantized(cls) -> bool: """Should backward skip FP8 quantization and use high precision""" recipe = cls.get_fp8_recipe() - if recipe.delayed(): + if recipe is not None and recipe.delayed(): + # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used return False return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) From f78ffa0ff7d5c37a7756f617d9512caf2d0e69c7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:13:57 -0800 Subject: [PATCH 09/61] Add back missing ctx.debug Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++-- transformer_engine/pytorch/module/layernorm_mlp.py | 10 +++++----- transformer_engine/pytorch/module/linear.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d830328eab..bdb765876e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -849,14 +849,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c461baa745..5abd5d40aa 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1091,7 +1091,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1230,14 +1230,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1469,7 +1469,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1479,7 +1479,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3a9558ff05..2e376fbf63 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -637,7 +637,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -663,7 +663,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -791,7 +791,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -833,7 +833,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: From 91b53e14c908d59e9036cc849a9193dcc852a1b3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:43:45 -0800 Subject: [PATCH 10/61] Refactor changes under fused Signed-off-by: Ziang Li --- .../ops/fused/backward_activation_bias.py | 7 ++----- .../ops/fused/forward_linear_bias_activation.py | 17 +++++++++++------ .../ops/fused/forward_linear_bias_add.py | 17 +++++++++++------ .../ops/fused/forward_linear_scale_add.py | 17 +++++++++++------ 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 59e9af14f4..4ab082d32b 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.quantization import Recipe, FP8GlobalStateManager +from transformer_engine.pytorch.quantization import Recipe from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, @@ -105,10 +105,7 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion - if recipe is None or ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ): + if recipe is None: return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 0a28d00706..6e7c85988f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -122,12 +122,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 41ae096e54..f3b4533848 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -119,12 +119,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index b06f5ad36a..53e7327873 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -100,12 +100,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From 4509967bfceb413b5344606ab4a6d45dbe2019f8 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:44:30 -0800 Subject: [PATCH 11/61] Clean up Signed-off-by: Ziang Li --- .../pytorch/ops/fused/forward_linear_bias_activation.py | 6 ------ .../pytorch/ops/fused/forward_linear_bias_add.py | 6 ------ .../pytorch/ops/fused/forward_linear_scale_add.py | 6 ------ 3 files changed, 18 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 6e7c85988f..2458d4d072 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -127,12 +127,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index f3b4533848..efa543e555 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -124,12 +124,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 53e7327873..2804534968 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -105,12 +105,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From 70be1b5e8f8b1c633ee8bf7f67a0dd2f594c972d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 14:11:07 -0800 Subject: [PATCH 12/61] Refactor high-precision overwrite if keep_backward_unquantized Signed-off-by: Ziang Li --- .../pytorch/module/grouped_linear.py | 17 ++++++++++------- .../pytorch/module/layernorm_linear.py | 10 ++++++++-- .../pytorch/module/layernorm_mlp.py | 14 +++++++++++--- transformer_engine/pytorch/module/linear.py | 5 ++++- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 263149d6ff..3ef33ce009 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -414,13 +414,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - weights_for_dgrad = weights if use_fp8_bwd else origin_weights - if use_fp8_bwd: - # Make sure weights are available in column-wise format - # for dgrad computation. - for weight in weights_for_dgrad: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) + # weights_for_dgrad = weights if use_fp8_bwd else origin_weights + # if use_fp8_bwd: + weights_for_dgrad = weights + if keep_backward_unquantized: + weights_for_dgrad = origin_weights + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights_for_dgrad: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( weights_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bdb765876e..4057a3b229 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -414,7 +414,10 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: - ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + ln_out_to_save = ln_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp + # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel @@ -754,7 +757,10 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight if use_fp8_bwd else origin_weight + # weight_for_dgrad = weight if use_fp8_bwd else origin_weight + weight_for_dgrad = weight + if keep_backward_unquantized: + weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5abd5d40aa..5c335b1921 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -697,8 +697,13 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: - ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out - act_out_to_save = act_out_hp if keep_backward_unquantized else act_out + ln_out_to_save = ln_out + act_out_to_save = act_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp + act_out_to_save = act_out_hp + # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + # act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer @@ -1154,7 +1159,10 @@ def backward( ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM - fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight + fc2_weight_for_dgrad = fc2_weight + if keep_backward_unquantized: + fc2_weight_for_dgrad = origin_fc2_weight + # fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( fc2_weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2e376fbf63..71ffd31a8e 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -736,7 +736,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight + weight_for_dgrad = weight_fp8 + if keep_backward_unquantized: + weight_for_dgrad = weight + # weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, From 058ad45bddd64475f484e2894f44f57a464cb46e Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 14:14:22 -0800 Subject: [PATCH 13/61] Clean up Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 2 -- transformer_engine/pytorch/module/layernorm_linear.py | 2 -- transformer_engine/pytorch/module/layernorm_mlp.py | 3 --- transformer_engine/pytorch/module/linear.py | 1 - 4 files changed, 8 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3ef33ce009..4a0b4fe060 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -414,8 +414,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - # weights_for_dgrad = weights if use_fp8_bwd else origin_weights - # if use_fp8_bwd: weights_for_dgrad = weights if keep_backward_unquantized: weights_for_dgrad = origin_weights diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4057a3b229..0e955d9e60 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -417,7 +417,6 @@ def forward( ln_out_to_save = ln_out if keep_backward_unquantized: ln_out_to_save = ln_out_hp - # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel @@ -757,7 +756,6 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - # weight_for_dgrad = weight if use_fp8_bwd else origin_weight weight_for_dgrad = weight if keep_backward_unquantized: weight_for_dgrad = origin_weight diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5c335b1921..b9867c6609 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -702,8 +702,6 @@ def _forward( if keep_backward_unquantized: ln_out_to_save = ln_out_hp act_out_to_save = act_out_hp - # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out - # act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer @@ -1162,7 +1160,6 @@ def backward( fc2_weight_for_dgrad = fc2_weight if keep_backward_unquantized: fc2_weight_for_dgrad = origin_fc2_weight - # fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( fc2_weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 71ffd31a8e..25e152210f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -739,7 +739,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weight_for_dgrad = weight_fp8 if keep_backward_unquantized: weight_for_dgrad = weight - # weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, From 0f53179d42239e5758f57cfa7ad3c1faccb72b39 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 10:56:41 -0800 Subject: [PATCH 14/61] Drop redundant fp8_recipe_bwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_mlp.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b9867c6609..b7b4591874 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1025,7 +1025,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -1251,7 +1250,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling(): + if use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1301,7 +1300,7 @@ def fc2_wgrad_gemm( if fc2_bias_grad is None: if ( use_fp8_bwd - and fp8_recipe_bwd.float8_block_scaling() + and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None ): # BGRAD not fused with GEMM for float8 blockwise gemm. @@ -1335,9 +1334,14 @@ def fc2_wgrad_gemm( dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) - elif _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd: + elif ( + _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None + and use_fp8_bwd + ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] + dbias_dact_quantize_func = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), @@ -1347,7 +1351,9 @@ def fc2_wgrad_gemm( else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func(ctx.activation, fp8_recipe_bwd)[1] + activation_func_bwd = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision @@ -1356,7 +1362,7 @@ def fc2_wgrad_gemm( # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) - or fp8_recipe_bwd.custom() + or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) From 6f213dcb5238a65e2d2606dbcf087a2a2e73fdd4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:57:29 +0000 Subject: [PATCH 15/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b7b4591874..6db8332d47 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1337,7 +1337,7 @@ def fc2_wgrad_gemm( elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None and use_fp8_bwd - ): + ): # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None From b4836d166e597f022459c636a4683e923d2c8b9d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:02:24 -0800 Subject: [PATCH 16/61] Drop redundant ub changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6db8332d47..f040ce7d87 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1390,16 +1390,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad", use_fp8_bwd) + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- From 6b7666a1a0d4b6ce8ecd1917db607a88bfed3eef Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:07:16 -0800 Subject: [PATCH 17/61] Drop more redundant ub changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 0e955d9e60..fdd03acfd5 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -811,11 +811,7 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -827,7 +823,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) From 6a10fd15445c21b9ddd3ead956db3e7db8bc97b1 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:25:01 -0800 Subject: [PATCH 18/61] Drop redundant delayed scaling changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 6 +----- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +----- transformer_engine/pytorch/module/linear.py | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 4a0b4fe060..054c5de04a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -304,11 +304,7 @@ def forward( ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad(inp, weights[0], biases[0]) - ): + if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f040ce7d87..bb70081f08 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -849,12 +849,8 @@ def _forward( ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad( + if ctx.fp8 and requires_grad( inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias - ) ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 25e152210f..6b96805e05 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -483,7 +483,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias): + if ctx.fp8 and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): From dd9038bba017b69821899ff329489994cef87d75 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:25:49 +0000 Subject: [PATCH 19/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bb70081f08..3866315f99 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -850,7 +850,7 @@ def _forward( ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad( - inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() From c83ed747de0fe0121dd57301fe74df96ff540940 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 12:01:36 -0800 Subject: [PATCH 20/61] Drop unneeded backwards_needs_fc1_input Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 3866315f99..9ec22ca07d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -352,10 +352,8 @@ def _forward( # bwd needs fc1 input when grad is enabled, fc1 needs grad, and either # 1) no checkpointing # or 2) doing the recomputation with checkpointing - backwards_needs_fc1_input = ( - fc1_weight.requires_grad - and ((is_grad_enabled and not checkpoint) or is_recomputation) - and not keep_backward_unquantized + backwards_needs_fc1_input = fc1_weight.requires_grad and ( + (is_grad_enabled and not checkpoint) or is_recomputation ) device = inp.device From bf632e74430447fb72cae86f62cfb1db7f7c531f Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 14:01:43 -0800 Subject: [PATCH 21/61] Drop and disallow LayerNormMLP implementation Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_mlp.py | 104 ++++++------------ 1 file changed, 34 insertions(+), 70 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9ec22ca07d..bf28fa47e3 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -235,6 +235,7 @@ def _forward( recompute_for_bwd, ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + assert not keep_backward_unquantized, "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -396,7 +397,6 @@ def _forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not keep_backward_unquantized and not custom ) @@ -418,7 +418,6 @@ def _forward( # do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation: ln_out_return = ln_out - ln_out_hp = ln_out if keep_backward_unquantized else None # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication @@ -616,10 +615,6 @@ def _forward( if fc2_input_quantizer is not None: fc2_input_quantizer.calibrate(act_out) - act_out_hp = act_out - if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: - act_out_hp = activation_func(fc1_out, None, **act_params) - # we want to skip fc2 computation if we are checkpointing and recomputing, # otherwise we compute fc2 if not (is_recomputation and checkpoint): @@ -695,33 +690,22 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: - ln_out_to_save = ln_out - act_out_to_save = act_out - if keep_backward_unquantized: - ln_out_to_save = ln_out_hp - act_out_to_save = act_out_hp ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: - clear_tensor_data(ln_out_to_save) - ln_out_to_save = None + clear_tensor_data(ln_out) + ln_out = None if not fc2_weight.requires_grad: - clear_tensor_data(act_out_to_save) - act_out_to_save = None + clear_tensor_data(act_out) + act_out = None if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: mark_activation_offload( - inputmat, - mu, - rsigma, - ln_out_to_save, - fc1_out, - fc1_out_without_bias, - act_out_to_save, + inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out ) # Scatter intermediate/activation tensors saved for the backward pass @@ -734,9 +718,9 @@ def _forward( fsdp_group, mu, rsigma, - ln_out_to_save, + ln_out, fc1_out_without_bias if bias_gelu_fusion else fc1_out, - act_out_to_save, + act_out, ( fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) @@ -764,13 +748,13 @@ def _forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out_to_save, + ln_out, fc1_weight_final, fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, - act_out_to_save, + act_out, fc2_weight_final, fc2_weight, fc2_bias, @@ -818,7 +802,6 @@ def _forward( ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -1017,15 +1000,6 @@ def backward( origin_fc1_weight.main_grad = fc1_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1045,7 +1019,7 @@ def backward( # Choose whether to use GEMM kernel with split accumulator dgrad_use_split_accumulator = _2X_ACC_DGRAD wgrad_use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -1059,7 +1033,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: + if ctx.fc2_grad_output_quantizer is not None: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1087,7 +1061,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1133,7 +1107,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not use_fp8_bwd + not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) and (not ctx.debug) @@ -1142,25 +1116,20 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ( - use_fp8_bwd - and ctx.fc2_weight_quantizer is not None - and isinstance(ctx.fc2_weight, QuantizedTensorStorage) + if ctx.fc2_weight_quantizer is not None and isinstance( + ctx.fc2_weight, QuantizedTensorStorage ): ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM - fc2_weight_for_dgrad = fc2_weight - if keep_backward_unquantized: - fc2_weight_for_dgrad = origin_fc2_weight gemm_output, *_ = general_gemm( - fc2_weight_for_dgrad, + fc2_weight, grad_output, layout="NN", grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_fp8_bwd + if fc2_dgrad_gemm_gelu_fusion or ctx.debug else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1228,14 +1197,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1244,7 +1213,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling(): + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1254,9 +1223,7 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.fc2_grad_weight_quantizer if use_fp8_bwd else None - ), # wgrad in high precision + "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc1_weight, "overwrite_main_grad", False) @@ -1293,7 +1260,7 @@ def fc2_wgrad_gemm( # Update grad bias if needed if fc2_bias_grad is None: if ( - use_fp8_bwd + ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None ): @@ -1314,12 +1281,12 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None and use_fp8_bwd: + if ctx.fc1_grad_output_quantizer is not None: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" - assert not use_fp8_bwd + assert not ctx.fp8 fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: dact = ctx.fc1_grad_output_quantizer(dact) @@ -1330,7 +1297,7 @@ def fc2_wgrad_gemm( dact = ctx.fc1_grad_output_quantizer(dact) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None - and use_fp8_bwd + and ctx.fp8 ): # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func( @@ -1352,7 +1319,7 @@ def fc2_wgrad_gemm( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision - if use_fp8_bwd: + if ctx.fp8: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) @@ -1401,10 +1368,8 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # Make sure required data is available - if ( - use_fp8_bwd - and ctx.fc1_weight_quantizer is not None - and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) + if ctx.fc1_weight_quantizer is not None and isinstance( + ctx.fc1_weight_quantizer, QuantizedTensorStorage ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1419,13 +1384,12 @@ def fc2_wgrad_gemm( gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) # dgrad GEMM - fc1_weight_for_dgrad = fc1_weight if use_fp8_bwd else origin_fc1_weight gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight_for_dgrad, + fc1_weight, dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.fc1_grad_input_quantizer, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1474,7 +1438,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1484,7 +1448,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1506,7 +1470,7 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": ctx.fc1_grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) From ae939a14245e1f0982f0081d05496401815addc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 22:02:31 +0000 Subject: [PATCH 22/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bf28fa47e3..b8adcb11e2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -235,7 +235,9 @@ def _forward( recompute_for_bwd, ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() - assert not keep_backward_unquantized, "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + assert ( + not keep_backward_unquantized + ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: From 25902a28dfdfe72431e9bacc55875e4a370e4147 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:10:10 -0800 Subject: [PATCH 23/61] Move interface changes to recipe Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 67 +++++++++++++++++-- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- .../pytorch/ops/basic/basic_linear.py | 6 +- .../pytorch/ops/basic/quantize.py | 2 +- .../fused/forward_linear_bias_activation.py | 2 +- .../ops/fused/forward_linear_bias_add.py | 2 +- .../ops/fused/forward_linear_scale_add.py | 2 +- transformer_engine/pytorch/quantization.py | 39 +++++++---- 11 files changed, 99 insertions(+), 29 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 18577b0eb4..341f23972c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,6 +11,11 @@ from pydantic.dataclasses import dataclass +def _default_quantize_backward() -> bool: + """Default backward quantization setting.""" + return not bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) + + class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -188,6 +193,11 @@ def scaling_factor_compute(amax: Tensor, `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. Delayed scaling + always quantizes backward; setting this to False is not supported. Notes ----- @@ -211,6 +221,8 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -223,7 +235,9 @@ def __repr__(self) -> str: f"amax_history_len={self.amax_history_len}, " f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -237,6 +251,10 @@ class Float8CurrentScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" @@ -249,6 +267,10 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -264,7 +286,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -291,12 +315,18 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -305,7 +335,9 @@ def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}" + f"format={str(self.fp8_format).split('.')[1]}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -334,6 +366,10 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -386,7 +422,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -435,6 +473,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ # Configuration envvars @@ -450,6 +492,8 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" @@ -481,6 +525,8 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " @@ -512,12 +558,23 @@ class CustomRecipe(Recipe): - forward: "linear_input", "linear_weight", "linear_output" - backward: "linear_grad_output", "linear_grad_input" + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ qfactory: Callable[..., Any] fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __repr__(self) -> str: - return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}" + return ( + f"recipe_type={self.__class__.__name__}, " + f"qfactory={self.qfactory}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" + ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 054c5de04a..4d17f4c519 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -97,7 +97,7 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fdd03acfd5..78cbcac50f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -140,7 +140,7 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b8adcb11e2..a2e002e875 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -234,7 +234,7 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) assert ( not keep_backward_unquantized ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6b96805e05..90f9778c7e 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -128,7 +128,7 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index f2b8ba106e..d73fceeaf0 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,7 +332,9 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + ) columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) @@ -989,7 +991,7 @@ def op_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 7dd8f1a7ac..4c67cd8cce 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -60,7 +60,7 @@ def op_forward( quantize_backward = ( fp8_enabled and self._quantize_backward - and not FP8GlobalStateManager.keep_backward_unquantized() + and FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Quantize if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 2458d4d072..80cb5647d7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -93,7 +93,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index efa543e555..cf29140a20 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -87,7 +87,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 2804534968..0caae13af9 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -66,7 +66,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get extra input tensor for add operation diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 3c4a6b9ffe..47d03ab4da 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,6 +87,21 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) +def _validate_recipe_quantization_flags(recipe: Recipe) -> None: + """Validate forward/backward quantization flags on a recipe.""" + quantize_forward = getattr(recipe, "quantize_forward", True) + quantize_backward = getattr(recipe, "quantize_backward", True) + if not quantize_forward and quantize_backward: + raise ValueError( + "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." + ) + if recipe.delayed() and not quantize_backward: + raise ValueError( + "Invalid recipe configuration: delayed scaling does not support " + "quantize_backward=False." + ) + + def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" recipe_supported = True @@ -431,15 +446,6 @@ def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL - @classmethod - def keep_backward_unquantized(cls) -> bool: - """Should backward skip FP8 quantization and use high precision""" - recipe = cls.get_fp8_recipe() - if recipe is not None and recipe.delayed(): - # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used - return False - return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) - @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" @@ -852,16 +858,21 @@ def autocast( are reduced at the end of each training step. """ - if enabled: - check_recipe_support(recipe) + fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe + if enabled or calibrating: + _validate_recipe_quantization_flags(fp8_recipe) + quantize_forward = getattr(fp8_recipe, "quantize_forward", True) + effective_enabled = enabled and quantize_forward + if effective_enabled: + check_recipe_support(fp8_recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() FP8GlobalStateManager.autocast_enter( - enabled=enabled, + enabled=effective_enabled, calibrating=calibrating, - fp8_recipe=recipe, + fp8_recipe=fp8_recipe, fp8_group=amax_reduction_group, _graph=_graph, ) @@ -869,7 +880,7 @@ def autocast( yield finally: FP8GlobalStateManager.set_autocast_state(fp8_state) - FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) + FP8GlobalStateManager.autocast_exit(effective_enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: From f10f0bbc7aba3624be2a4130079c0a7f9a9e71ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 21:11:01 +0000 Subject: [PATCH 24/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- transformer_engine/pytorch/module/linear.py | 4 +++- transformer_engine/pytorch/ops/basic/basic_linear.py | 8 ++++---- .../pytorch/ops/fused/forward_linear_bias_activation.py | 4 ++-- .../pytorch/ops/fused/forward_linear_bias_add.py | 4 ++-- .../pytorch/ops/fused/forward_linear_scale_add.py | 4 ++-- 8 files changed, 22 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 4d17f4c519..9dacb51f4a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -97,7 +97,9 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 78cbcac50f..a93d08d467 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -140,7 +140,9 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a2e002e875..30920e42fc 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -234,7 +234,9 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) assert ( not keep_backward_unquantized ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 90f9778c7e..904cac1733 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -128,7 +128,9 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index d73fceeaf0..307b2e1624 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,8 +332,8 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) @@ -990,8 +990,8 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 80cb5647d7..2bccabb306 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,8 +92,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index cf29140a20..03e3bff6f3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,8 +86,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 0caae13af9..8cebcec53a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,8 +65,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get extra input tensor for add operation From 074b83fcf52a10f9a22aea8f3de04b18c236babd Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:43:08 -0800 Subject: [PATCH 25/61] Move ub overrides to fwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 14 ++++++++------ transformer_engine/pytorch/module/linear.py | 13 +++++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a93d08d467..f3d37900e2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -538,6 +538,14 @@ def forward( ctx.wgrad_store = wgrad_store ctx.debug = debug + # keep_backward_unquantized overrides + if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -609,12 +617,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 904cac1733..bc56e4d3e0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -492,6 +492,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -544,12 +551,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None From 1a504d44035103f8902ab98206045a9e958d8337 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:44:22 -0800 Subject: [PATCH 26/61] Remove duplication Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 341f23972c..1307302180 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -269,8 +269,6 @@ class Float8CurrentScaling(Recipe): fp8_mha: bool = False quantize_forward: bool = True quantize_backward: bool = field(default_factory=_default_quantize_backward) - quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." From 447677b2b305f5a1f5bea3c43affb0f6ead65390 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:59:39 -0800 Subject: [PATCH 27/61] Simplify use_fp8_bwd logic in bwd Signed-off-by: Ziang Li --- .../pytorch/module/grouped_linear.py | 19 +++++++++---- .../pytorch/module/layernorm_linear.py | 25 ++++++++--------- transformer_engine/pytorch/module/linear.py | 28 +++++++++---------- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 9dacb51f4a..a3fde744ec 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -315,6 +315,14 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -331,7 +339,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -347,7 +354,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms - if use_fp8_bwd and not ctx.debug: + if ctx.fp8 and not ctx.debug: if ctx.use_bias: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) recipe = ctx.fp8_recipe @@ -401,7 +408,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = ( @@ -435,7 +442,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): wgrad_gemm_use_split_accumulator = ( @@ -463,7 +470,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list - if use_fp8_bwd and not ctx.debug: + if ctx.fp8 and not ctx.debug: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -540,7 +547,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() - and not use_fp8_bwd + and not ctx.fp8 ): grad_biases = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f3d37900e2..4b04522ce9 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -540,7 +540,7 @@ def forward( # keep_backward_unquantized overrides if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -616,7 +616,6 @@ def backward( origin_weight.main_grad = main_grad keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None @@ -654,7 +653,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_fp8_bwd: + if ctx.grad_output_quantizer is not None and ctx.fp8: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -691,7 +690,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None and use_fp8_bwd: + if ctx.input_quantizer is not None and ctx.fp8: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -730,7 +729,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_fp8_bwd + ctx.fp8 and ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage) ): @@ -738,13 +737,13 @@ def backward( # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_fp8_bwd: + if ctx.grad_input_quantizer is not None and ctx.fp8: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -768,7 +767,7 @@ def backward( grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -853,14 +852,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -869,7 +868,7 @@ def backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -895,7 +894,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -903,7 +902,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index bc56e4d3e0..1c6a4e7e39 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -494,6 +494,7 @@ def forward( # keep_backward_unquantized overrides if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -550,7 +551,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_pop(f"{nvtx_label}.fsdp_gather") keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None @@ -591,7 +591,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_fp8_bwd: + if ctx.grad_output_quantizer is not None and ctx.fp8: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -610,7 +610,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and use_fp8_bwd + and ctx.fp8 ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -640,7 +640,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -666,7 +666,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -708,7 +708,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_fp8_bwd + ctx.fp8 and ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage) ): @@ -716,13 +716,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_fp8_bwd: + if ctx.grad_input_quantizer is not None and ctx.fp8: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -747,7 +747,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -796,7 +796,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -838,7 +838,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -847,7 +847,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -873,7 +873,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -881,7 +881,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, From 81e7febc0bf3e4b706dc7d0194a6ec99e2ec9360 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:00:24 +0000 Subject: [PATCH 28/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a3fde744ec..08fc9d6de8 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -315,7 +315,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - + # keep_backward_unquantized overrides if keep_backward_unquantized: ctx.fp8 = ctx.fp8 and not keep_backward_unquantized From 435859b8da184733427545fb1f8300a710910302 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 14:28:06 -0800 Subject: [PATCH 29/61] Set grad quantizers to none if keep bwd unquantized Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 3 +++ .../pytorch/module/layernorm_linear.py | 11 +++++++---- transformer_engine/pytorch/module/linear.py | 13 ++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 08fc9d6de8..bead74d9a6 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -323,6 +323,9 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4b04522ce9..2b69ad9157 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -545,6 +545,9 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -653,7 +656,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and ctx.fp8: + if ctx.grad_output_quantizer is not None: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -743,7 +746,7 @@ def backward( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and ctx.fp8: + if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -767,7 +770,7 @@ def backward( grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, + quantization_params=ctx.grad_input_quantizer, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -894,7 +897,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), + "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1c6a4e7e39..eead5229fa 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -499,6 +499,10 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -591,7 +595,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and ctx.fp8: + if ctx.grad_output_quantizer is not None: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -610,7 +614,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and ctx.fp8 ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -722,7 +725,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and ctx.fp8: + if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -747,7 +750,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, + quantization_params=ctx.grad_input_quantizer, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -873,7 +876,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), + "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) From a191345d6fa264ea6cf37c7b52d9a2eb0c64cbfa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:28:55 +0000 Subject: [PATCH 30/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index eead5229fa..b5c0ceb2b7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -502,7 +502,6 @@ def forward( ctx.grad_input_quantizer = None ctx.grad_weight_quantizer = None ctx.grad_output_quantizer = None - # ------------------------------------------------------ # Cached state for backward pass is ready... From 1fbc22a3ee2ecbcfc5eedf39ff514996a3eec556 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 17:28:04 -0800 Subject: [PATCH 31/61] Drop delayed scaling change Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2b69ad9157..e097c0bad7 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -526,11 +526,7 @@ def forward( ctx.requires_dgrad = inp_requires_grad ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad(inp, ln_weight, ln_bias, weight, bias) - ): + if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): From 7da1af5dda1856c144385751c4aaf0152e037da4 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:29:24 -0800 Subject: [PATCH 32/61] Simplify env var logic Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 1307302180..e76256cce3 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,11 +11,6 @@ from pydantic.dataclasses import dataclass -def _default_quantize_backward() -> bool: - """Default backward quantization setting.""" - return not bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) - - class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -222,7 +217,7 @@ def scaling_factor_compute(amax: Tensor, fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -268,7 +263,7 @@ class Float8CurrentScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -324,7 +319,7 @@ class MXFP8BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -491,7 +486,7 @@ class NVFP4BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" @@ -567,7 +562,7 @@ class CustomRecipe(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __repr__(self) -> str: return ( From 442297cc8f50712598a5f3e23d27a0012c94eae3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:41:01 -0800 Subject: [PATCH 33/61] Move validation check to recipe Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 18 ++++++++++++++++++ transformer_engine/pytorch/quantization.py | 17 ----------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index e76256cce3..b7cdbe818a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -221,6 +221,12 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." + assert ( + not self.quantize_backward + ), "Delayed scaling does not support quantize_backward=False." def __repr__(self) -> str: return ( @@ -267,6 +273,9 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -323,6 +332,9 @@ class MXFP8BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -400,6 +412,9 @@ def __post_init__(self) -> None: not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -491,6 +506,9 @@ class NVFP4BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." # Quantization params # Note: RHT is currently only applied to column-wise usage so that diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 47d03ab4da..9364ffe5bd 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,21 +87,6 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) -def _validate_recipe_quantization_flags(recipe: Recipe) -> None: - """Validate forward/backward quantization flags on a recipe.""" - quantize_forward = getattr(recipe, "quantize_forward", True) - quantize_backward = getattr(recipe, "quantize_backward", True) - if not quantize_forward and quantize_backward: - raise ValueError( - "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - ) - if recipe.delayed() and not quantize_backward: - raise ValueError( - "Invalid recipe configuration: delayed scaling does not support " - "quantize_backward=False." - ) - - def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" recipe_supported = True @@ -859,8 +844,6 @@ def autocast( """ fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - if enabled or calibrating: - _validate_recipe_quantization_flags(fp8_recipe) quantize_forward = getattr(fp8_recipe, "quantize_forward", True) effective_enabled = enabled and quantize_forward if effective_enabled: From 9e9e94f525aa2ba7b7ab837d114c9c22bbc04c9f Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:55:28 -0800 Subject: [PATCH 34/61] Simplify effective_enabled Signed-off-by: Ziang Li --- transformer_engine/pytorch/quantization.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9364ffe5bd..134781101a 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -843,11 +843,9 @@ def autocast( are reduced at the end of each training step. """ - fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - quantize_forward = getattr(fp8_recipe, "quantize_forward", True) - effective_enabled = enabled and quantize_forward + effective_enabled = enabled and getattr(recipe, "quantize_forward", True) if effective_enabled: - check_recipe_support(fp8_recipe) + check_recipe_support(recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() @@ -855,7 +853,7 @@ def autocast( FP8GlobalStateManager.autocast_enter( enabled=effective_enabled, calibrating=calibrating, - fp8_recipe=fp8_recipe, + fp8_recipe=recipe, fp8_group=amax_reduction_group, _graph=_graph, ) From e269b85bc6ea1dd13fed551bd4d4177844d04a2d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:56:33 -0800 Subject: [PATCH 35/61] Fix inverted assertion logic Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b7cdbe818a..1c14e5e42c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -224,9 +224,7 @@ def __post_init__(self) -> None: assert not ( not self.quantize_forward and self.quantize_backward ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - assert ( - not self.quantize_backward - ), "Delayed scaling does not support quantize_backward=False." + assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." def __repr__(self) -> str: return ( From 153b1a8ad584aa7bb547dc8eceb2d11056e0d3c2 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 12:33:38 -0800 Subject: [PATCH 36/61] Simplify changes under ops Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/basic/basic_linear.py | 4 ---- transformer_engine/pytorch/ops/basic/quantize.py | 11 ++++++----- .../ops/fused/forward_linear_bias_activation.py | 7 ++----- .../pytorch/ops/fused/forward_linear_bias_add.py | 7 ++----- .../pytorch/ops/fused/forward_linear_scale_add.py | 7 ++----- 5 files changed, 12 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 307b2e1624..15a6815d2e 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -1020,11 +1020,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None saved_weight = self.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 4c67cd8cce..9dcd33f9b3 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -57,11 +57,12 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = ( - fp8_enabled - and self._quantize_backward - and FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + quantize_backward = fp8_enabled and self._quantize_backward + + # Recipe quantize overrides + if FP8GlobalStateManager.get_fp8_recipe() is not None: + quantize_forward = quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward + quantize_backward = quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 2bccabb306..860407904c 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -122,11 +122,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 03e3bff6f3..0729291d55 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -119,11 +119,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 8cebcec53a..dfdd11a231 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -100,11 +100,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From e5eb2edb5cc975f358251a3ca648e57c4488fdf9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:34:39 +0000 Subject: [PATCH 37/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/basic/quantize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 9dcd33f9b3..b2a36d1daa 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -61,8 +61,12 @@ def op_forward( # Recipe quantize overrides if FP8GlobalStateManager.get_fp8_recipe() is not None: - quantize_forward = quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward - quantize_backward = quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward + quantize_forward = ( + quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward + ) + quantize_backward = ( + quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # Quantize if needed out = input_ From c233a6daa4f52cac74c609fe6d2570def9ef0215 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 12:52:01 -0800 Subject: [PATCH 38/61] Simplify ctx.keep_backward_unquantized Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/base.py | 3 +-- transformer_engine/pytorch/module/grouped_linear.py | 3 +-- transformer_engine/pytorch/module/layernorm_linear.py | 4 +--- transformer_engine/pytorch/module/linear.py | 4 +--- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index f4351a3be8..2ecd5f77c8 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1184,8 +1184,7 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_fp8_bwd = ctx.fp8 and not ctx.keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. if not use_fp8_bwd and not ctx.debug: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index bead74d9a6..b8ebd521a9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -341,7 +341,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -423,7 +422,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) weights_for_dgrad = weights - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weights_for_dgrad = origin_weights # Make sure weights are available in column-wise format # for dgrad computation. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e097c0bad7..10dbb8b0f8 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -614,8 +614,6 @@ def backward( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -759,7 +757,7 @@ def backward( # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b5c0ceb2b7..5eda5886f1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -553,8 +553,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -742,7 +740,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight_fp8 - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weight_for_dgrad = weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, From b96762819f8632d4fafdf907f1da97d22b697d11 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 15:07:48 -0800 Subject: [PATCH 39/61] Fix missing attribute Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 2 ++ transformer_engine/pytorch/module/layernorm_mlp.py | 1 + 2 files changed, 3 insertions(+) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 1c14e5e42c..46a19652f1 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -389,6 +389,8 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 30920e42fc..cb40bff1ae 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -786,6 +786,7 @@ def _forward( ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer From 8ff02a2cb671613e19bd750fe8fa5017b77d5204 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 10 Feb 2026 14:02:10 -0800 Subject: [PATCH 40/61] Add unit tests Signed-off-by: Ziang Li --- qa/L0_pytorch_unittest/test.sh | 1 + .../pytorch/test_keep_backward_unquantized.py | 701 ++++++++++++++++++ 2 files changed, 702 insertions(+) create mode 100644 tests/pytorch/test_keep_backward_unquantized.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index f2b0b07fed..c5cce521d4 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -42,6 +42,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +NVTE_KEEP_BACKWARD_UNQUANTIZED=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_keep_backward_unquantized.xml $TE_PATH/tests/pytorch/test_keep_backward_unquantized.py || test_fail "test_keep_backward_unquantized.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py new file mode 100644 index 0000000000..a5ef00e34c --- /dev/null +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -0,0 +1,701 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +from contextlib import nullcontext +import os +from typing import Optional + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe +from transformer_engine.pytorch.ops.fused import ( + BackwardActivationBias, + ForwardLinearBiasActivation, + ForwardLinearBiasAdd, + ForwardLinearScaleAdd, +) + +from utils import quantization_tols, reset_rng_states + + +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +# This file is intended to run in dedicated keep-backward-unquantized mode. +pytestmark = pytest.mark.skipif( + os.environ.get("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") != "1", + reason="Requires NVTE_KEEP_BACKWARD_UNQUANTIZED=1", +) + + +_quantized_numerics_recipe_list = [ + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + id="Float8CurrentScaling", + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + id="MXFP8BlockScaling", + ), + pytest.param( + "fp8_block_scaling", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + id="Float8BlockScaling", + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4BlockScaling", + ), +] + +_shape_test_cases = [ + pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), + pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), +] + +_bias_activation_shape_cases = [ + pytest.param((32, 64), id="2d_m32_k64"), + pytest.param((8, 4, 64), id="3d_m32_k64"), +] + + +def _make_recipe(recipe_name: str, quantize_backward: Optional[bool]) -> recipe.Recipe: + kwargs = {} + if quantize_backward is not None: + kwargs = {"quantize_forward": True, "quantize_backward": quantize_backward} + + if recipe_name == "fp8_current_scaling": + return recipe.Float8CurrentScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "mxfp8": + return recipe.MXFP8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "fp8_block_scaling": + return recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "nvfp4": + return recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + **kwargs, + ) + + raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") + + +def _build_keep_backward_unquantized_recipe(recipe_name: str) -> recipe.Recipe: + fp8_recipe = _make_recipe(recipe_name, quantize_backward=None) + assert fp8_recipe.quantize_forward + assert not fp8_recipe.quantize_backward + return fp8_recipe + + +def _build_quantized_reference_recipe(recipe_name: str) -> recipe.Recipe: + return _make_recipe(recipe_name, quantize_backward=True) + + +def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: + src_params = dict(src_module.named_parameters()) + with torch.no_grad(): + for name, dst_param in dst_module.named_parameters(): + if name not in src_params: + raise RuntimeError(f"Parameter {name} missing in source module") + dst_param.copy_(src_params[name]) + + +def _fprop_tolerances(recipe_name: str) -> dict[str, float]: + if recipe_name == "mxfp8": + return quantization_tols("mxfp8") + if recipe_name in ("fp8_current_scaling", "fp8_block_scaling"): + return quantization_tols("fp8_current_scaling") + if recipe_name == "nvfp4": + return quantization_tols("nvfp4") + raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") + + +def _make_linear_like_module( + module_type: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + bias: bool = False, +) -> torch.nn.Module: + if module_type == "linear": + return te.Linear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "layernorm_linear": + return te.LayerNormLinear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "ops_linear": + return te_ops.Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device="cuda", + ) + raise ValueError(f"Unsupported module type: {module_type}") + + +def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: + if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": + pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + + +def _run_single_step( + module: torch.nn.Module, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + assert x_run.grad is not None + assert module.weight.grad is not None + return ( + y.detach().clone(), + x_run.grad.detach().clone(), + module.weight.grad.detach().clone(), + ) + + +def _extract_bias_grad(module: torch.nn.Module) -> Optional[torch.Tensor]: + bias = getattr(module, "bias", None) + if bias is None or bias.grad is None: + return None + return bias.grad.detach().clone() + + +def _run_grouped_linear_single_step( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[Optional[torch.Tensor]]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run, m_splits) + y.backward(dy) + assert x_run.grad is not None + weight_grads = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + bias_grads: list[Optional[torch.Tensor]] = [] + for i in range(module.num_gemms): + if module.use_bias: + bias_grads.append(getattr(module, f"bias{i}").grad.detach().clone()) + else: + bias_grads.append(None) + return y.detach().clone(), x_run.grad.detach().clone(), weight_grads, bias_grads + + +def _make_fused_model( + pattern: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + scale: float = 0.5, +) -> te_ops.Sequential: + if pattern == "bias_activation": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.ReLU(), + ) + if pattern == "bias_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.AddExtraInput(in_place=True), + ) + if pattern == "scale_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), + ) + raise ValueError(f"Unsupported fused test pattern: {pattern}") + + +def _run_fused_single_step( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], + x2: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + y.backward(dy) + assert x1_run.grad is not None + weight_grad = model[0].weight.grad.detach().clone() + bias_grad = None + if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: + bias_grad = model[0].bias.grad.detach().clone() + x2_grad = x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + return y.detach().clone(), x1_run.grad.detach().clone(), x2_grad, weight_grad, bias_grad + + +def _run_quantize_op_single_step( + model: te_ops.Sequential, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor]: + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = model(x_run) + y.backward(dy) + assert x_run.grad is not None + return y.detach().clone(), x_run.grad.detach().clone() + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +def test_keep_backward_unquantized_recipe_defaults(recipe_name: str): + _ = _build_keep_backward_unquantized_recipe(recipe_name) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize( + "module_type", + ("linear", "layernorm_linear", "ops_linear"), +) +@pytest.mark.parametrize( + "input_shape,out_features", + _shape_test_cases, +) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads( + recipe_name: str, + module_type: str, + input_shape: tuple[int, ...], + out_features: int, + use_bias: bool, +): + reset_rng_states() + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + dtype = torch.bfloat16 + in_features = input_shape[-1] + + module_quantized_ref = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + module_keep_bwd_hp = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + module_unquantized_ref = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + + # Start all runs from identical parameters. + _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + + output_shape = input_shape[:-1] + (out_features,) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*output_shape, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) + y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp = _run_single_step( + module_keep_bwd_hp, x, dy, keep_bwd_hp_recipe + ) + _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step(module_unquantized_ref, x, dy, None) + + # Forward pass should still match quantized reference when only backward is unquantized. + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + + # Backward pass should match unquantized reference for dgrad and wgrad. + torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) + if use_bias: + bgrad_keep = _extract_bias_grad(module_keep_bwd_hp) + bgrad_unquantized = _extract_bias_grad(module_unquantized_ref) + assert bgrad_keep is not None + assert bgrad_unquantized is not None + torch.testing.assert_close(bgrad_keep, bgrad_unquantized, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize( + "m_splits", + ([32, 32, 32, 32], [64, 0, 32, 32]), + ids=("uniform_splits", "with_empty_split"), +) +def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_unquantized_grads( + recipe_name: str, + use_bias: bool, + m_splits: list[int], +): + if recipe_name == "nvfp4": + pytest.skip("NVFP4 not supported for grouped linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + + module_quantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_keep_bwd_hp = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_unquantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + + _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _ = _run_grouped_linear_single_step( + module_quantized_ref, x, m_splits, dy, quantized_ref_recipe + ) + y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = _run_grouped_linear_single_step( + module_keep_bwd_hp, x, m_splits, dy, keep_bwd_hp_recipe + ) + _, dx_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = _run_grouped_linear_single_step( + module_unquantized_ref, x, m_splits, dy, None + ) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) + for test_dw, ref_dw in zip(dw_keep_bwd_hp, dw_unquantized_ref): + torch.testing.assert_close(test_dw, ref_dw, rtol=0, atol=0) + if use_bias: + for test_db, ref_db in zip(db_keep_bwd_hp, db_unquantized_ref): + assert test_db is not None + assert ref_db is not None + torch.testing.assert_close(test_db, ref_db, rtol=0, atol=0) + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize( + "fused_pattern,expected_fused_op", + ( + ("bias_add", ForwardLinearBiasAdd), + ("scale_add", ForwardLinearScaleAdd), + ), +) +def test_keep_backward_unquantized_fused_linear_paths( + recipe_name: str, + fused_pattern: str, + expected_fused_op: type, +): + # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + m = 32 + + model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_keep_bwd_hp = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + + _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) + _copy_named_parameters(model_quantized_ref, model_unquantized_ref) + + x1 = torch.randn(m, in_features, dtype=dtype, device="cuda") + x2 = None + if fused_pattern in ("bias_add", "scale_add"): + x2 = torch.randn(m, out_features, dtype=dtype, device="cuda") + dy = torch.randn(m, out_features, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + fused_pattern, model_quantized_ref, x1, dy, quantized_ref_recipe, x2=x2 + ) + y_keep_bwd_hp, dx1_keep_bwd_hp, dx2_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = ( + _run_fused_single_step( + fused_pattern, + model_keep_bwd_hp, + x1, + dy, + keep_bwd_hp_recipe, + x2=x2, + ) + ) + _, dx1_unquantized_ref, dx2_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = ( + _run_fused_single_step( + fused_pattern, + model_unquantized_ref, + x1, + dy, + None, + x2=x2, + ) + ) + + # Ensure this test executes the fused path changed by the keep-bwd feature. + fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], expected_fused_op) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + torch.testing.assert_close(dx1_keep_bwd_hp, dx1_unquantized_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) + if dx2_keep_bwd_hp is not None and dx2_unquantized_ref is not None: + torch.testing.assert_close(dx2_keep_bwd_hp, dx2_unquantized_ref, rtol=0, atol=0) + if db_keep_bwd_hp is not None and db_unquantized_ref is not None: + torch.testing.assert_close(db_keep_bwd_hp, db_unquantized_ref, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) +def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_backward( + recipe_name: str, + input_shape: tuple[int, ...], +): + # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = input_shape[-1] + out_features = 64 + + model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) + model_keep_bwd_hp = _make_fused_model("bias_activation", in_features, out_features, dtype) + linear_unquantized_ref = _make_linear_like_module( + "ops_linear", in_features, out_features, dtype, bias=True + ) + + _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) + _copy_named_parameters(model_keep_bwd_hp[0], linear_unquantized_ref) + + x1 = torch.randn(*input_shape, dtype=dtype, device="cuda") + out_shape = x1.shape[:-1] + (out_features,) + dy = torch.randn(*out_shape, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + "bias_activation", model_quantized_ref, x1, dy, quantized_ref_recipe + ) + y_keep_bwd_hp, dx1_keep_bwd_hp, _, dw_keep_bwd_hp, db_keep_bwd_hp = _run_fused_single_step( + "bias_activation", model_keep_bwd_hp, x1, dy, keep_bwd_hp_recipe + ) + + # Ensure this test executes the fused path changed by the keep-bwd feature. + fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) + + # keep-bwd mode should disable backward-activation+bias fusion, while quantized + # reference should still use it. + keep_bwd_backward_ops = model_keep_bwd_hp._module_groups[0]._backward_ops + assert not any( + isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops + ) + quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops + assert any( + isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops + ) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + + # In keep-backward-unquantized mode, backward should behave as high-precision linear backward + # given the ReLU mask induced by quantized forward activations. + dy_after_activation = dy * (y_keep_bwd_hp > 0).to(dy.dtype) + _, dx1_expected, dw_expected = _run_single_step(linear_unquantized_ref, x1, dy_after_activation, None) + db_expected = _extract_bias_grad(linear_unquantized_ref) + assert db_keep_bwd_hp is not None + assert db_expected is not None + + torch.testing.assert_close(dx1_keep_bwd_hp, dx1_expected, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_expected, rtol=0, atol=0) + torch.testing.assert_close(db_keep_bwd_hp, db_expected, rtol=0, atol=0) + + +def test_keep_backward_unquantized_autocast_respects_quantize_forward_flag(): + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + + module_quantization_disabled = _make_linear_like_module( + "linear", in_features, out_features, dtype, bias=True + ) + module_unquantized_ref = _make_linear_like_module("linear", in_features, out_features, dtype, bias=True) + _copy_named_parameters(module_quantization_disabled, module_unquantized_ref) + + x = torch.randn(32, in_features, dtype=dtype, device="cuda") + dy = torch.randn(32, out_features, dtype=dtype, device="cuda") + + recipe_no_fwd_quant = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.E4M3, + quantize_forward=False, + quantize_backward=False, + ) + + y_test, dx_test, dw_test = _run_single_step( + module_quantization_disabled, x, dy, recipe_no_fwd_quant + ) + y_ref, dx_ref, dw_ref = _run_single_step(module_unquantized_ref, x, dy, None) + + torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_test, dw_ref, rtol=0, atol=0) + bgrad_test = _extract_bias_grad(module_quantization_disabled) + bgrad_ref = _extract_bias_grad(module_unquantized_ref) + assert bgrad_test is not None + assert bgrad_ref is not None + torch.testing.assert_close(bgrad_test, bgrad_ref, rtol=0, atol=0) + + +def test_keep_backward_unquantized_quantize_op_respects_recipe_overrides(): + reset_rng_states() + dtype = torch.bfloat16 + x = torch.randn(32, 64, dtype=dtype, device="cuda") + dy = torch.randn(32, 64, dtype=dtype, device="cuda") + + model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + + recipe_no_quant = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.E4M3, + quantize_forward=False, + quantize_backward=False, + ) + y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, recipe_no_quant) + y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, None) + + torch.testing.assert_close(y_override, y_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_override, dx_ref, rtol=0, atol=0) + + +def test_keep_backward_unquantized_is_invalid_for_delayed_scaling(): + with pytest.raises( + (AssertionError, ValueError), + match="Delayed scaling does not support quantize_backward=False", + ): + _ = recipe.DelayedScaling() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_keep_backward_unquantized_not_implemented_for_layernorm_mlp(): + reset_rng_states() + layer = te.LayerNormMLP( + hidden_size=64, + ffn_hidden_size=64, + params_dtype=torch.bfloat16, + bias=False, + device="cuda", + ) + x = torch.randn(32, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe("fp8_current_scaling") + + with pytest.raises( + AssertionError, match="NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + ): + with te.autocast(enabled=True, recipe=keep_bwd_hp_recipe): + _ = layer(x) From 01af85529b7406beda25fd8ddc12a3e2abb63962 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 10 Feb 2026 14:03:02 -0800 Subject: [PATCH 41/61] Fix bias errors in unit test Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/basic/bias.py | 8 +++++++- .../pytorch/ops/fused/backward_activation_bias.py | 5 +++-- .../pytorch/ops/fused/forward_linear_bias_activation.py | 4 +++- .../pytorch/ops/fused/forward_linear_bias_add.py | 4 +++- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index d580f84866..8bcd84b441 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -10,6 +10,7 @@ import torch import transformer_engine_torch as tex +from ...quantization import FP8GlobalStateManager from ..op import BasicOperation, OperationContext from ...utils import canonicalize_device, canonicalize_dtype from ...tensor import Quantizer @@ -123,7 +124,12 @@ def op_forward( b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) if ctx.requires_grad: - ctx.grad_input_quantizer = prev_op_grad_output_quantizer + keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) + ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else prev_op_grad_output_quantizer + ) return x + b diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 4ab082d32b..395a9dbd67 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -104,8 +104,9 @@ def fuse_backward_ops( """ - # Check if recipe supports bias activation fusion - if recipe is None: + # Check if recipe supports bias activation fusion. + # keep-backward-unquantized mode should use unfused backward ops. + if recipe is None or not recipe.quantize_backward: return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 860407904c..42f459a41e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -138,7 +138,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + bias_op_ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + ) return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 0729291d55..75d58fd5cc 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -135,7 +135,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + bias_op_ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + ) return output, [() for _ in range(len(self.basic_ops))] From 7198af2ab41d00e70e57e3ac9d12cfb2cab0b71a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 22:03:50 +0000 Subject: [PATCH 42/61] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/test_keep_backward_unquantized.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py index a5ef00e34c..fe11bfcd3a 100644 --- a/tests/pytorch/test_keep_backward_unquantized.py +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -214,7 +214,9 @@ def _run_grouped_linear_single_step( y = module(x_run, m_splits) y.backward(dy) assert x_run.grad is not None - weight_grads = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + weight_grads = [ + getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms) + ] bias_grads: list[Optional[torch.Tensor]] = [] for i in range(module.num_gemms): if module.use_bias: @@ -257,7 +259,9 @@ def _run_fused_single_step( dy: torch.Tensor, fp8_recipe: Optional[recipe.Recipe], x2: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[ + torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor] +]: model.zero_grad(set_to_none=True) x1_run = x1.detach().clone().requires_grad_(True) x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None @@ -276,7 +280,9 @@ def _run_fused_single_step( bias_grad = None if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: bias_grad = model[0].bias.grad.detach().clone() - x2_grad = x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + x2_grad = ( + x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + ) return y.detach().clone(), x1_run.grad.detach().clone(), x2_grad, weight_grad, bias_grad @@ -355,7 +361,9 @@ def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp = _run_single_step( module_keep_bwd_hp, x, dy, keep_bwd_hp_recipe ) - _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step(module_unquantized_ref, x, dy, None) + _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step( + module_unquantized_ref, x, dy, None + ) # Forward pass should still match quantized reference when only backward is unquantized. torch.testing.assert_close( @@ -458,6 +466,7 @@ def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_un assert ref_db is not None torch.testing.assert_close(test_db, ref_db, rtol=0, atol=0) + @pytest.mark.parametrize( "recipe_name", _quantized_numerics_recipe_list, @@ -589,13 +598,9 @@ def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_b # keep-bwd mode should disable backward-activation+bias fusion, while quantized # reference should still use it. keep_bwd_backward_ops = model_keep_bwd_hp._module_groups[0]._backward_ops - assert not any( - isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops - ) + assert not any(isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops) quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops - assert any( - isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops - ) + assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) torch.testing.assert_close( y_keep_bwd_hp, @@ -606,7 +611,9 @@ def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_b # In keep-backward-unquantized mode, backward should behave as high-precision linear backward # given the ReLU mask induced by quantized forward activations. dy_after_activation = dy * (y_keep_bwd_hp > 0).to(dy.dtype) - _, dx1_expected, dw_expected = _run_single_step(linear_unquantized_ref, x1, dy_after_activation, None) + _, dx1_expected, dw_expected = _run_single_step( + linear_unquantized_ref, x1, dy_after_activation, None + ) db_expected = _extract_bias_grad(linear_unquantized_ref) assert db_keep_bwd_hp is not None assert db_expected is not None @@ -625,7 +632,9 @@ def test_keep_backward_unquantized_autocast_respects_quantize_forward_flag(): module_quantization_disabled = _make_linear_like_module( "linear", in_features, out_features, dtype, bias=True ) - module_unquantized_ref = _make_linear_like_module("linear", in_features, out_features, dtype, bias=True) + module_unquantized_ref = _make_linear_like_module( + "linear", in_features, out_features, dtype, bias=True + ) _copy_named_parameters(module_quantization_disabled, module_unquantized_ref) x = torch.randn(32, in_features, dtype=dtype, device="cuda") From 8d985c360ce28eed9041d05e894f8383b86e9737 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 10 Feb 2026 15:22:40 -0800 Subject: [PATCH 43/61] Add more shapes to unit test Signed-off-by: Ziang Li --- .../pytorch/test_keep_backward_unquantized.py | 54 +++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py index fe11bfcd3a..f5c3339a71 100644 --- a/tests/pytorch/test_keep_backward_unquantized.py +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -5,6 +5,7 @@ from __future__ import annotations from contextlib import nullcontext +import math import os from typing import Optional @@ -64,7 +65,9 @@ ] _shape_test_cases = [ + pytest.param((1, 64), 64, id="2d_m1_k64_n64"), pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), ] @@ -166,6 +169,46 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") +def _maybe_skip_unsupported_recipe_shape( + recipe_name: str, + input_shape: tuple[int, ...], + module_type: str, +) -> None: + flat_first_dim = math.prod(input_shape[:-1]) + last_dim = input_shape[-1] + + # TE Linear / LayerNormLinear FP8 kernels require FP8-GEMM-compatible dimensions. + if module_type in ("linear", "layernorm_linear"): + if flat_first_dim % 8 != 0 or last_dim % 16 != 0: + pytest.skip( + "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " + "and shape[-1] divisible by 16." + ) + return + + # te_ops.Linear (fusible ops) has stricter constraints for some block-scaled recipes. + if module_type == "ops_linear": + if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): + pytest.skip( + "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." + ) + if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + pytest.skip( + "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." + ) + + +def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: + # Grouped GEMM paths enforce additional split-alignment constraints for block-scaled recipes. + non_empty_splits = [m for m in m_splits if m > 0] + if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): + pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") + if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): + pytest.skip( + "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." + ) + + def _run_single_step( module: torch.nn.Module, x: torch.Tensor, @@ -333,6 +376,7 @@ def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads ): reset_rng_states() _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) dtype = torch.bfloat16 in_features = input_shape[-1] @@ -390,8 +434,8 @@ def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads @pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) @pytest.mark.parametrize( "m_splits", - ([32, 32, 32, 32], [64, 0, 32, 32]), - ids=("uniform_splits", "with_empty_split"), + ([32, 32, 32, 32], [64, 0, 32, 32], [1, 31, 0, 96]), + ids=("uniform_splits", "with_empty_split", "small_and_empty_splits"), ) def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_unquantized_grads( recipe_name: str, @@ -400,6 +444,7 @@ def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_un ): if recipe_name == "nvfp4": pytest.skip("NVFP4 not supported for grouped linear") + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) reset_rng_states() dtype = torch.bfloat16 @@ -478,10 +523,12 @@ def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_un ("scale_add", ForwardLinearScaleAdd), ), ) +@pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) def test_keep_backward_unquantized_fused_linear_paths( recipe_name: str, fused_pattern: str, expected_fused_op: type, + m: int, ): # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") @@ -490,8 +537,7 @@ def test_keep_backward_unquantized_fused_linear_paths( dtype = torch.bfloat16 in_features = 64 out_features = 64 - m = 32 - + _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) model_keep_bwd_hp = _make_fused_model(fused_pattern, in_features, out_features, dtype) model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) From e35da447d49e416c25b16ba794ca5a02682a1e9d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 24 Feb 2026 15:10:49 -0800 Subject: [PATCH 44/61] Refator interface to `NVTE_BACKWARD_MODE=default|unquant|dequant` Signed-off-by: Ziang Li --- qa/L0_pytorch_unittest/test.sh | 2 +- tests/pytorch/test_backward_mode.py | 1446 +++++++++++++++++ .../pytorch/test_keep_backward_unquantized.py | 756 --------- transformer_engine/common/recipe/__init__.py | 129 +- transformer_engine/pytorch/module/base.py | 2 +- .../pytorch/module/grouped_linear.py | 74 +- .../pytorch/module/layernorm_linear.py | 47 +- .../pytorch/module/layernorm_mlp.py | 13 +- transformer_engine/pytorch/module/linear.py | 51 +- .../pytorch/ops/basic/basic_linear.py | 44 +- transformer_engine/pytorch/ops/basic/bias.py | 11 +- .../pytorch/ops/basic/quantize.py | 12 +- .../ops/fused/backward_activation_bias.py | 4 +- .../fused/forward_linear_bias_activation.py | 23 +- .../ops/fused/forward_linear_bias_add.py | 19 +- .../ops/fused/forward_linear_scale_add.py | 17 +- .../ops/fused/userbuffers_forward_linear.py | 13 + transformer_engine/pytorch/ops/fuser.py | 14 +- transformer_engine/pytorch/quantization.py | 7 +- 19 files changed, 1749 insertions(+), 935 deletions(-) create mode 100644 tests/pytorch/test_backward_mode.py delete mode 100644 tests/pytorch/test_keep_backward_unquantized.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index c5cce521d4..a9df8a1bb6 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -42,7 +42,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -NVTE_KEEP_BACKWARD_UNQUANTIZED=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_keep_backward_unquantized.xml $TE_PATH/tests/pytorch/test_keep_backward_unquantized.py || test_fail "test_keep_backward_unquantized.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_mode.xml $TE_PATH/tests/pytorch/test_backward_mode.py || test_fail "test_backward_mode.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py new file mode 100644 index 0000000000..300d860496 --- /dev/null +++ b/tests/pytorch/test_backward_mode.py @@ -0,0 +1,1446 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +from contextlib import nullcontext +import math +from typing import Optional + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe +from transformer_engine.pytorch.cpp_extensions import general_gemm, layernorm_bwd +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.ops.fused import ( + BackwardActivationBias, + ForwardLinearBiasActivation, + ForwardLinearBiasAdd, + ForwardLinearScaleAdd, + UserbuffersForwardLinear, +) +from transformer_engine.pytorch.quantized_tensor import restore_from_saved + +from utils import quantization_tols, reset_rng_states + + +# -------------------------- +# Mode and capability config +# -------------------------- + +_NON_QUANT_BACKWARD_MODES = ("unquant", "dequant") + +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) +bf16_available, reason_for_no_bf16 = te.is_bf16_available(return_reason=True) + +# Broad dtype coverage for modules touched by this change. +_core_dtypes = [torch.float16, torch.float32] +if bf16_available: + _core_dtypes.insert(1, torch.bfloat16) + +# Fused GEMM+bias+activation requires FP16/BF16 output. +_fused_dtypes = [torch.float16] +if bf16_available: + _fused_dtypes.append(torch.bfloat16) + + +@pytest.fixture(autouse=True) +def _reset_global_fp8_state(): + """Avoid global FP8-state leakage between parametrized cases.""" + yield + FP8GlobalStateManager.reset() + + +@pytest.fixture(params=_NON_QUANT_BACKWARD_MODES, ids=lambda mode: f"mode_{mode}") +def backward_mode(request: pytest.FixtureRequest) -> str: + """Backward mode under test.""" + return request.param + + +# -------------------------- +# Shared helpers +# -------------------------- + + +def _assert_exact(test: torch.Tensor, ref: torch.Tensor) -> None: + torch.testing.assert_close(test, ref, rtol=0, atol=0) + + +def _assert_forward_matches_quantized_ref( + test: torch.Tensor, + ref: torch.Tensor, + recipe_name: str, +) -> None: + torch.testing.assert_close(test, ref, **_fprop_tolerances(recipe_name)) + + +def _restore_saved_operands(output: torch.Tensor) -> list[Optional[torch.Tensor]]: + if output.grad_fn is None: + raise RuntimeError("Output tensor has no grad_fn; cannot inspect saved operands") + if not hasattr(output.grad_fn, "tensor_objects"): + raise RuntimeError("grad_fn does not expose tensor_objects for saved operand restoration") + return restore_from_saved(output.grad_fn.tensor_objects, list(output.grad_fn.saved_tensors)) + + +def _extract_linear_saved_operands( + saved_operands: list[Optional[torch.Tensor]], + *, + context: str, +) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if len(saved_operands) < 2: + raise RuntimeError( + f"Insufficient saved operands for {context} dequant reference " + f"(got {len(saved_operands)}, expected at least 2)." + ) + return saved_operands[0], saved_operands[1] + + +def _dequantize_saved_operand( + saved_operand: Optional[torch.Tensor], + dtype: torch.dtype, +) -> torch.Tensor: + if saved_operand is None: + raise RuntimeError("Expected saved operand but got None") + # In dequant mode we must consume the fprop-saved quantized payload directly. + # If row-wise payload is missing, the tensor was retargeted to a transpose-only + # layout and no longer represents the original fprop operand. + if ( + not isinstance(saved_operand, torch.Tensor) + and hasattr(saved_operand, "_rowwise_data") + and getattr(saved_operand, "_rowwise_data") is None + ): + raise RuntimeError( + "Saved dequant operand lost row-wise fprop payload (likely usage retarget)." + ) + if isinstance(saved_operand, torch.Tensor): + return saved_operand.to(dtype) + if not hasattr(saved_operand, "dequantize"): + raise RuntimeError(f"Unsupported saved operand type: {type(saved_operand)}") + return saved_operand.dequantize(dtype=dtype) + + +def _assert_saved_quantized_operand_uses_rowwise_only( + saved_operand: Optional[torch.Tensor], + *, + name: str, +) -> None: + if saved_operand is None: + raise RuntimeError(f"Expected quantized saved {name} operand but got None") + if isinstance(saved_operand, torch.Tensor): + raise RuntimeError( + f"Dequant reference expects quantized saved {name} operand, got torch.Tensor." + ) + if not hasattr(saved_operand, "dequantize"): + raise RuntimeError(f"Unsupported saved {name} operand type: {type(saved_operand)}") + if hasattr(saved_operand, "_rowwise_data") and getattr(saved_operand, "_rowwise_data") is None: + raise RuntimeError( + f"Saved dequant {name} operand lost row-wise fprop payload (likely usage retarget)." + ) + if ( + hasattr(saved_operand, "_columnwise_data") + and getattr(saved_operand, "_columnwise_data") is not None + ): + raise RuntimeError( + f"Saved dequant {name} operand unexpectedly carries column-wise payload." + ) + + +def _snapshot_saved_quantized_operand_layout( + saved_operand: Optional[torch.Tensor], + *, + name: str, +) -> dict[str, object]: + _assert_saved_quantized_operand_uses_rowwise_only(saved_operand, name=name) + rowwise_present = None + columnwise_present = None + rowwise_obj_id = None + if hasattr(saved_operand, "_rowwise_data"): + rowwise_data = getattr(saved_operand, "_rowwise_data") + rowwise_present = rowwise_data is not None + if rowwise_data is not None: + rowwise_obj_id = id(rowwise_data) + if hasattr(saved_operand, "_columnwise_data"): + columnwise_present = getattr(saved_operand, "_columnwise_data") is not None + return { + "name": name, + "saved_operand": saved_operand, + "rowwise_present": rowwise_present, + "columnwise_present": columnwise_present, + "rowwise_obj_id": rowwise_obj_id, + } + + +def _assert_saved_quantized_operand_layout_unchanged(snapshot: dict[str, object]) -> None: + name = snapshot.get("name") + if not isinstance(name, str): + raise RuntimeError(f"Invalid saved operand snapshot name: {name!r}") + saved_operand = snapshot.get("saved_operand") + _assert_saved_quantized_operand_uses_rowwise_only(saved_operand, name=name) + + rowwise_present = snapshot.get("rowwise_present") + if isinstance(rowwise_present, bool): + rowwise_data_now = getattr(saved_operand, "_rowwise_data", None) + rowwise_now = rowwise_data_now is not None + if rowwise_now != rowwise_present: + raise RuntimeError( + f"Saved dequant {name} operand row-wise payload presence changed " + f"from {rowwise_present} to {rowwise_now}." + ) + # Guard against hidden requantization that swaps in a new row-wise payload. + rowwise_obj_id = snapshot.get("rowwise_obj_id") + if ( + isinstance(rowwise_obj_id, int) + and rowwise_now + and id(rowwise_data_now) != rowwise_obj_id + ): + raise RuntimeError( + f"Saved dequant {name} operand row-wise payload identity changed " + "(likely rewritten/requantized)." + ) + + columnwise_present = snapshot.get("columnwise_present") + if isinstance(columnwise_present, bool): + columnwise_now = getattr(saved_operand, "_columnwise_data", None) is not None + if columnwise_now != columnwise_present: + raise RuntimeError( + f"Saved dequant {name} operand column-wise payload presence changed " + f"from {columnwise_present} to {columnwise_now}." + ) + + +def _snapshot_layout_invariants( + guard_operands: list[tuple[str, Optional[torch.Tensor]]], +) -> list[dict[str, object]]: + """Capture saved-operand layout invariants before backward runs.""" + return [ + _snapshot_saved_quantized_operand_layout(saved_operand, name=name) + for name, saved_operand in guard_operands + ] + + +def _assert_layout_invariants_unchanged(layout_invariants: list[dict[str, object]]) -> None: + """Validate saved-operand layout invariants after backward runs.""" + for layout_invariant in layout_invariants: + _assert_saved_quantized_operand_layout_unchanged(layout_invariant) + + +def _raise_if_ref_failed(ref_exc: Optional[Exception]) -> None: + """Re-raise deferred reference exceptions after layout checks.""" + if ref_exc is not None: + raise ref_exc + + +def _compute_linear_backward_reference_from_saved_operands( + saved_input: Optional[torch.Tensor], + saved_weight: Optional[torch.Tensor], + dy: torch.Tensor, + *, + dequant_dtype: torch.dtype, + out_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Dequant reference path: + # 1) use the exact operands saved by quantized forward, + # 2) dequantize them to the active high-precision compute dtype, + # 3) run backward GEMMs in high precision and compare exactly. + for name, saved_operand in (("input", saved_input), ("weight", saved_weight)): + _assert_saved_quantized_operand_uses_rowwise_only(saved_operand, name=name) + dy_mat = dy.reshape(-1, dy.shape[-1]) + + # Empty-token chunks can happen in grouped/fused paths. Reference should be zeros. + if dy_mat.shape[0] == 0: + out_features = dy_mat.shape[-1] + if saved_input is None: + raise RuntimeError("Expected saved input operand for empty-chunk dequant reference.") + in_features = saved_input.size(-1) + dx_ref = torch.zeros(*dy.shape[:-1], in_features, dtype=out_dtype, device=dy.device) + dw_ref = torch.zeros(out_features, in_features, dtype=out_dtype, device=dy.device) + db_ref = torch.zeros(out_features, dtype=out_dtype, device=dy.device) + return dx_ref, dw_ref, db_ref + + x_ref_full = _dequantize_saved_operand(saved_input, dequant_dtype) + x_ref = x_ref_full.reshape(-1, x_ref_full.shape[-1]) + w_ref = _dequantize_saved_operand(saved_weight, dequant_dtype) + + dx_ref_2d, *_ = general_gemm( + w_ref, + dy_mat, + out_dtype=out_dtype, + layout="NN", + grad=True, + ) + # Derive db from the same GEMM primitive used by runtime wgrad. This avoids + # tiny reduction-order drift vs. a standalone dy.sum() path in FP32 cases. + db_seed = torch.empty(dy_mat.shape[-1], dtype=out_dtype, device=dy_mat.device) + dw_ref, db_ref, *_ = general_gemm( + x_ref, + dy_mat, + out_dtype=out_dtype, + layout="NT", + grad=True, + bias=db_seed, + ) + if db_ref is None: + db_ref = dy_mat.sum(dim=0).to(out_dtype) + dx_ref = dx_ref_2d.view(*dy.shape[:-1], dx_ref_2d.shape[-1]) + return dx_ref, dw_ref, db_ref + + +_quantized_numerics_recipe_list = [ + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + id="Float8CurrentScaling", + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + id="MXFP8BlockScaling", + ), + pytest.param( + "fp8_block_scaling", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, + reason=reason_for_no_fp8_block_scaling, + ), + id="Float8BlockScaling", + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4BlockScaling", + ), +] + +_shape_test_cases = [ + pytest.param((1, 64), 64, id="2d_m1_k64_n64"), + pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), + pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), + pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), +] + +_bias_activation_shape_cases = [ + pytest.param((32, 64), id="2d_m32_k64"), + pytest.param((8, 4, 64), id="3d_m32_k64"), +] + + +def _make_recipe(recipe_name: str, *, backward_mode: str) -> recipe.Recipe: + kwargs = {"backward_mode": backward_mode} + if recipe_name == "fp8_current_scaling": + return recipe.Float8CurrentScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "mxfp8": + return recipe.MXFP8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "fp8_block_scaling": + return recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "nvfp4": + return recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + **kwargs, + ) + raise ValueError(f"Unsupported recipe for backward-mode test: {recipe_name}") + + +def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: + src_params = dict(src_module.named_parameters()) + with torch.no_grad(): + for name, dst_param in dst_module.named_parameters(): + if name not in src_params: + raise RuntimeError(f"Parameter {name} missing in source module") + dst_param.copy_(src_params[name]) + + +def _fprop_tolerances(recipe_name: str) -> dict[str, float]: + if recipe_name == "mxfp8": + return quantization_tols("mxfp8") + if recipe_name in ("fp8_current_scaling", "fp8_block_scaling"): + return quantization_tols("fp8_current_scaling") + if recipe_name == "nvfp4": + return quantization_tols("nvfp4") + raise ValueError(f"Unsupported recipe for backward-mode test: {recipe_name}") + + +def _maybe_skip_recipe_dtype(recipe_name: str, dtype: torch.dtype, backward_mode: str) -> None: + if dtype == torch.bfloat16 and not bf16_available: + pytest.skip(reason_for_no_bf16) + if recipe_name == "nvfp4" and dtype != torch.bfloat16: + pytest.skip("NVFP4 is only supported with BF16 in this test") + + +def _make_linear_like_module( + module_type: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + *, + bias: bool, +) -> torch.nn.Module: + if module_type == "linear": + return te.Linear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "layernorm_linear": + return te.LayerNormLinear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "ops_linear": + return te_ops.Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device="cuda", + ) + raise ValueError(f"Unsupported module type: {module_type}") + + +def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: + if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": + pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + + +def _maybe_skip_unsupported_recipe_shape( + recipe_name: str, + input_shape: tuple[int, ...], + module_type: str, +) -> None: + flat_first_dim = math.prod(input_shape[:-1]) + last_dim = input_shape[-1] + + if module_type in ("linear", "layernorm_linear"): + if flat_first_dim % 8 != 0 or last_dim % 16 != 0: + pytest.skip( + "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " + "and shape[-1] divisible by 16." + ) + return + + if module_type == "ops_linear": + if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): + pytest.skip( + "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." + ) + if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + pytest.skip( + "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." + ) + + +def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: + non_empty_splits = [m for m in m_splits if m > 0] + if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): + pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") + if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): + pytest.skip( + "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." + ) + + +def _run_single_step( + module: torch.nn.Module, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + assert x_run.grad is not None + assert module.weight.grad is not None + bgrad = _extract_bias_grad(module) + return ( + y.detach().clone(), + x_run.grad.detach().clone(), + module.weight.grad.detach().clone(), + bgrad, + ) + + +def _run_single_step_with_saved_operands( + module: torch.nn.Module, + x: torch.Tensor, + fp8_recipe: recipe.Recipe, +) -> tuple[ + torch.Tensor, + torch.Tensor, + list[Optional[torch.Tensor]], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + saved_operands = _restore_saved_operands(y) + return y, x_run, saved_operands + + +def _extract_bias_grad(module: torch.nn.Module) -> Optional[torch.Tensor]: + bias = getattr(module, "bias", None) + if bias is None or bias.grad is None: + return None + return bias.grad.detach().clone() + + +def _run_grouped_linear_single_step( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[Optional[torch.Tensor]]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run, m_splits) + y.backward(dy) + assert x_run.grad is not None + + dw = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + db: list[Optional[torch.Tensor]] = [] + for i in range(module.num_gemms): + if module.use_bias: + db.append(getattr(module, f"bias{i}").grad.detach().clone()) + else: + db.append(None) + return y.detach().clone(), x_run.grad.detach().clone(), dw, db + + +def _run_grouped_linear_step_with_saved_operands( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + fp8_recipe: recipe.Recipe, +) -> tuple[ + torch.Tensor, + torch.Tensor, + list[Optional[torch.Tensor]], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = module(x_run, m_splits) + saved_operands = _restore_saved_operands(y) + return y, x_run, saved_operands + + +def _make_fused_model( + pattern: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + *, + scale: float = 0.5, +) -> te_ops.Sequential: + if pattern == "bias_activation": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.ReLU(), + ) + if pattern == "bias_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.AddExtraInput(in_place=True), + ) + if pattern == "scale_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), + ) + raise ValueError(f"Unsupported fused test pattern: {pattern}") + + +def _run_fused_single_step( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], + *, + x2: Optional[torch.Tensor] = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], +]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + y.backward(dy) + assert x1_run.grad is not None + + dw = model[0].weight.grad.detach().clone() + db = None + if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: + db = model[0].bias.grad.detach().clone() + dx2 = x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + return y.detach().clone(), x1_run.grad.detach().clone(), dx2, dw, db + + +def _run_fused_single_step_with_saved_operands( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + fp8_recipe: recipe.Recipe, + *, + x2: Optional[torch.Tensor] = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + list[Optional[torch.Tensor]], +]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + with te.autocast(enabled=True, recipe=fp8_recipe): + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + saved_operands = _restore_saved_operands(y) + return y, x1_run, x2_run, saved_operands + + +def _run_quantize_op_single_step( + model: te_ops.Sequential, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor]: + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = model(x_run) + y.backward(dy) + assert x_run.grad is not None + return y.detach().clone(), x_run.grad.detach().clone() + + +def _make_userbuffers_fuser_for_mode_switch_test( + *, + dtype: torch.dtype, +) -> tuple[object, torch.Tensor, list[tuple[()]]]: + """Build a Userbuffers-eligible fuser and representative inputs.""" + in_features = 64 + out_features = 64 + linear = te_ops.BasicLinear( + in_features, + out_features, + device="cuda", + dtype=dtype, + userbuffers_options={"comm_name": "qkv"}, + ) + linear.tensor_parallel_mode = "column" + linear.tensor_parallel_size = 2 + linear.sequence_parallel = True + bias = te_ops.Bias(out_features, device="cuda", dtype=dtype) + model = te_ops.Sequential(linear, bias) + model._module_groups = model._make_module_groups( + model._modules.values() + ) # pylint: disable=protected-access + fuser = model._module_groups[0] + x = torch.randn(32, in_features, dtype=dtype, device="cuda", requires_grad=True) + extra_inputs = [() for _ in range(fuser._num_basic_ops)] # pylint: disable=protected-access + return fuser, x, extra_inputs + + +def _has_userbuffers_forward_linear(fuser: object) -> bool: + return any( + isinstance(op, UserbuffersForwardLinear) for op, _ in fuser._forward_ops + ) # pylint: disable=protected-access + + +# -------------------------- +# Tests +# -------------------------- + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +def test_backward_mode_recipe_matches_requested_mode( + recipe_name: str, + backward_mode: str, +) -> None: + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + quant_recipe = _make_recipe(recipe_name, backward_mode="default") + assert mode_recipe.backward_mode == backward_mode + assert quant_recipe.backward_mode == "default" + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("module_type", ("linear", "layernorm_linear", "ops_linear")) +@pytest.mark.parametrize("input_shape,out_features", _shape_test_cases) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_linear_like_backward_mode_matches_reference( + recipe_name: str, + module_type: str, + input_shape: tuple[int, ...], + out_features: int, + use_bias: bool, + dtype: torch.dtype, + backward_mode: str, +) -> None: + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) + + in_features = input_shape[-1] + quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + module_quantized_ref = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + module_bwd_mode = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + _copy_named_parameters(module_quantized_ref, module_bwd_mode) + + output_shape = input_shape[:-1] + (out_features,) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*output_shape, dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) + if backward_mode == "unquant": + # Unquant reference path: compare against a plain high-precision backward run + # (no fp8/autocast), starting from the same params and inputs. + module_unquantized_ref = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + y_bwd_mode, dx_bwd_mode, dw_bwd_mode, db_bwd_mode = _run_single_step( + module_bwd_mode, + x, + dy, + mode_recipe, + ) + _, dx_ref, dw_ref, db_ref = _run_single_step( + module_unquantized_ref, + x, + dy, + None, + ) + else: + # Dequant reference path: capture saved forward operands from the real dequant-mode + # execution, then rebuild backward reference from those saved operands. + y_bwd_mode, x_bwd_mode, saved_operands = _run_single_step_with_saved_operands( + module_bwd_mode, x, mode_recipe + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + + dx_ref: Optional[torch.Tensor] = None + dw_ref: Optional[torch.Tensor] = None + db_ref: Optional[torch.Tensor] = None + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + if module_type == "layernorm_linear": + # LayerNormLinear dequant reference: + # 1) Compute d(ln_out), dw, db from linear backward with saved operands. + # 2) Compute exact dx via layernorm_bwd with saved norm statistics. + # _LayerNormLinear forward saves operands as: + # [inputmat, weightmat, origin_weight, bias, ln_weight, ln_out, mu, rsigma, ...] + if len(saved_operands) < 8: + raise RuntimeError( + "Insufficient saved operands for layernorm_linear dequant reference " + f"(got {len(saved_operands)}, expected at least 8)." + ) + saved_input = saved_operands[0] + saved_weight = saved_operands[1] + saved_ln_weight = saved_operands[4] + saved_ln_out = saved_operands[5] + saved_mu = saved_operands[6] + saved_rsigma = saved_operands[7] + guard_operands.extend( + [ + ("layernorm_linear_ln_out", saved_ln_out), + ("layernorm_linear_weight", saved_weight), + ] + ) + d_ln_out_ref, dw_ref, db_ref = ( + _compute_linear_backward_reference_from_saved_operands( + saved_ln_out, + saved_weight, + dy, + dequant_dtype=dtype, + out_dtype=dtype, + ) + ) + input_ref = _dequantize_saved_operand(saved_input, dtype) + input_ref_2d = input_ref.reshape(-1, input_ref.shape[-1]) + ln_weight_ref = _dequantize_saved_operand(saved_ln_weight, dtype).view(-1) + if saved_mu is None or saved_rsigma is None: + raise RuntimeError("Missing LayerNorm statistics in saved operands") + if not isinstance(saved_mu, torch.Tensor) or not isinstance( + saved_rsigma, torch.Tensor + ): + raise RuntimeError("LayerNorm statistics must be Tensor objects") + dx_ref, *_ = layernorm_bwd( + d_ln_out_ref.reshape(input_ref_2d.shape), + input_ref_2d, + saved_mu, + saved_rsigma, + ln_weight_ref, + module_bwd_mode.bwd_ln_sm_margin, + module_bwd_mode.zero_centered_gamma, + ) + dx_ref = dx_ref.view_as(x_bwd_mode) + else: + saved_input, saved_weight = _extract_linear_saved_operands( + saved_operands, + context=f"{module_type}", + ) + guard_operands.extend( + [ + (f"{module_type}_input", saved_input), + (f"{module_type}_weight", saved_weight), + ] + ) + dx_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy, + dequant_dtype=dtype, + out_dtype=dtype, + ) + if module_type == "ops_linear" and use_bias: + # te_ops bias grad is reduced by the Bias op from incoming dy. + db_ref = dy.reshape(-1, dy.shape[-1]).sum(dim=0).to(dtype) + except Exception as exc: # pylint: disable=broad-exception-caught + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x_bwd_mode.grad is not None + assert module_bwd_mode.weight.grad is not None + dx_bwd_mode = x_bwd_mode.grad.detach().clone() + dw_bwd_mode = module_bwd_mode.weight.grad.detach().clone() + db_bwd_mode = _extract_bias_grad(module_bwd_mode) + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx_ref is not None and dw_ref is not None and db_ref is not None + + _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) + _assert_exact(dx_bwd_mode, dx_ref) + _assert_exact(dw_bwd_mode, dw_ref) + if use_bias: + assert db_bwd_mode is not None + assert db_ref is not None + _assert_exact(db_bwd_mode, db_ref) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize( + "m_splits", + ([32, 32, 32, 32], [64, 0, 32, 32], [1, 31, 0, 96]), + ids=("uniform_splits", "with_empty_split", "small_and_empty_splits"), +) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_grouped_linear_backward_mode_matches_reference( + recipe_name: str, + use_bias: bool, + m_splits: list[int], + dtype: torch.dtype, + backward_mode: str, +) -> None: + if recipe_name == "nvfp4": + pytest.skip("NVFP4 not supported for grouped linear") + + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + + in_features = 64 + out_features = 64 + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + + quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + module_quantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_bwd_mode = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + _copy_named_parameters(module_quantized_ref, module_bwd_mode) + + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _ = _run_grouped_linear_single_step( + module_quantized_ref, + x, + m_splits, + dy, + quantized_ref_recipe, + ) + if backward_mode == "unquant": + # Unquant reference path: grouped module in plain high precision. + module_unquantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + y_bwd_mode, dx_bwd_mode, dw_bwd_mode, db_bwd_mode = _run_grouped_linear_single_step( + module_bwd_mode, + x, + m_splits, + dy, + mode_recipe, + ) + _, dx_ref, dw_ref, db_ref = _run_grouped_linear_single_step( + module_unquantized_ref, + x, + m_splits, + dy, + None, + ) + else: + # Dequant reference path for grouped GEMMs: + # each GEMM restores its own saved input/weight pair and computes its own ref grads. + y_bwd_mode, x_bwd_mode, saved_operands = _run_grouped_linear_step_with_saved_operands( + module_bwd_mode, x, m_splits, mode_recipe + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + + dx_ref: Optional[torch.Tensor] = None + dw_ref: list[torch.Tensor] = [] + db_ref: list[Optional[torch.Tensor]] = [] + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + if len(saved_operands) < 2 * num_gemms: + raise RuntimeError( + "Insufficient saved operands for GroupedLinear dequant reference " + f"(got {len(saved_operands)}, expected at least {2 * num_gemms})." + ) + + saved_inputs = saved_operands[:num_gemms] + saved_weights = saved_operands[num_gemms : 2 * num_gemms] + for i, (saved_input, saved_weight) in enumerate(zip(saved_inputs, saved_weights)): + guard_operands.extend( + [ + (f"grouped_input{i}", saved_input), + (f"grouped_weight{i}", saved_weight), + ] + ) + dy_chunks = torch.split(dy, m_splits) + + dx_chunks = [] + dw_ref = [] + db_ref = [] + for dy_chunk, saved_input, saved_weight in zip(dy_chunks, saved_inputs, saved_weights): + dx_i, dw_i, db_i = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy_chunk, + dequant_dtype=dtype, + out_dtype=dtype, + ) + dx_chunks.append(dx_i) + dw_ref.append(dw_i) + db_ref.append(db_i if use_bias else None) + dx_ref = torch.cat(dx_chunks, dim=0) + except Exception as exc: # pylint: disable=broad-exception-caught + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x_bwd_mode.grad is not None + dx_bwd_mode = x_bwd_mode.grad.detach().clone() + dw_bwd_mode = [ + getattr(module_bwd_mode, f"weight{i}").grad.detach().clone() + for i in range(module_bwd_mode.num_gemms) + ] + db_bwd_mode = [] + for i in range(module_bwd_mode.num_gemms): + if module_bwd_mode.use_bias: + db_bwd_mode.append(getattr(module_bwd_mode, f"bias{i}").grad.detach().clone()) + else: + db_bwd_mode.append(None) + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx_ref is not None + + _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) + _assert_exact(dx_bwd_mode, dx_ref) + for test_dw, ref_dw in zip(dw_bwd_mode, dw_ref): + _assert_exact(test_dw, ref_dw) + if use_bias: + for test_db, ref_db_i in zip(db_bwd_mode, db_ref): + assert test_db is not None + assert ref_db_i is not None + _assert_exact(test_db, ref_db_i) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize( + "fused_pattern,expected_fused_op", + ( + ("bias_add", ForwardLinearBiasAdd), + ("scale_add", ForwardLinearScaleAdd), + ), +) +@pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) +@pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) +def test_fused_linear_paths_match_backward_mode_reference( + recipe_name: str, + fused_pattern: str, + expected_fused_op: type, + m: int, + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, (m, 64), "ops_linear") + + reset_rng_states() + in_features = 64 + out_features = 64 + + quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_bwd_mode = _make_fused_model(fused_pattern, in_features, out_features, dtype) + _copy_named_parameters(model_quantized_ref, model_bwd_mode) + + x1 = torch.randn(m, in_features, dtype=dtype, device="cuda") + x2 = None + if fused_pattern in ("bias_add", "scale_add"): + x2 = torch.randn(m, out_features, dtype=dtype, device="cuda") + dy = torch.randn(m, out_features, dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + fused_pattern, + model_quantized_ref, + x1, + dy, + quantized_ref_recipe, + x2=x2, + ) + + if backward_mode == "unquant": + # Unquant reference path: replay the same fused model structure in plain + # high precision and compare backward outputs exactly. + model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + _copy_named_parameters(model_quantized_ref, model_unquantized_ref) + + y_bwd_mode, dx1_bwd_mode, dx2_bwd_mode, dw_bwd_mode, db_bwd_mode = _run_fused_single_step( + fused_pattern, + model_bwd_mode, + x1, + dy, + mode_recipe, + x2=x2, + ) + _, dx1_ref, dx2_ref, dw_ref, db_ref = _run_fused_single_step( + fused_pattern, + model_unquantized_ref, + x1, + dy, + None, + x2=x2, + ) + else: + # Dequant reference path: compute backward reference from saved quantized + # linear operands (with branch-specific dy handling for fused epilogues). + y_bwd_mode, x1_bwd_mode, x2_bwd_mode_ref, saved_operands = ( + _run_fused_single_step_with_saved_operands( + fused_pattern, + model_bwd_mode, + x1, + mode_recipe, + x2=x2, + ) + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + dx1_ref: Optional[torch.Tensor] = None + dx2_ref: Optional[torch.Tensor] = None + dw_ref: Optional[torch.Tensor] = None + db_ref: Optional[torch.Tensor] = None + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + saved_input, saved_weight = _extract_linear_saved_operands( + saved_operands, + context=f"fused_{fused_pattern}", + ) + guard_operands.extend( + [ + (f"fused_{fused_pattern}_input", saved_input), + (f"fused_{fused_pattern}_weight", saved_weight), + ] + ) + dy_for_linear = dy * 0.5 if fused_pattern == "scale_add" else dy + dx1_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy_for_linear, + dequant_dtype=dtype, + out_dtype=dtype, + ) + dx2_ref = dy if x2 is not None else None + except Exception as exc: # pylint: disable=broad-exception-caught + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x1_bwd_mode.grad is not None + dx1_bwd_mode = x1_bwd_mode.grad.detach().clone() + dx2_bwd_mode = ( + x2_bwd_mode_ref.grad.detach().clone() + if x2_bwd_mode_ref is not None and x2_bwd_mode_ref.grad is not None + else None + ) + dw_bwd_mode = model_bwd_mode[0].weight.grad.detach().clone() + db_bwd_mode = None + if ( + getattr(model_bwd_mode[0], "bias", None) is not None + and model_bwd_mode[0].bias.grad is not None + ): + db_bwd_mode = model_bwd_mode[0].bias.grad.detach().clone() + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx1_ref is not None and dw_ref is not None + + fused_ops = model_bwd_mode._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], expected_fused_op) + + _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) + _assert_exact(dx1_bwd_mode, dx1_ref) + _assert_exact(dw_bwd_mode, dw_ref) + if dx2_bwd_mode is not None and dx2_ref is not None: + _assert_exact(dx2_bwd_mode, dx2_ref) + if db_bwd_mode is not None and db_ref is not None: + _assert_exact(db_bwd_mode, db_ref) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) +@pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) +def test_fused_bias_activation_matches_masked_linear_backward( + recipe_name: str, + input_shape: tuple[int, ...], + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + in_features = input_shape[-1] + out_features = 64 + + quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) + model_bwd_mode = _make_fused_model("bias_activation", in_features, out_features, dtype) + _copy_named_parameters(model_quantized_ref, model_bwd_mode) + + x1 = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*((*x1.shape[:-1], out_features)), dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + "bias_activation", + model_quantized_ref, + x1, + dy, + quantized_ref_recipe, + ) + + if backward_mode == "unquant": + # Unquant reference path: build a plain linear reference and apply the + # same activation mask (from quantized forward output) before backward. + linear_unquantized_ref = _make_linear_like_module( + "ops_linear", + in_features, + out_features, + dtype, + bias=True, + ) + _copy_named_parameters(model_bwd_mode[0], linear_unquantized_ref) + + y_bwd_mode, dx1_bwd_mode, _, dw_bwd_mode, db_bwd_mode = _run_fused_single_step( + "bias_activation", + model_bwd_mode, + x1, + dy, + mode_recipe, + ) + dy_after_activation = dy * (y_bwd_mode > 0).to(dy.dtype) + _, dx1_ref, dw_ref, db_ref = _run_single_step( + linear_unquantized_ref, + x1, + dy_after_activation, + None, + ) + else: + # Dequant reference path: restore saved linear operands from fused forward, + # apply the same activation mask, then run linear backward reference. + y_bwd_mode, x1_bwd_mode, _, saved_operands = _run_fused_single_step_with_saved_operands( + "bias_activation", + model_bwd_mode, + x1, + mode_recipe, + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + dy_after_activation = dy * (y_bwd_mode > 0).to(dy.dtype) + dx1_ref: Optional[torch.Tensor] = None + dw_ref: Optional[torch.Tensor] = None + db_ref: Optional[torch.Tensor] = None + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + saved_input, saved_weight = _extract_linear_saved_operands( + saved_operands, + context="fused_bias_activation", + ) + guard_operands.extend( + [ + ("fused_bias_activation_input", saved_input), + ("fused_bias_activation_weight", saved_weight), + ] + ) + dx1_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy_after_activation, + dequant_dtype=dtype, + out_dtype=dtype, + ) + except Exception as exc: # pylint: disable=broad-exception-caught + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x1_bwd_mode.grad is not None + dx1_bwd_mode = x1_bwd_mode.grad.detach().clone() + dw_bwd_mode = model_bwd_mode[0].weight.grad.detach().clone() + db_bwd_mode = ( + model_bwd_mode[0].bias.grad.detach().clone() + if model_bwd_mode[0].bias.grad is not None + else None + ) + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx1_ref is not None and dw_ref is not None and db_ref is not None + + fused_ops = model_bwd_mode._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) + + # In unquant/dequant modes, backward-activation+bias fusion should be disabled. + bwd_mode_backward_ops = model_bwd_mode._module_groups[0]._backward_ops + assert not any(isinstance(op, BackwardActivationBias) for op, _ in bwd_mode_backward_ops) + + # Quantized reference should still use fused backward path. + quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops + assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) + + _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) + _assert_exact(dx1_bwd_mode, dx1_ref) + _assert_exact(dw_bwd_mode, dw_ref) + assert db_bwd_mode is not None + assert db_ref is not None + _assert_exact(db_bwd_mode, db_ref) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( + recipe_name: str, + dtype: torch.dtype, + backward_mode: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Simulate a distributed setup to exercise Userbuffers fusion eligibility + # without launching a multi-rank job. + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) + monkeypatch.setattr(torch.distributed, "get_world_size", lambda *_args, **_kwargs: 2) + + # Use a mutable recipe holder so we can switch fusion behavior on the same + # fuser object and verify that the cached fusion plan is refreshed. + current_recipe = {"value": _make_recipe(recipe_name, backward_mode="default")} + monkeypatch.setattr(FP8GlobalStateManager, "get_fp8_recipe", lambda: current_recipe["value"]) + + reset_rng_states() + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + fuser, x, extra_inputs = _make_userbuffers_fuser_for_mode_switch_test(dtype=dtype) + + quant_recipe = _make_recipe(recipe_name, backward_mode="default") + fuser.maybe_fuse_ops( + is_grad_enabled=True, + recipe=quant_recipe, + input_=x, + extra_inputs=extra_inputs, + ) + assert _has_userbuffers_forward_linear(fuser) + + non_quant_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + current_recipe["value"] = non_quant_recipe + fuser.maybe_fuse_ops( + is_grad_enabled=True, + recipe=non_quant_recipe, + input_=x, + extra_inputs=extra_inputs, + ) + assert not _has_userbuffers_forward_linear(fuser) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_quantize_op_respects_backward_mode( + recipe_name: str, + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + reset_rng_states() + + x = torch.randn(32, 64, dtype=dtype, device="cuda") + dy = torch.randn(32, 64, dtype=dtype, device="cuda") + + model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=False)) + + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, mode_recipe) + y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, mode_recipe) + + _assert_exact(y_override, y_ref) + _assert_exact(dx_override, dx_ref) + + +def test_delayed_scaling_rejects_non_quant_backward_mode(backward_mode: str) -> None: + with pytest.raises( + (AssertionError, ValueError), + match="Delayed scaling only supports backward_mode=default", + ): + _ = recipe.DelayedScaling(backward_mode=backward_mode) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) +def test_layernorm_mlp_not_implemented_for_unquantized_backward_mode( + recipe_name: str, + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + reset_rng_states() + + layer = te.LayerNormMLP( + hidden_size=64, + ffn_hidden_size=64, + params_dtype=dtype, + bias=False, + device="cuda", + ) + x = torch.randn(32, 64, dtype=dtype, device="cuda", requires_grad=True) + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + with pytest.raises( + AssertionError, + match="NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP", + ): + with te.autocast(enabled=True, recipe=mode_recipe): + _ = layer(x) diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py deleted file mode 100644 index f5c3339a71..0000000000 --- a/tests/pytorch/test_keep_backward_unquantized.py +++ /dev/null @@ -1,756 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -from __future__ import annotations - -from contextlib import nullcontext -import math -import os -from typing import Optional - -import pytest -import torch - -import transformer_engine.pytorch as te -import transformer_engine.pytorch.ops as te_ops -from transformer_engine.common import recipe -from transformer_engine.pytorch.ops.fused import ( - BackwardActivationBias, - ForwardLinearBiasActivation, - ForwardLinearBiasAdd, - ForwardLinearScaleAdd, -) - -from utils import quantization_tols, reset_rng_states - - -fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) -mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( - return_reason=True -) -nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) - -# This file is intended to run in dedicated keep-backward-unquantized mode. -pytestmark = pytest.mark.skipif( - os.environ.get("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") != "1", - reason="Requires NVTE_KEEP_BACKWARD_UNQUANTIZED=1", -) - - -_quantized_numerics_recipe_list = [ - pytest.param( - "fp8_current_scaling", - marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), - id="Float8CurrentScaling", - ), - pytest.param( - "mxfp8", - marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), - id="MXFP8BlockScaling", - ), - pytest.param( - "fp8_block_scaling", - marks=pytest.mark.skipif( - not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling - ), - id="Float8BlockScaling", - ), - pytest.param( - "nvfp4", - marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), - id="NVFP4BlockScaling", - ), -] - -_shape_test_cases = [ - pytest.param((1, 64), 64, id="2d_m1_k64_n64"), - pytest.param((32, 64), 64, id="2d_m32_k64_n64"), - pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), - pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), - pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), -] - -_bias_activation_shape_cases = [ - pytest.param((32, 64), id="2d_m32_k64"), - pytest.param((8, 4, 64), id="3d_m32_k64"), -] - - -def _make_recipe(recipe_name: str, quantize_backward: Optional[bool]) -> recipe.Recipe: - kwargs = {} - if quantize_backward is not None: - kwargs = {"quantize_forward": True, "quantize_backward": quantize_backward} - - if recipe_name == "fp8_current_scaling": - return recipe.Float8CurrentScaling(fp8_format=recipe.Format.E4M3, **kwargs) - if recipe_name == "mxfp8": - return recipe.MXFP8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) - if recipe_name == "fp8_block_scaling": - return recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) - if recipe_name == "nvfp4": - return recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - **kwargs, - ) - - raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") - - -def _build_keep_backward_unquantized_recipe(recipe_name: str) -> recipe.Recipe: - fp8_recipe = _make_recipe(recipe_name, quantize_backward=None) - assert fp8_recipe.quantize_forward - assert not fp8_recipe.quantize_backward - return fp8_recipe - - -def _build_quantized_reference_recipe(recipe_name: str) -> recipe.Recipe: - return _make_recipe(recipe_name, quantize_backward=True) - - -def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: - src_params = dict(src_module.named_parameters()) - with torch.no_grad(): - for name, dst_param in dst_module.named_parameters(): - if name not in src_params: - raise RuntimeError(f"Parameter {name} missing in source module") - dst_param.copy_(src_params[name]) - - -def _fprop_tolerances(recipe_name: str) -> dict[str, float]: - if recipe_name == "mxfp8": - return quantization_tols("mxfp8") - if recipe_name in ("fp8_current_scaling", "fp8_block_scaling"): - return quantization_tols("fp8_current_scaling") - if recipe_name == "nvfp4": - return quantization_tols("nvfp4") - raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") - - -def _make_linear_like_module( - module_type: str, - in_features: int, - out_features: int, - dtype: torch.dtype, - bias: bool = False, -) -> torch.nn.Module: - if module_type == "linear": - return te.Linear( - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device="cuda", - ) - if module_type == "layernorm_linear": - return te.LayerNormLinear( - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device="cuda", - ) - if module_type == "ops_linear": - return te_ops.Linear( - in_features, - out_features, - bias=bias, - dtype=dtype, - device="cuda", - ) - raise ValueError(f"Unsupported module type: {module_type}") - - -def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: - if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": - pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - - -def _maybe_skip_unsupported_recipe_shape( - recipe_name: str, - input_shape: tuple[int, ...], - module_type: str, -) -> None: - flat_first_dim = math.prod(input_shape[:-1]) - last_dim = input_shape[-1] - - # TE Linear / LayerNormLinear FP8 kernels require FP8-GEMM-compatible dimensions. - if module_type in ("linear", "layernorm_linear"): - if flat_first_dim % 8 != 0 or last_dim % 16 != 0: - pytest.skip( - "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " - "and shape[-1] divisible by 16." - ) - return - - # te_ops.Linear (fusible ops) has stricter constraints for some block-scaled recipes. - if module_type == "ops_linear": - if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): - pytest.skip( - "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." - ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): - pytest.skip( - "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." - ) - - -def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: - # Grouped GEMM paths enforce additional split-alignment constraints for block-scaled recipes. - non_empty_splits = [m for m in m_splits if m > 0] - if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): - pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): - pytest.skip( - "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." - ) - - -def _run_single_step( - module: torch.nn.Module, - x: torch.Tensor, - dy: torch.Tensor, - fp8_recipe: Optional[recipe.Recipe], -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - module.zero_grad(set_to_none=True) - x_run = x.detach().clone().requires_grad_(True) - autocast_ctx = ( - te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() - ) - with autocast_ctx: - y = module(x_run) - if isinstance(y, tuple): - y = y[0] - y.backward(dy) - assert x_run.grad is not None - assert module.weight.grad is not None - return ( - y.detach().clone(), - x_run.grad.detach().clone(), - module.weight.grad.detach().clone(), - ) - - -def _extract_bias_grad(module: torch.nn.Module) -> Optional[torch.Tensor]: - bias = getattr(module, "bias", None) - if bias is None or bias.grad is None: - return None - return bias.grad.detach().clone() - - -def _run_grouped_linear_single_step( - module: te.GroupedLinear, - x: torch.Tensor, - m_splits: list[int], - dy: torch.Tensor, - fp8_recipe: Optional[recipe.Recipe], -) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[Optional[torch.Tensor]]]: - module.zero_grad(set_to_none=True) - x_run = x.detach().clone().requires_grad_(True) - autocast_ctx = ( - te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() - ) - with autocast_ctx: - y = module(x_run, m_splits) - y.backward(dy) - assert x_run.grad is not None - weight_grads = [ - getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms) - ] - bias_grads: list[Optional[torch.Tensor]] = [] - for i in range(module.num_gemms): - if module.use_bias: - bias_grads.append(getattr(module, f"bias{i}").grad.detach().clone()) - else: - bias_grads.append(None) - return y.detach().clone(), x_run.grad.detach().clone(), weight_grads, bias_grads - - -def _make_fused_model( - pattern: str, - in_features: int, - out_features: int, - dtype: torch.dtype, - scale: float = 0.5, -) -> te_ops.Sequential: - if pattern == "bias_activation": - return te_ops.Sequential( - te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), - te_ops.ReLU(), - ) - if pattern == "bias_add": - return te_ops.Sequential( - te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), - te_ops.AddExtraInput(in_place=True), - ) - if pattern == "scale_add": - return te_ops.Sequential( - te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), - te_ops.ConstantScale(scale), - te_ops.AddExtraInput(in_place=True), - ) - raise ValueError(f"Unsupported fused test pattern: {pattern}") - - -def _run_fused_single_step( - pattern: str, - model: te_ops.Sequential, - x1: torch.Tensor, - dy: torch.Tensor, - fp8_recipe: Optional[recipe.Recipe], - x2: Optional[torch.Tensor] = None, -) -> tuple[ - torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor] -]: - model.zero_grad(set_to_none=True) - x1_run = x1.detach().clone().requires_grad_(True) - x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None - autocast_ctx = ( - te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() - ) - with autocast_ctx: - if pattern in ("bias_add", "scale_add"): - assert x2_run is not None - y = model(x1_run, x2_run) - else: - y = model(x1_run) - y.backward(dy) - assert x1_run.grad is not None - weight_grad = model[0].weight.grad.detach().clone() - bias_grad = None - if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: - bias_grad = model[0].bias.grad.detach().clone() - x2_grad = ( - x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None - ) - return y.detach().clone(), x1_run.grad.detach().clone(), x2_grad, weight_grad, bias_grad - - -def _run_quantize_op_single_step( - model: te_ops.Sequential, - x: torch.Tensor, - dy: torch.Tensor, - fp8_recipe: Optional[recipe.Recipe], -) -> tuple[torch.Tensor, torch.Tensor]: - x_run = x.detach().clone().requires_grad_(True) - autocast_ctx = ( - te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() - ) - with autocast_ctx: - y = model(x_run) - y.backward(dy) - assert x_run.grad is not None - return y.detach().clone(), x_run.grad.detach().clone() - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -def test_keep_backward_unquantized_recipe_defaults(recipe_name: str): - _ = _build_keep_backward_unquantized_recipe(recipe_name) - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -@pytest.mark.parametrize( - "module_type", - ("linear", "layernorm_linear", "ops_linear"), -) -@pytest.mark.parametrize( - "input_shape,out_features", - _shape_test_cases, -) -@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) -def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads( - recipe_name: str, - module_type: str, - input_shape: tuple[int, ...], - out_features: int, - use_bias: bool, -): - reset_rng_states() - _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) - _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) - dtype = torch.bfloat16 - in_features = input_shape[-1] - - module_quantized_ref = _make_linear_like_module( - module_type, in_features, out_features, dtype, bias=use_bias - ) - module_keep_bwd_hp = _make_linear_like_module( - module_type, in_features, out_features, dtype, bias=use_bias - ) - module_unquantized_ref = _make_linear_like_module( - module_type, in_features, out_features, dtype, bias=use_bias - ) - - # Start all runs from identical parameters. - _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) - _copy_named_parameters(module_quantized_ref, module_unquantized_ref) - - output_shape = input_shape[:-1] + (out_features,) - x = torch.randn(*input_shape, dtype=dtype, device="cuda") - dy = torch.randn(*output_shape, dtype=dtype, device="cuda") - - quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) - - y_quantized_ref, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) - y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp = _run_single_step( - module_keep_bwd_hp, x, dy, keep_bwd_hp_recipe - ) - _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step( - module_unquantized_ref, x, dy, None - ) - - # Forward pass should still match quantized reference when only backward is unquantized. - torch.testing.assert_close( - y_keep_bwd_hp, - y_quantized_ref, - **_fprop_tolerances(recipe_name), - ) - - # Backward pass should match unquantized reference for dgrad and wgrad. - torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) - torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) - if use_bias: - bgrad_keep = _extract_bias_grad(module_keep_bwd_hp) - bgrad_unquantized = _extract_bias_grad(module_unquantized_ref) - assert bgrad_keep is not None - assert bgrad_unquantized is not None - torch.testing.assert_close(bgrad_keep, bgrad_unquantized, rtol=0, atol=0) - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) -@pytest.mark.parametrize( - "m_splits", - ([32, 32, 32, 32], [64, 0, 32, 32], [1, 31, 0, 96]), - ids=("uniform_splits", "with_empty_split", "small_and_empty_splits"), -) -def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_unquantized_grads( - recipe_name: str, - use_bias: bool, - m_splits: list[int], -): - if recipe_name == "nvfp4": - pytest.skip("NVFP4 not supported for grouped linear") - _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) - - reset_rng_states() - dtype = torch.bfloat16 - in_features = 64 - out_features = 64 - num_gemms = len(m_splits) - num_tokens = sum(m_splits) - - module_quantized_ref = te.GroupedLinear( - num_gemms, - in_features, - out_features, - bias=use_bias, - params_dtype=dtype, - device="cuda", - ) - module_keep_bwd_hp = te.GroupedLinear( - num_gemms, - in_features, - out_features, - bias=use_bias, - params_dtype=dtype, - device="cuda", - ) - module_unquantized_ref = te.GroupedLinear( - num_gemms, - in_features, - out_features, - bias=use_bias, - params_dtype=dtype, - device="cuda", - ) - - _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) - _copy_named_parameters(module_quantized_ref, module_unquantized_ref) - - x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") - dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") - - quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) - - y_quantized_ref, _, _, _ = _run_grouped_linear_single_step( - module_quantized_ref, x, m_splits, dy, quantized_ref_recipe - ) - y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = _run_grouped_linear_single_step( - module_keep_bwd_hp, x, m_splits, dy, keep_bwd_hp_recipe - ) - _, dx_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = _run_grouped_linear_single_step( - module_unquantized_ref, x, m_splits, dy, None - ) - - torch.testing.assert_close( - y_keep_bwd_hp, - y_quantized_ref, - **_fprop_tolerances(recipe_name), - ) - torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) - for test_dw, ref_dw in zip(dw_keep_bwd_hp, dw_unquantized_ref): - torch.testing.assert_close(test_dw, ref_dw, rtol=0, atol=0) - if use_bias: - for test_db, ref_db in zip(db_keep_bwd_hp, db_unquantized_ref): - assert test_db is not None - assert ref_db is not None - torch.testing.assert_close(test_db, ref_db, rtol=0, atol=0) - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -@pytest.mark.parametrize( - "fused_pattern,expected_fused_op", - ( - ("bias_add", ForwardLinearBiasAdd), - ("scale_add", ForwardLinearScaleAdd), - ), -) -@pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) -def test_keep_backward_unquantized_fused_linear_paths( - recipe_name: str, - fused_pattern: str, - expected_fused_op: type, - m: int, -): - # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. - _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") - - reset_rng_states() - dtype = torch.bfloat16 - in_features = 64 - out_features = 64 - _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") - model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) - model_keep_bwd_hp = _make_fused_model(fused_pattern, in_features, out_features, dtype) - model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) - - _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) - _copy_named_parameters(model_quantized_ref, model_unquantized_ref) - - x1 = torch.randn(m, in_features, dtype=dtype, device="cuda") - x2 = None - if fused_pattern in ("bias_add", "scale_add"): - x2 = torch.randn(m, out_features, dtype=dtype, device="cuda") - dy = torch.randn(m, out_features, dtype=dtype, device="cuda") - - quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) - - y_quantized_ref, _, _, _, _ = _run_fused_single_step( - fused_pattern, model_quantized_ref, x1, dy, quantized_ref_recipe, x2=x2 - ) - y_keep_bwd_hp, dx1_keep_bwd_hp, dx2_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = ( - _run_fused_single_step( - fused_pattern, - model_keep_bwd_hp, - x1, - dy, - keep_bwd_hp_recipe, - x2=x2, - ) - ) - _, dx1_unquantized_ref, dx2_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = ( - _run_fused_single_step( - fused_pattern, - model_unquantized_ref, - x1, - dy, - None, - x2=x2, - ) - ) - - # Ensure this test executes the fused path changed by the keep-bwd feature. - fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops - assert len(fused_ops) >= 1 - assert isinstance(fused_ops[0][0], expected_fused_op) - - torch.testing.assert_close( - y_keep_bwd_hp, - y_quantized_ref, - **_fprop_tolerances(recipe_name), - ) - torch.testing.assert_close(dx1_keep_bwd_hp, dx1_unquantized_ref, rtol=0, atol=0) - torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) - if dx2_keep_bwd_hp is not None and dx2_unquantized_ref is not None: - torch.testing.assert_close(dx2_keep_bwd_hp, dx2_unquantized_ref, rtol=0, atol=0) - if db_keep_bwd_hp is not None and db_unquantized_ref is not None: - torch.testing.assert_close(db_keep_bwd_hp, db_unquantized_ref, rtol=0, atol=0) - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -@pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) -def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_backward( - recipe_name: str, - input_shape: tuple[int, ...], -): - # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. - _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") - - reset_rng_states() - dtype = torch.bfloat16 - in_features = input_shape[-1] - out_features = 64 - - model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) - model_keep_bwd_hp = _make_fused_model("bias_activation", in_features, out_features, dtype) - linear_unquantized_ref = _make_linear_like_module( - "ops_linear", in_features, out_features, dtype, bias=True - ) - - _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) - _copy_named_parameters(model_keep_bwd_hp[0], linear_unquantized_ref) - - x1 = torch.randn(*input_shape, dtype=dtype, device="cuda") - out_shape = x1.shape[:-1] + (out_features,) - dy = torch.randn(*out_shape, dtype=dtype, device="cuda") - - quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) - - y_quantized_ref, _, _, _, _ = _run_fused_single_step( - "bias_activation", model_quantized_ref, x1, dy, quantized_ref_recipe - ) - y_keep_bwd_hp, dx1_keep_bwd_hp, _, dw_keep_bwd_hp, db_keep_bwd_hp = _run_fused_single_step( - "bias_activation", model_keep_bwd_hp, x1, dy, keep_bwd_hp_recipe - ) - - # Ensure this test executes the fused path changed by the keep-bwd feature. - fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops - assert len(fused_ops) >= 1 - assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) - - # keep-bwd mode should disable backward-activation+bias fusion, while quantized - # reference should still use it. - keep_bwd_backward_ops = model_keep_bwd_hp._module_groups[0]._backward_ops - assert not any(isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops) - quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops - assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) - - torch.testing.assert_close( - y_keep_bwd_hp, - y_quantized_ref, - **_fprop_tolerances(recipe_name), - ) - - # In keep-backward-unquantized mode, backward should behave as high-precision linear backward - # given the ReLU mask induced by quantized forward activations. - dy_after_activation = dy * (y_keep_bwd_hp > 0).to(dy.dtype) - _, dx1_expected, dw_expected = _run_single_step( - linear_unquantized_ref, x1, dy_after_activation, None - ) - db_expected = _extract_bias_grad(linear_unquantized_ref) - assert db_keep_bwd_hp is not None - assert db_expected is not None - - torch.testing.assert_close(dx1_keep_bwd_hp, dx1_expected, rtol=0, atol=0) - torch.testing.assert_close(dw_keep_bwd_hp, dw_expected, rtol=0, atol=0) - torch.testing.assert_close(db_keep_bwd_hp, db_expected, rtol=0, atol=0) - - -def test_keep_backward_unquantized_autocast_respects_quantize_forward_flag(): - reset_rng_states() - dtype = torch.bfloat16 - in_features = 64 - out_features = 64 - - module_quantization_disabled = _make_linear_like_module( - "linear", in_features, out_features, dtype, bias=True - ) - module_unquantized_ref = _make_linear_like_module( - "linear", in_features, out_features, dtype, bias=True - ) - _copy_named_parameters(module_quantization_disabled, module_unquantized_ref) - - x = torch.randn(32, in_features, dtype=dtype, device="cuda") - dy = torch.randn(32, out_features, dtype=dtype, device="cuda") - - recipe_no_fwd_quant = recipe.Float8CurrentScaling( - fp8_format=recipe.Format.E4M3, - quantize_forward=False, - quantize_backward=False, - ) - - y_test, dx_test, dw_test = _run_single_step( - module_quantization_disabled, x, dy, recipe_no_fwd_quant - ) - y_ref, dx_ref, dw_ref = _run_single_step(module_unquantized_ref, x, dy, None) - - torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0) - torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0) - torch.testing.assert_close(dw_test, dw_ref, rtol=0, atol=0) - bgrad_test = _extract_bias_grad(module_quantization_disabled) - bgrad_ref = _extract_bias_grad(module_unquantized_ref) - assert bgrad_test is not None - assert bgrad_ref is not None - torch.testing.assert_close(bgrad_test, bgrad_ref, rtol=0, atol=0) - - -def test_keep_backward_unquantized_quantize_op_respects_recipe_overrides(): - reset_rng_states() - dtype = torch.bfloat16 - x = torch.randn(32, 64, dtype=dtype, device="cuda") - dy = torch.randn(32, 64, dtype=dtype, device="cuda") - - model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) - model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) - - recipe_no_quant = recipe.Float8CurrentScaling( - fp8_format=recipe.Format.E4M3, - quantize_forward=False, - quantize_backward=False, - ) - y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, recipe_no_quant) - y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, None) - - torch.testing.assert_close(y_override, y_ref, rtol=0, atol=0) - torch.testing.assert_close(dx_override, dx_ref, rtol=0, atol=0) - - -def test_keep_backward_unquantized_is_invalid_for_delayed_scaling(): - with pytest.raises( - (AssertionError, ValueError), - match="Delayed scaling does not support quantize_backward=False", - ): - _ = recipe.DelayedScaling() - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -def test_keep_backward_unquantized_not_implemented_for_layernorm_mlp(): - reset_rng_states() - layer = te.LayerNormMLP( - hidden_size=64, - ffn_hidden_size=64, - params_dtype=torch.bfloat16, - bias=False, - device="cuda", - ) - x = torch.randn(32, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe("fp8_current_scaling") - - with pytest.raises( - AssertionError, match="NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" - ): - with te.autocast(enabled=True, recipe=keep_bwd_hp_recipe): - _ = layer(x) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 46a19652f1..9058f155c4 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,6 +11,20 @@ from pydantic.dataclasses import dataclass +_BACKWARD_MODES = ("default", "unquant", "dequant") + + +def _resolve_backward_mode(mode: Optional[str] = None) -> str: + """Return validated backward mode from argument or NVTE_BACKWARD_MODE env.""" + if mode is None: + mode = os.getenv("NVTE_BACKWARD_MODE", "default") + mode = mode.lower() + assert ( + mode in _BACKWARD_MODES + ), f"Invalid NVTE_BACKWARD_MODE value {mode!r}. Supported values are: default|unquant|dequant." + return mode + + class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -188,11 +202,8 @@ def scaling_factor_compute(amax: Tensor, `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. Delayed scaling - always quantizes backward; setting this to False is not supported. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. Delayed scaling only supports `default`. Notes ----- @@ -216,15 +227,14 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." + assert ( + self.backward_mode == "default" + ), "Delayed scaling only supports backward_mode=default." def __repr__(self) -> str: return ( @@ -235,8 +245,7 @@ def __repr__(self) -> str: f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) @@ -250,10 +259,11 @@ class Float8CurrentScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" @@ -266,14 +276,11 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -287,8 +294,7 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) @@ -315,32 +321,29 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) @@ -369,10 +372,11 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -389,10 +393,10 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" @@ -412,9 +416,6 @@ def __post_init__(self) -> None: not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -431,8 +432,7 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) @@ -481,10 +481,11 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ # Configuration envvars @@ -500,15 +501,12 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." # Quantization params # Note: RHT is currently only applied to column-wise usage so that @@ -536,8 +534,7 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}, " + f"backward_mode={self.backward_mode}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " @@ -569,23 +566,25 @@ class CustomRecipe(Recipe): - forward: "linear_input", "linear_weight", "linear_output" - backward: "linear_grad_output", "linear_grad_input" - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ qfactory: Callable[..., Any] fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) + + def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"qfactory={self.qfactory}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2ecd5f77c8..2ca1f1ace2 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1184,7 +1184,7 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - use_fp8_bwd = ctx.fp8 and not ctx.keep_backward_unquantized + use_fp8_bwd = ctx.fp8 and ctx.backward_mode == "default" # Non-FP8 case: bgrad is fused with wgrad for this case. if not use_fp8_bwd and not ctx.debug: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b8ebd521a9..8126951ad7 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -97,11 +97,12 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) - if keep_backward_unquantized: - # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + if backward_mode == "unquant": + # Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used. save_original_input = True num_gemms = len(m_splits) @@ -118,10 +119,15 @@ def forward( input_quantizer.set_usage( rowwise=True, columnwise=( - is_grad_enabled and weight_requires_grad and not save_original_input + is_grad_enabled + and weight_requires_grad + and not save_original_input + and backward_mode == "default" ), ) columnwise_usage = is_grad_enabled and inp.requires_grad + if backward_mode in ("unquant", "dequant"): + columnwise_usage = False if not columnwise_usage: columnwise_usage = ( is_fp8_activation_recompute_enabled() @@ -246,7 +252,12 @@ def forward( else: for inputmat in inputmats: if isinstance(inputmat, QuantizedTensorStorage): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if backward_mode in ("unquant", "dequant"): + # In dequant mode we should dequantize directly from + # fprop quantized layouts without retargeting usage. + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + else: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms @@ -297,7 +308,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized + ctx.backward_mode = backward_mode ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -316,9 +327,9 @@ def forward( ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - # keep_backward_unquantized overrides - if keep_backward_unquantized: - ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + # Non-quantized backward mode overrides + if backward_mode in ("unquant", "dequant"): + ctx.fp8 = False ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -422,7 +433,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) weights_for_dgrad = weights - if ctx.keep_backward_unquantized: + if ctx.backward_mode == "dequant": + weights_for_dgrad = [ + ( + weight.dequantize(dtype=ctx.activation_dtype) + if isinstance(weight, QuantizedTensorStorage) + else cast_if_needed(weight, ctx.activation_dtype) + ) + for weight in weights + ] + elif ctx.backward_mode == "unquant": weights_for_dgrad = origin_weights # Make sure weights are available in column-wise format # for dgrad computation. @@ -485,6 +505,30 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = torch.split( cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits ) + elif ctx.backward_mode == "dequant": + inputmats_dequant = [] + for m_split, inputmat in zip(ctx.m_splits, inputmats): + if isinstance(inputmat, QuantizedTensorStorage): + if m_split == 0: + # Dequant kernels for some quantized storage formats + # (e.g. MXFP8/Float8BlockScaling) do not accept empty + # M-dimension inputs. For empty grouped splits, materialize + # an explicit empty high-precision matrix instead of invoking + # dequantize(). + inputmats_dequant.append( + torch.empty( + (0, ctx.weights_shape_1), + dtype=ctx.activation_dtype, + device=ctx.device, + ) + ) + else: + inputmats_dequant.append( + inputmat.dequantize(dtype=ctx.activation_dtype) + ) + else: + inputmats_dequant.append(cast_if_needed(inputmat, ctx.activation_dtype)) + inputmats = inputmats_dequant grouped_gemm_wgrad = functools.partial( general_grouped_gemm, quantization_params=ctx.grad_weight_quantizers, @@ -1094,6 +1138,12 @@ def _get_quantizers(self): for i in range(self.num_gemms): grad_output_quantizers[i].internal = True grad_output_quantizers[i].optimize_for_gemm = True + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + for input_quantizer in input_quantizers: + input_quantizer.optimize_for_gemm = False + for grad_output_quantizer in grad_output_quantizers: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizers, weight_quantizers, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 10dbb8b0f8..75d7802143 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -140,9 +140,10 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -203,7 +204,7 @@ def forward( raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage( rowwise=True, - columnwise=backward_needs_input and not keep_backward_unquantized, + columnwise=backward_needs_input and backward_mode == "default", ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data @@ -217,7 +218,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not keep_backward_unquantized + and backward_mode == "default" and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -241,7 +242,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out - ln_out_hp = ln_out if keep_backward_unquantized else None + ln_out_hp = ln_out if backward_mode == "unquant" else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -303,7 +304,10 @@ def forward( if is_weight_param_quantized and not debug: weight_quantizer = weight._quantizer elif weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + weight_quantizer.set_usage( + rowwise=True, + columnwise=is_grad_enabled and backward_mode == "default", + ) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -417,7 +421,7 @@ def forward( if is_grad_enabled: ln_out_to_save = ln_out - if keep_backward_unquantized: + if backward_mode == "unquant": ln_out_to_save = ln_out_hp ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( @@ -425,7 +429,7 @@ def forward( ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input and not keep_backward_unquantized: + if backward_needs_input and backward_mode == "default": if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -503,7 +507,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized + ctx.backward_mode = backward_mode ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -534,9 +538,9 @@ def forward( ctx.wgrad_store = wgrad_store ctx.debug = debug - # keep_backward_unquantized overrides - if keep_backward_unquantized: - ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + # Non-quantized backward mode overrides + if backward_mode in ("unquant", "dequant"): + ctx.fp8 = False ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -685,6 +689,11 @@ def backward( # -------------------------------------------------- ln_out_total = None ln_out_total_work = None + if ctx.backward_mode == "dequant": + if isinstance(ln_out, QuantizedTensorStorage): + ln_out = ln_out.dequantize(dtype=ctx.activation_dtype) + else: + ln_out = cast_if_needed(ln_out, ctx.activation_dtype) if ctx.ln_out_needs_gather: quantizer = None if ctx.input_quantizer is not None and ctx.fp8: @@ -757,7 +766,12 @@ def backward( # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight - if ctx.keep_backward_unquantized: + if ctx.backward_mode == "dequant": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) + elif ctx.backward_mode == "unquant": weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, @@ -1657,6 +1671,11 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + input_quantizer.optimize_for_gemm = False + if grad_output_quantizer is not None: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizer, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index cb40bff1ae..49b3ee8b9d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -234,12 +234,13 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" assert ( - not keep_backward_unquantized - ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + backward_mode == "default" + ), "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -786,7 +787,7 @@ def _forward( ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized + ctx.backward_mode = backward_mode ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5eda5886f1..b87ad71823 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -128,11 +128,12 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) - if keep_backward_unquantized: - # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + if backward_mode == "unquant": + # Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used. save_original_input = True # NVTX label for profiling @@ -193,7 +194,10 @@ def forward( raise ValueError("Missing quantizer for input tensor") if not isinstance(inputmat, QuantizedTensorStorage) and not custom: own_quantized_input = True - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and backward_mode == "default", + ) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -235,7 +239,12 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage( - rowwise=True, columnwise=backward_needs_input and not save_original_input + rowwise=True, + columnwise=( + backward_needs_input + and not save_original_input + and backward_mode == "default" + ), ) inputmat = input_quantizer(inputmat) own_quantized_input = True @@ -260,6 +269,8 @@ def forward( # for debug mode we create quantizer every iteration, thus we need to set the quantizer states if weight_quantizer is not None and (not isinstance(weight, QuantizedTensor) or debug): columnwise_usage = is_grad_enabled and inp.requires_grad + if backward_mode in ("unquant", "dequant"): + columnwise_usage = False if not columnwise_usage: columnwise_usage = ( is_fp8_activation_recompute_enabled() @@ -393,7 +404,11 @@ def forward( and own_quantized_input and isinstance(inputmat, QuantizedTensorStorage) ): - if ( + if backward_mode in ("unquant", "dequant"): + # In dequant mode we should dequantize directly from the + # fprop quantized tensor layout without retargeting usage. + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + elif ( ctx.backward_input_needs_gather and weight_quantizer.supports_only_rowwise_all_gather() ): @@ -448,7 +463,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized + ctx.backward_mode = backward_mode ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -492,9 +507,9 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store - # keep_backward_unquantized overrides - if keep_backward_unquantized: - ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + # Non-quantized backward mode overrides + if backward_mode in ("unquant", "dequant"): + ctx.fp8 = False ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -740,7 +755,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight_fp8 - if ctx.keep_backward_unquantized: + if ctx.backward_mode == "dequant": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) + elif ctx.backward_mode == "unquant": weight_for_dgrad = weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, @@ -1518,6 +1538,11 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + input_quantizer.optimize_for_gemm = False + if grad_output_quantizer is not None: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizer, weight_quantizer, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 15a6815d2e..7f21cd9331 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,10 +332,9 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) - columnwise_usage = weight_requires_grad and not keep_backward_unquantized + columnwise_usage = weight_requires_grad + if FP8GlobalStateManager.get_fp8_recipe().backward_mode in ("unquant", "dequant"): + columnwise_usage = False input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) @@ -359,6 +358,13 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: grad_output_quantizer.internal = True if not (self.tensor_parallel_mode == "row" and self.sequence_parallel): grad_output_quantizer.optimize_for_gemm = True + if FP8GlobalStateManager.is_fp8_enabled(): + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + if input_quantizer is not None: + input_quantizer.optimize_for_gemm = False + if grad_output_quantizer is not None: + grad_output_quantizer.optimize_for_gemm = False # Configure weight quantizer # Note: This function may be called in base class constructor, @@ -424,7 +430,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, - keep_backward_unquantized: bool = False, + backward_mode: str = "default", input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -464,8 +470,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. - keep_backward_unquantized: bool, default = `False` - Whether to skip quantized backward and use high precision. + backward_mode: {`"default"`, `"unquant"`, `"dequant"`}, default = `"default"` + Backward-mode policy for quantized compute. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -519,7 +525,7 @@ def _functional_forward( raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage( rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, + columnwise=weight_requires_grad and backward_mode == "default", ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) @@ -554,7 +560,7 @@ def _functional_forward( raise ValueError("Missing quantizer for weight tensor") weight_quantizer.set_usage( rowwise=True, - columnwise=input_requires_grad and not keep_backward_unquantized, + columnwise=input_requires_grad and backward_mode == "default", ) w = weight_quantizer(w) @@ -628,7 +634,7 @@ def _functional_forward( w is not weight and with_quantized_compute and is_quantized_tensor(w) - and not keep_backward_unquantized + and backward_mode == "default" ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: @@ -639,7 +645,7 @@ def _functional_forward( if ( with_quantized_compute and is_quantized_tensor(x_local) - and not keep_backward_unquantized + and backward_mode == "default" ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data @@ -990,9 +996,10 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = with_quantized_compute and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -1009,7 +1016,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -1019,12 +1026,13 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - saved_weight = self.weight if keep_backward_unquantized else w + saved_input = input_ if backward_mode == "unquant" else x_local + saved_weight = self.weight if backward_mode == "unquant" else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) - ctx.with_quantized_compute = with_quantized_compute and not keep_backward_unquantized + ctx.with_quantized_compute = with_quantized_compute and backward_mode == "default" + ctx.backward_mode = backward_mode ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 8bcd84b441..ad147a8d85 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -124,12 +124,11 @@ def op_forward( b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) if ctx.requires_grad: - keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) - ctx.grad_input_quantizer = ( - None if keep_backward_unquantized else prev_op_grad_output_quantizer - ) + ctx.grad_input_quantizer = prev_op_grad_output_quantizer + if FP8GlobalStateManager.is_fp8_enabled(): + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode in ("unquant", "dequant"): + ctx.grad_input_quantizer = None return x + b diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index b2a36d1daa..c5474c18a0 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -59,14 +59,10 @@ def op_forward( quantize_forward = fp8_enabled and self._quantize_forward quantize_backward = fp8_enabled and self._quantize_backward - # Recipe quantize overrides - if FP8GlobalStateManager.get_fp8_recipe() is not None: - quantize_forward = ( - quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward - ) - quantize_backward = ( - quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + # Backward quantization is controlled by recipe backward mode. + if fp8_enabled: + recipe = FP8GlobalStateManager.get_fp8_recipe() + quantize_backward = quantize_backward and recipe.backward_mode == "default" # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 395a9dbd67..7b3025c03e 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -105,8 +105,8 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion. - # keep-backward-unquantized mode should use unfused backward ops. - if recipe is None or not recipe.quantize_backward: + # unquant/dequant backward modes should use unfused backward ops. + if recipe is None or recipe.backward_mode in ("unquant", "dequant"): return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 42f459a41e..7584891384 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,9 +92,10 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = with_quantized_compute and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -112,7 +113,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -122,14 +123,16 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - saved_weight = linear_op.weight if keep_backward_unquantized else w + saved_input, saved_weight = x_local, w + if backward_mode == "unquant": + saved_input, saved_weight = input_, linear_op.weight if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized + with_quantized_compute and backward_mode == "default" ) + linear_op_ctx.backward_mode = backward_mode linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -138,9 +141,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = ( - None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() - ) + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + if backward_mode in ("unquant", "dequant"): + bias_op_ctx.grad_input_quantizer = None return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 75d58fd5cc..6935330f4e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,9 +86,10 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = with_quantized_compute and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -109,7 +110,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -119,14 +120,16 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - saved_weight = linear_op.weight if keep_backward_unquantized else w + saved_input, saved_weight = x_local, w + if backward_mode == "unquant": + saved_input, saved_weight = input_, linear_op.weight if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized + with_quantized_compute and backward_mode == "default" ) + linear_op_ctx.backward_mode = backward_mode linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -136,7 +139,7 @@ def fuser_forward( linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: bias_op_ctx.grad_input_quantizer = ( - None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + None if backward_mode != "default" else linear_op.get_grad_output_quantizer() ) return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index dfdd11a231..2358140c88 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,9 +65,10 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = with_quantized_compute and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -90,7 +91,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -100,14 +101,16 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - saved_weight = linear_op.weight if keep_backward_unquantized else w + saved_input, saved_weight = x_local, w + if backward_mode == "unquant": + saved_input, saved_weight = input_, linear_op.weight if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized + with_quantized_compute and backward_mode == "default" ) + linear_op_ctx.backward_mode = backward_mode linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 0d3e1d0416..54411f650d 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -388,6 +388,19 @@ def fuse_forward_ops( """ + # Disable Userbuffers for non-quantized backward modes. + # In unquant/dequant modes we want to avoid all UB-specific overlap + # paths and run through the standard non-UB operator sequence instead. + recipe = unused.get("recipe", None) + if recipe is not None: + backward_mode = recipe.backward_mode + elif FP8GlobalStateManager.is_fp8_enabled(): + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + if backward_mode in ("unquant", "dequant"): + return ops + # Return immediately if environment is not distributed if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 80386db2d9..616c075ad8 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -339,6 +339,7 @@ def __init__( # Cache and detect change of state relevant for fusing operations self.recipe_type = None self.first_op_requiring_backward = 0 + self.backward_mode = "default" self._last_amax_history_len = 0 # Flatten list of parameters @@ -415,9 +416,14 @@ def maybe_fuse_ops( # Early exit if fusion parameters haven't changed need_reset = False recipe_type = type(recipe) - fusion_params = (recipe_type, first_op_requiring_backward) - if fusion_params != (self.recipe_type, self.first_op_requiring_backward): - # Recipe type or grad requirmenets have changed + backward_mode = recipe.backward_mode if recipe is not None else "default" + fusion_params = (recipe_type, first_op_requiring_backward, backward_mode) + if fusion_params != ( + self.recipe_type, + self.first_op_requiring_backward, + self.backward_mode, + ): + # Recipe type, backward mode, or grad requirements have changed need_reset = True elif ( recipe is not None @@ -451,7 +457,7 @@ def maybe_fuse_ops( ) # Save current fusion params - self.recipe_type, self.first_op_requiring_backward = fusion_params + self.recipe_type, self.first_op_requiring_backward, self.backward_mode = fusion_params # Save amax history length if isinstance(recipe, DelayedScaling): diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 134781101a..47e6d5c8dc 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -843,15 +843,14 @@ def autocast( are reduced at the end of each training step. """ - effective_enabled = enabled and getattr(recipe, "quantize_forward", True) - if effective_enabled: + if enabled: check_recipe_support(recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() FP8GlobalStateManager.autocast_enter( - enabled=effective_enabled, + enabled=enabled, calibrating=calibrating, fp8_recipe=recipe, fp8_group=amax_reduction_group, @@ -861,7 +860,7 @@ def autocast( yield finally: FP8GlobalStateManager.set_autocast_state(fp8_state) - FP8GlobalStateManager.autocast_exit(effective_enabled, _graph=_graph) + FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: From 749e3a0749565cdb706ea912cbbbecd68492c299 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 25 Feb 2026 12:59:01 -0800 Subject: [PATCH 45/61] Fix override and clean up Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 6 +++--- transformer_engine/pytorch/module/layernorm_mlp.py | 7 ++++--- transformer_engine/pytorch/module/linear.py | 1 - 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8126951ad7..3b9c3b2949 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -102,7 +102,6 @@ def forward( else: backward_mode = "default" if backward_mode == "unquant": - # Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used. save_original_input = True num_gemms = len(m_splits) @@ -1142,8 +1141,9 @@ def _get_quantizers(self): if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): for input_quantizer in input_quantizers: input_quantizer.optimize_for_gemm = False - for grad_output_quantizer in grad_output_quantizers: - grad_output_quantizer.optimize_for_gemm = False + if torch.is_grad_enabled(): + for grad_output_quantizer in grad_output_quantizers: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizers, weight_quantizers, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 49b3ee8b9d..4f206c866e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -238,9 +238,10 @@ def _forward( backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode else: backward_mode = "default" - assert ( - backward_mode == "default" - ), "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP" + assert backward_mode == "default", ( + "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP. " + "Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward." + ) # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b87ad71823..de00553225 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -133,7 +133,6 @@ def forward( else: backward_mode = "default" if backward_mode == "unquant": - # Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used. save_original_input = True # NVTX label for profiling From 03a7fe993cf98c3116c732228ded53dc208f3676 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Mar 2026 12:55:22 -0800 Subject: [PATCH 46/61] Clean up unit test Signed-off-by: Ziang Li --- tests/pytorch/test_backward_mode.py | 136 ++++++++++++---------------- tests/pytorch/utils.py | 8 +- 2 files changed, 64 insertions(+), 80 deletions(-) diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py index 300d860496..2a744d8278 100644 --- a/tests/pytorch/test_backward_mode.py +++ b/tests/pytorch/test_backward_mode.py @@ -25,7 +25,7 @@ ) from transformer_engine.pytorch.quantized_tensor import restore_from_saved -from utils import quantization_tols, reset_rng_states +from utils import assert_close, make_recipe, quantization_tols, reset_rng_states # -------------------------- @@ -42,15 +42,11 @@ nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) bf16_available, reason_for_no_bf16 = te.is_bf16_available(return_reason=True) -# Broad dtype coverage for modules touched by this change. _core_dtypes = [torch.float16, torch.float32] -if bf16_available: - _core_dtypes.insert(1, torch.bfloat16) - -# Fused GEMM+bias+activation requires FP16/BF16 output. _fused_dtypes = [torch.float16] if bf16_available: - _fused_dtypes.append(torch.bfloat16) + _core_dtypes.insert(1, torch.bfloat16) + _fused_dtypes.insert(1, torch.bfloat16) @pytest.fixture(autouse=True) @@ -71,18 +67,6 @@ def backward_mode(request: pytest.FixtureRequest) -> str: # -------------------------- -def _assert_exact(test: torch.Tensor, ref: torch.Tensor) -> None: - torch.testing.assert_close(test, ref, rtol=0, atol=0) - - -def _assert_forward_matches_quantized_ref( - test: torch.Tensor, - ref: torch.Tensor, - recipe_name: str, -) -> None: - torch.testing.assert_close(test, ref, **_fprop_tolerances(recipe_name)) - - def _restore_saved_operands(output: torch.Tensor) -> list[Optional[torch.Tensor]]: if output.grad_fn is None: raise RuntimeError("Output tensor has no grad_fn; cannot inspect saved operands") @@ -326,32 +310,27 @@ def _compute_linear_backward_reference_from_saved_operands( pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), + pytest.param((160, 64), 64, id="2d_m160_k64_n64"), + pytest.param((5, 64, 64), 64, id="3d_m320_k64_n64"), + pytest.param((3, 5, 32, 64), 96, id="4d_m480_k64_n96"), + pytest.param((2, 5, 16, 128), 64, id="4d_m160_k128_n64"), + # Intentionally unaligned token dimensions to exercise skip/support logic. + pytest.param((3, 64), 64, id="2d_m3_k64_n64_unaligned"), + pytest.param((3, 10, 64), 64, id="3d_m30_k64_n64_unaligned"), ] _bias_activation_shape_cases = [ pytest.param((32, 64), id="2d_m32_k64"), pytest.param((8, 4, 64), id="3d_m32_k64"), + pytest.param((160, 64), id="2d_m160_k64"), + pytest.param((5, 64, 64), id="3d_m320_k64"), + pytest.param((3, 5, 32, 64), id="4d_m480_k64"), + # Intentionally unaligned token dimensions to exercise skip/support logic. + pytest.param((3, 64), id="2d_m3_k64_unaligned"), + pytest.param((3, 10, 64), id="3d_m30_k64_unaligned"), ] -def _make_recipe(recipe_name: str, *, backward_mode: str) -> recipe.Recipe: - kwargs = {"backward_mode": backward_mode} - if recipe_name == "fp8_current_scaling": - return recipe.Float8CurrentScaling(fp8_format=recipe.Format.E4M3, **kwargs) - if recipe_name == "mxfp8": - return recipe.MXFP8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) - if recipe_name == "fp8_block_scaling": - return recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) - if recipe_name == "nvfp4": - return recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - **kwargs, - ) - raise ValueError(f"Unsupported recipe for backward-mode test: {recipe_name}") - - def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: src_params = dict(src_module.named_parameters()) with torch.no_grad(): @@ -371,7 +350,7 @@ def _fprop_tolerances(recipe_name: str) -> dict[str, float]: raise ValueError(f"Unsupported recipe for backward-mode test: {recipe_name}") -def _maybe_skip_recipe_dtype(recipe_name: str, dtype: torch.dtype, backward_mode: str) -> None: +def _maybe_skip_recipe_dtype(recipe_name: str, dtype: torch.dtype) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) if recipe_name == "nvfp4" and dtype != torch.bfloat16: @@ -705,8 +684,8 @@ def test_backward_mode_recipe_matches_requested_mode( recipe_name: str, backward_mode: str, ) -> None: - mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) - quant_recipe = _make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + quant_recipe = make_recipe(recipe_name, backward_mode="default") assert mode_recipe.backward_mode == backward_mode assert quant_recipe.backward_mode == "default" @@ -726,13 +705,13 @@ def test_linear_like_backward_mode_matches_reference( backward_mode: str, ) -> None: reset_rng_states() - _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_recipe_dtype(recipe_name, dtype) _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) in_features = input_shape[-1] - quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") - mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) module_quantized_ref = _make_linear_like_module( module_type, @@ -882,13 +861,13 @@ def test_linear_like_backward_mode_matches_reference( _raise_if_ref_failed(ref_exc) assert dx_ref is not None and dw_ref is not None and db_ref is not None - _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) - _assert_exact(dx_bwd_mode, dx_ref) - _assert_exact(dw_bwd_mode, dw_ref) + assert_close(y_bwd_mode, y_quantized_ref, check_dtype=True, **_fprop_tolerances(recipe_name)) + assert_close(dx_bwd_mode, dx_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dw_bwd_mode, dw_ref, rtol=0, atol=0, check_dtype=True) if use_bias: assert db_bwd_mode is not None assert db_ref is not None - _assert_exact(db_bwd_mode, db_ref) + assert_close(db_bwd_mode, db_ref, rtol=0, atol=0, check_dtype=True) @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -910,7 +889,7 @@ def test_grouped_linear_backward_mode_matches_reference( pytest.skip("NVFP4 not supported for grouped linear") reset_rng_states() - _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_recipe_dtype(recipe_name, dtype) _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) in_features = 64 @@ -918,8 +897,8 @@ def test_grouped_linear_backward_mode_matches_reference( num_gemms = len(m_splits) num_tokens = sum(m_splits) - quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") - mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) module_quantized_ref = te.GroupedLinear( num_gemms, @@ -1045,15 +1024,15 @@ def test_grouped_linear_backward_mode_matches_reference( _raise_if_ref_failed(ref_exc) assert dx_ref is not None - _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) - _assert_exact(dx_bwd_mode, dx_ref) + assert_close(y_bwd_mode, y_quantized_ref, check_dtype=True, **_fprop_tolerances(recipe_name)) + assert_close(dx_bwd_mode, dx_ref, rtol=0, atol=0, check_dtype=True) for test_dw, ref_dw in zip(dw_bwd_mode, dw_ref): - _assert_exact(test_dw, ref_dw) + assert_close(test_dw, ref_dw, rtol=0, atol=0, check_dtype=True) if use_bias: for test_db, ref_db_i in zip(db_bwd_mode, db_ref): assert test_db is not None assert ref_db_i is not None - _assert_exact(test_db, ref_db_i) + assert_close(test_db, ref_db_i, rtol=0, atol=0, check_dtype=True) @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1074,7 +1053,7 @@ def test_fused_linear_paths_match_backward_mode_reference( dtype: torch.dtype, backward_mode: str, ) -> None: - _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_recipe_dtype(recipe_name, dtype) _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, (m, 64), "ops_linear") @@ -1082,8 +1061,8 @@ def test_fused_linear_paths_match_backward_mode_reference( in_features = 64 out_features = 64 - quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") - mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) model_bwd_mode = _make_fused_model(fused_pattern, in_features, out_features, dtype) @@ -1196,13 +1175,13 @@ def test_fused_linear_paths_match_backward_mode_reference( assert len(fused_ops) >= 1 assert isinstance(fused_ops[0][0], expected_fused_op) - _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) - _assert_exact(dx1_bwd_mode, dx1_ref) - _assert_exact(dw_bwd_mode, dw_ref) + assert_close(y_bwd_mode, y_quantized_ref, check_dtype=True, **_fprop_tolerances(recipe_name)) + assert_close(dx1_bwd_mode, dx1_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dw_bwd_mode, dw_ref, rtol=0, atol=0, check_dtype=True) if dx2_bwd_mode is not None and dx2_ref is not None: - _assert_exact(dx2_bwd_mode, dx2_ref) + assert_close(dx2_bwd_mode, dx2_ref, rtol=0, atol=0, check_dtype=True) if db_bwd_mode is not None and db_ref is not None: - _assert_exact(db_bwd_mode, db_ref) + assert_close(db_bwd_mode, db_ref, rtol=0, atol=0, check_dtype=True) @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1214,15 +1193,16 @@ def test_fused_bias_activation_matches_masked_linear_backward( dtype: torch.dtype, backward_mode: str, ) -> None: - _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_recipe_dtype(recipe_name, dtype) _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "ops_linear") reset_rng_states() in_features = input_shape[-1] out_features = 64 - quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") - mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) model_bwd_mode = _make_fused_model("bias_activation", in_features, out_features, dtype) @@ -1332,12 +1312,12 @@ def test_fused_bias_activation_matches_masked_linear_backward( quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) - _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) - _assert_exact(dx1_bwd_mode, dx1_ref) - _assert_exact(dw_bwd_mode, dw_ref) + assert_close(y_bwd_mode, y_quantized_ref, check_dtype=True, **_fprop_tolerances(recipe_name)) + assert_close(dx1_bwd_mode, dx1_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dw_bwd_mode, dw_ref, rtol=0, atol=0, check_dtype=True) assert db_bwd_mode is not None assert db_ref is not None - _assert_exact(db_bwd_mode, db_ref) + assert_close(db_bwd_mode, db_ref, rtol=0, atol=0, check_dtype=True) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @@ -1356,14 +1336,14 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( # Use a mutable recipe holder so we can switch fusion behavior on the same # fuser object and verify that the cached fusion plan is refreshed. - current_recipe = {"value": _make_recipe(recipe_name, backward_mode="default")} + current_recipe = {"value": make_recipe(recipe_name, backward_mode="default")} monkeypatch.setattr(FP8GlobalStateManager, "get_fp8_recipe", lambda: current_recipe["value"]) reset_rng_states() _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") fuser, x, extra_inputs = _make_userbuffers_fuser_for_mode_switch_test(dtype=dtype) - quant_recipe = _make_recipe(recipe_name, backward_mode="default") + quant_recipe = make_recipe(recipe_name, backward_mode="default") fuser.maybe_fuse_ops( is_grad_enabled=True, recipe=quant_recipe, @@ -1372,7 +1352,7 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( ) assert _has_userbuffers_forward_linear(fuser) - non_quant_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + non_quant_recipe = make_recipe(recipe_name, backward_mode=backward_mode) current_recipe["value"] = non_quant_recipe fuser.maybe_fuse_ops( is_grad_enabled=True, @@ -1390,7 +1370,7 @@ def test_quantize_op_respects_backward_mode( dtype: torch.dtype, backward_mode: str, ) -> None: - _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_recipe_dtype(recipe_name, dtype) _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") reset_rng_states() @@ -1400,13 +1380,13 @@ def test_quantize_op_respects_backward_mode( model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=False)) - mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, mode_recipe) y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, mode_recipe) - _assert_exact(y_override, y_ref) - _assert_exact(dx_override, dx_ref) + assert_close(y_override, y_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx_override, dx_ref, rtol=0, atol=0, check_dtype=True) def test_delayed_scaling_rejects_non_quant_backward_mode(backward_mode: str) -> None: @@ -1425,7 +1405,7 @@ def test_layernorm_mlp_not_implemented_for_unquantized_backward_mode( dtype: torch.dtype, backward_mode: str, ) -> None: - _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_recipe_dtype(recipe_name, dtype) reset_rng_states() layer = te.LayerNormMLP( @@ -1436,7 +1416,7 @@ def test_layernorm_mlp_not_implemented_for_unquantized_backward_mode( device="cuda", ) x = torch.randn(32, 64, dtype=dtype, device="cuda", requires_grad=True) - mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) with pytest.raises( AssertionError, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 317240fb78..dfd0c73738 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -117,7 +117,7 @@ def quantization_tols(name: str) -> dict[str, float]: raise ValueError(f"Unsupported quantization scheme ({name})") -def make_recipe(name: Optional[str]) -> Optional[Recipe]: +def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: """Make recipe for quantization scheme""" if name is None: return None @@ -125,22 +125,26 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: return transformer_engine.common.recipe.DelayedScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, amax_history_len=8, + **recipe_kwargs, ) if name == "fp8_current_scaling": return transformer_engine.common.recipe.Float8CurrentScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, + **recipe_kwargs, ) if name == "mxfp8": return transformer_engine.common.recipe.MXFP8BlockScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, + **recipe_kwargs, ) if name == "fp8_block_scaling": - return transformer_engine.common.recipe.Float8BlockScaling() + return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) if name == "nvfp4": return transformer_engine.common.recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, + **recipe_kwargs, ) raise ValueError(f"Unsupported quantization scheme ({name})") From 4ef353fb39d0c22c276044f2304acf6f8647de27 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Mar 2026 13:51:10 -0800 Subject: [PATCH 47/61] Clean up unit test Signed-off-by: Ziang Li --- tests/pytorch/test_backward_mode.py | 137 +++++++--------------------- 1 file changed, 35 insertions(+), 102 deletions(-) diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py index 2a744d8278..c780c60cff 100644 --- a/tests/pytorch/test_backward_mode.py +++ b/tests/pytorch/test_backward_mode.py @@ -25,7 +25,7 @@ ) from transformer_engine.pytorch.quantized_tensor import restore_from_saved -from utils import assert_close, make_recipe, quantization_tols, reset_rng_states +from utils import assert_close, make_recipe, reset_rng_states # -------------------------- @@ -340,16 +340,6 @@ def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Mod dst_param.copy_(src_params[name]) -def _fprop_tolerances(recipe_name: str) -> dict[str, float]: - if recipe_name == "mxfp8": - return quantization_tols("mxfp8") - if recipe_name in ("fp8_current_scaling", "fp8_block_scaling"): - return quantization_tols("fp8_current_scaling") - if recipe_name == "nvfp4": - return quantization_tols("nvfp4") - raise ValueError(f"Unsupported recipe for backward-mode test: {recipe_name}") - - def _maybe_skip_recipe_dtype(recipe_name: str, dtype: torch.dtype) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) @@ -452,7 +442,8 @@ def _run_single_step( y.backward(dy) assert x_run.grad is not None assert module.weight.grad is not None - bgrad = _extract_bias_grad(module) + bias = getattr(module, "bias", None) + bgrad = None if bias is None or bias.grad is None else bias.grad.detach().clone() return ( y.detach().clone(), x_run.grad.detach().clone(), @@ -480,13 +471,6 @@ def _run_single_step_with_saved_operands( return y, x_run, saved_operands -def _extract_bias_grad(module: torch.nn.Module) -> Optional[torch.Tensor]: - bias = getattr(module, "bias", None) - if bias is None or bias.grad is None: - return None - return bias.grad.detach().clone() - - def _run_grouped_linear_single_step( module: te.GroupedLinear, x: torch.Tensor, @@ -640,40 +624,6 @@ def _run_quantize_op_single_step( return y.detach().clone(), x_run.grad.detach().clone() -def _make_userbuffers_fuser_for_mode_switch_test( - *, - dtype: torch.dtype, -) -> tuple[object, torch.Tensor, list[tuple[()]]]: - """Build a Userbuffers-eligible fuser and representative inputs.""" - in_features = 64 - out_features = 64 - linear = te_ops.BasicLinear( - in_features, - out_features, - device="cuda", - dtype=dtype, - userbuffers_options={"comm_name": "qkv"}, - ) - linear.tensor_parallel_mode = "column" - linear.tensor_parallel_size = 2 - linear.sequence_parallel = True - bias = te_ops.Bias(out_features, device="cuda", dtype=dtype) - model = te_ops.Sequential(linear, bias) - model._module_groups = model._make_module_groups( - model._modules.values() - ) # pylint: disable=protected-access - fuser = model._module_groups[0] - x = torch.randn(32, in_features, dtype=dtype, device="cuda", requires_grad=True) - extra_inputs = [() for _ in range(fuser._num_basic_ops)] # pylint: disable=protected-access - return fuser, x, extra_inputs - - -def _has_userbuffers_forward_linear(fuser: object) -> bool: - return any( - isinstance(op, UserbuffersForwardLinear) for op, _ in fuser._forward_ops - ) # pylint: disable=protected-access - - # -------------------------- # Tests # -------------------------- @@ -844,7 +794,7 @@ def test_linear_like_backward_mode_matches_reference( if module_type == "ops_linear" and use_bias: # te_ops bias grad is reduced by the Bias op from incoming dy. db_ref = dy.reshape(-1, dy.shape[-1]).sum(dim=0).to(dtype) - except Exception as exc: # pylint: disable=broad-exception-caught + except Exception as exc: ref_exc = exc layout_invariants = _snapshot_layout_invariants(guard_operands) @@ -854,14 +804,15 @@ def test_linear_like_backward_mode_matches_reference( assert module_bwd_mode.weight.grad is not None dx_bwd_mode = x_bwd_mode.grad.detach().clone() dw_bwd_mode = module_bwd_mode.weight.grad.detach().clone() - db_bwd_mode = _extract_bias_grad(module_bwd_mode) + bias = getattr(module_bwd_mode, "bias", None) + db_bwd_mode = None if bias is None or bias.grad is None else bias.grad.detach().clone() y_bwd_mode = y_bwd_mode_detached _assert_layout_invariants_unchanged(layout_invariants) _raise_if_ref_failed(ref_exc) assert dx_ref is not None and dw_ref is not None and db_ref is not None - assert_close(y_bwd_mode, y_quantized_ref, check_dtype=True, **_fprop_tolerances(recipe_name)) + assert_close(y_bwd_mode, y_quantized_ref, rtol=0, atol=0, check_dtype=True) assert_close(dx_bwd_mode, dx_ref, rtol=0, atol=0, check_dtype=True) assert_close(dw_bwd_mode, dw_ref, rtol=0, atol=0, check_dtype=True) if use_bias: @@ -1000,7 +951,7 @@ def test_grouped_linear_backward_mode_matches_reference( dw_ref.append(dw_i) db_ref.append(db_i if use_bias else None) dx_ref = torch.cat(dx_chunks, dim=0) - except Exception as exc: # pylint: disable=broad-exception-caught + except Exception as exc: ref_exc = exc layout_invariants = _snapshot_layout_invariants(guard_operands) @@ -1024,7 +975,7 @@ def test_grouped_linear_backward_mode_matches_reference( _raise_if_ref_failed(ref_exc) assert dx_ref is not None - assert_close(y_bwd_mode, y_quantized_ref, check_dtype=True, **_fprop_tolerances(recipe_name)) + assert_close(y_bwd_mode, y_quantized_ref, rtol=0, atol=0, check_dtype=True) assert_close(dx_bwd_mode, dx_ref, rtol=0, atol=0, check_dtype=True) for test_dw, ref_dw in zip(dw_bwd_mode, dw_ref): assert_close(test_dw, ref_dw, rtol=0, atol=0, check_dtype=True) @@ -1145,7 +1096,7 @@ def test_fused_linear_paths_match_backward_mode_reference( out_dtype=dtype, ) dx2_ref = dy if x2 is not None else None - except Exception as exc: # pylint: disable=broad-exception-caught + except Exception as exc: ref_exc = exc layout_invariants = _snapshot_layout_invariants(guard_operands) @@ -1175,7 +1126,7 @@ def test_fused_linear_paths_match_backward_mode_reference( assert len(fused_ops) >= 1 assert isinstance(fused_ops[0][0], expected_fused_op) - assert_close(y_bwd_mode, y_quantized_ref, check_dtype=True, **_fprop_tolerances(recipe_name)) + assert_close(y_bwd_mode, y_quantized_ref, rtol=0, atol=0, check_dtype=True) assert_close(dx1_bwd_mode, dx1_ref, rtol=0, atol=0, check_dtype=True) assert_close(dw_bwd_mode, dw_ref, rtol=0, atol=0, check_dtype=True) if dx2_bwd_mode is not None and dx2_ref is not None: @@ -1280,7 +1231,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( dequant_dtype=dtype, out_dtype=dtype, ) - except Exception as exc: # pylint: disable=broad-exception-caught + except Exception as exc: ref_exc = exc layout_invariants = _snapshot_layout_invariants(guard_operands) @@ -1312,7 +1263,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) - assert_close(y_bwd_mode, y_quantized_ref, check_dtype=True, **_fprop_tolerances(recipe_name)) + assert_close(y_bwd_mode, y_quantized_ref, rtol=0, atol=0, check_dtype=True) assert_close(dx1_bwd_mode, dx1_ref, rtol=0, atol=0, check_dtype=True) assert_close(dw_bwd_mode, dw_ref, rtol=0, atol=0, check_dtype=True) assert db_bwd_mode is not None @@ -1341,7 +1292,26 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( reset_rng_states() _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") - fuser, x, extra_inputs = _make_userbuffers_fuser_for_mode_switch_test(dtype=dtype) + + # Build a Userbuffers-eligible fuser and representative inputs. + in_features = 64 + out_features = 64 + linear = te_ops.BasicLinear( + in_features, + out_features, + device="cuda", + dtype=dtype, + userbuffers_options={"comm_name": "qkv"}, + ) + linear.tensor_parallel_mode = "column" + linear.tensor_parallel_size = 2 + linear.sequence_parallel = True + bias = te_ops.Bias(out_features, device="cuda", dtype=dtype) + model = te_ops.Sequential(linear, bias) + model._module_groups = model._make_module_groups(model._modules.values()) + fuser = model._module_groups[0] + x = torch.randn(32, in_features, dtype=dtype, device="cuda", requires_grad=True) + extra_inputs = [() for _ in range(fuser._num_basic_ops)] quant_recipe = make_recipe(recipe_name, backward_mode="default") fuser.maybe_fuse_ops( @@ -1350,7 +1320,7 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( input_=x, extra_inputs=extra_inputs, ) - assert _has_userbuffers_forward_linear(fuser) + assert any(isinstance(op, UserbuffersForwardLinear) for op, _ in fuser._forward_ops) non_quant_recipe = make_recipe(recipe_name, backward_mode=backward_mode) current_recipe["value"] = non_quant_recipe @@ -1360,7 +1330,7 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( input_=x, extra_inputs=extra_inputs, ) - assert not _has_userbuffers_forward_linear(fuser) + assert not any(isinstance(op, UserbuffersForwardLinear) for op, _ in fuser._forward_ops) @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1387,40 +1357,3 @@ def test_quantize_op_respects_backward_mode( assert_close(y_override, y_ref, rtol=0, atol=0, check_dtype=True) assert_close(dx_override, dx_ref, rtol=0, atol=0, check_dtype=True) - - -def test_delayed_scaling_rejects_non_quant_backward_mode(backward_mode: str) -> None: - with pytest.raises( - (AssertionError, ValueError), - match="Delayed scaling only supports backward_mode=default", - ): - _ = recipe.DelayedScaling(backward_mode=backward_mode) - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) -@pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) -def test_layernorm_mlp_not_implemented_for_unquantized_backward_mode( - recipe_name: str, - dtype: torch.dtype, - backward_mode: str, -) -> None: - _maybe_skip_recipe_dtype(recipe_name, dtype) - reset_rng_states() - - layer = te.LayerNormMLP( - hidden_size=64, - ffn_hidden_size=64, - params_dtype=dtype, - bias=False, - device="cuda", - ) - x = torch.randn(32, 64, dtype=dtype, device="cuda", requires_grad=True) - mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) - - with pytest.raises( - AssertionError, - match="NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP", - ): - with te.autocast(enabled=True, recipe=mode_recipe): - _ = layer(x) From c16ba4b984273b8c6544bdf494b10887137fd2d7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Mar 2026 10:14:59 -0700 Subject: [PATCH 48/61] Override `ctx.reduce_and_update_bwd_fp8_tensors = False` Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 1 + transformer_engine/pytorch/module/layernorm_linear.py | 1 + transformer_engine/pytorch/module/linear.py | 1 + 3 files changed, 3 insertions(+) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3b9c3b2949..5350e7e6f6 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -336,6 +336,7 @@ def forward( ctx.grad_input_quantizer = None ctx.grad_weight_quantizer = None ctx.grad_output_quantizer = None + ctx.reduce_and_update_bwd_fp8_tensors = False # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 75d7802143..91c67ce3b0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -548,6 +548,7 @@ def forward( ctx.grad_input_quantizer = None ctx.grad_weight_quantizer = None ctx.grad_output_quantizer = None + ctx.reduce_and_update_bwd_fp8_tensors = False # ------------------------------------------------------ # Cached state for backward pass is ready... diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index de00553225..f3a35ecade 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -516,6 +516,7 @@ def forward( ctx.grad_input_quantizer = None ctx.grad_weight_quantizer = None ctx.grad_output_quantizer = None + ctx.reduce_and_update_bwd_fp8_tensors = False # ------------------------------------------------------ # Cached state for backward pass is ready... From 27e70bcef0ab29d77645f338616e7ecb8a9759b7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Mar 2026 11:15:28 -0700 Subject: [PATCH 49/61] Expand unit test Signed-off-by: Ziang Li --- tests/pytorch/test_backward_mode.py | 307 ++++++++++++++++++++++++++-- 1 file changed, 293 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py index c780c60cff..6c0fea2da4 100644 --- a/tests/pytorch/test_backward_mode.py +++ b/tests/pytorch/test_backward_mode.py @@ -307,6 +307,7 @@ def _compute_linear_backward_reference_from_saved_operands( _shape_test_cases = [ pytest.param((1, 64), 64, id="2d_m1_k64_n64"), pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((32, 96), 96, id="2d_m32_k96_n96"), pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), @@ -317,10 +318,12 @@ def _compute_linear_backward_reference_from_saved_operands( # Intentionally unaligned token dimensions to exercise skip/support logic. pytest.param((3, 64), 64, id="2d_m3_k64_n64_unaligned"), pytest.param((3, 10, 64), 64, id="3d_m30_k64_n64_unaligned"), + pytest.param((3, 10, 96), 96, id="3d_m30_k96_n96_unaligned"), ] _bias_activation_shape_cases = [ pytest.param((32, 64), id="2d_m32_k64"), + pytest.param((32, 96), id="2d_m32_k96"), pytest.param((8, 4, 64), id="3d_m32_k64"), pytest.param((160, 64), id="2d_m160_k64"), pytest.param((5, 64, 64), id="3d_m320_k64"), @@ -328,6 +331,30 @@ def _compute_linear_backward_reference_from_saved_operands( # Intentionally unaligned token dimensions to exercise skip/support logic. pytest.param((3, 64), id="2d_m3_k64_unaligned"), pytest.param((3, 10, 64), id="3d_m30_k64_unaligned"), + pytest.param((3, 10, 96), id="3d_m30_k96_unaligned"), +] + +_grouped_m_split_cases = [ + pytest.param([32, 32, 32, 32], id="uniform_splits"), + pytest.param([64, 0, 32, 32], id="with_empty_split"), + pytest.param([1, 31, 0, 96], id="small_and_empty_splits"), +] + +_linear_feature_cases = [ + pytest.param(64, 64, id="k64_n64"), + pytest.param(64, 128, id="k64_n128"), + pytest.param(128, 64, id="k128_n64"), + pytest.param(96, 96, id="k96_n96"), + pytest.param(64, 96, id="k64_n96"), + pytest.param(96, 64, id="k96_n64"), + pytest.param(128, 96, id="k128_n96"), + pytest.param(96, 128, id="k96_n128"), +] + +_output_feature_cases = [ + pytest.param(64, id="n64"), + pytest.param(96, id="n96"), + pytest.param(128, id="n128"), ] @@ -624,6 +651,120 @@ def _run_quantize_op_single_step( return y.detach().clone(), x_run.grad.detach().clone() +def _snapshot_backward_ctx_state( + output: torch.Tensor, +) -> tuple[str, bool, object, bool]: + if output.grad_fn is None: + raise RuntimeError("Output tensor has no grad_fn; cannot inspect backward context state.") + required_attrs = ( + "backward_mode", + "fp8", + "grad_output_quantizer", + "reduce_and_update_bwd_fp8_tensors", + ) + missing_attrs = [attr for attr in required_attrs if not hasattr(output.grad_fn, attr)] + if missing_attrs: + raise RuntimeError( + "grad_fn does not expose required backward context attributes: " + f"{', '.join(missing_attrs)}." + ) + return ( + getattr(output.grad_fn, "backward_mode"), + bool(getattr(output.grad_fn, "fp8")), + getattr(output.grad_fn, "grad_output_quantizer"), + bool(getattr(output.grad_fn, "reduce_and_update_bwd_fp8_tensors")), + ) + + +def _run_single_step_with_ctx_state( + module: torch.nn.Module, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + tuple[str, bool, object, bool], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + ctx_state = _snapshot_backward_ctx_state(y) + y.backward(dy) + assert x_run.grad is not None + assert module.weight.grad is not None + bias = getattr(module, "bias", None) + bgrad = None if bias is None or bias.grad is None else bias.grad.detach().clone() + return ( + y.detach().clone(), + x_run.grad.detach().clone(), + module.weight.grad.detach().clone(), + bgrad, + ctx_state, + ) + + +def _run_grouped_linear_single_step_with_ctx_state( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[ + torch.Tensor, + torch.Tensor, + list[torch.Tensor], + list[Optional[torch.Tensor]], + tuple[str, bool, bool], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run, m_splits) + if y.grad_fn is None: + raise RuntimeError( + "Output tensor has no grad_fn; cannot inspect grouped backward state." + ) + required_attrs = ( + "backward_mode", + "fp8", + "reduce_and_update_bwd_fp8_tensors", + ) + missing_attrs = [attr for attr in required_attrs if not hasattr(y.grad_fn, attr)] + if missing_attrs: + raise RuntimeError( + "Grouped grad_fn does not expose required backward context attributes: " + f"{', '.join(missing_attrs)}." + ) + ctx_state = ( + getattr(y.grad_fn, "backward_mode"), + bool(getattr(y.grad_fn, "fp8")), + bool(getattr(y.grad_fn, "reduce_and_update_bwd_fp8_tensors")), + ) + y.backward(dy) + assert x_run.grad is not None + + dw = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + db: list[Optional[torch.Tensor]] = [] + for i in range(module.num_gemms): + if module.use_bias: + db.append(getattr(module, f"bias{i}").grad.detach().clone()) + else: + db.append(None) + return y.detach().clone(), x_run.grad.detach().clone(), dw, db, ctx_state + + # -------------------------- # Tests # -------------------------- @@ -822,15 +963,14 @@ def test_linear_like_backward_mode_matches_reference( @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) @pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) -@pytest.mark.parametrize( - "m_splits", - ([32, 32, 32, 32], [64, 0, 32, 32], [1, 31, 0, 96]), - ids=("uniform_splits", "with_empty_split", "small_and_empty_splits"), -) +@pytest.mark.parametrize("m_splits", _grouped_m_split_cases) @pytest.mark.parametrize("dtype", _core_dtypes, ids=str) def test_grouped_linear_backward_mode_matches_reference( recipe_name: str, + in_features: int, + out_features: int, use_bias: bool, m_splits: list[int], dtype: torch.dtype, @@ -842,9 +982,6 @@ def test_grouped_linear_backward_mode_matches_reference( reset_rng_states() _maybe_skip_recipe_dtype(recipe_name, dtype) _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) - - in_features = 64 - out_features = 64 num_gemms = len(m_splits) num_tokens = sum(m_splits) @@ -986,6 +1123,145 @@ def test_grouped_linear_backward_mode_matches_reference( assert_close(test_db, ref_db_i, rtol=0, atol=0, check_dtype=True) +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("module_type", ("linear", "layernorm_linear")) +@pytest.mark.parametrize("input_shape,out_features", _shape_test_cases) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_linear_like_runtime_backward_mode_switch_updates_ctx( + recipe_name: str, + module_type: str, + input_shape: tuple[int, ...], + out_features: int, + use_bias: bool, + dtype: torch.dtype, + backward_mode: str, +) -> None: + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) + + module = _make_linear_like_module( + module_type, + input_shape[-1], + out_features, + dtype, + bias=use_bias, + ) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") + + default_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + + *_, default_ctx = _run_single_step_with_ctx_state(module, x, dy, default_recipe) + ( + default_mode, + default_fp8, + default_grad_output_quantizer, + default_reduce_and_update, + ) = default_ctx + assert default_mode == "default" + assert default_fp8 + assert default_grad_output_quantizer is not None + assert default_reduce_and_update + + *_, switched_ctx = _run_single_step_with_ctx_state(module, x, dy, mode_recipe) + switched_mode, switched_fp8, switched_grad_output_quantizer, switched_reduce_and_update = ( + switched_ctx + ) + assert switched_mode == backward_mode + assert not switched_fp8 + assert switched_grad_output_quantizer is None + assert not switched_reduce_and_update + + *_, default_ctx_after = _run_single_step_with_ctx_state(module, x, dy, default_recipe) + ( + default_mode_after, + default_fp8_after, + default_grad_output_quantizer_after, + default_reduce_and_update_after, + ) = default_ctx_after + assert default_mode_after == "default" + assert default_fp8_after + assert default_grad_output_quantizer_after is not None + assert default_reduce_and_update_after + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) +@pytest.mark.parametrize("m_splits", _grouped_m_split_cases) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_grouped_linear_runtime_backward_mode_switch_updates_ctx( + recipe_name: str, + in_features: int, + out_features: int, + m_splits: list[int], + use_bias: bool, + dtype: torch.dtype, + backward_mode: str, +) -> None: + if recipe_name == "nvfp4": + pytest.skip("NVFP4 not supported for grouped linear") + + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + + num_tokens = sum(m_splits) + module = te.GroupedLinear( + len(m_splits), + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + default_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + + *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( + module, + x, + m_splits, + dy, + default_recipe, + ) + default_mode, default_fp8, default_reduce_and_update = default_ctx + assert default_mode == "default" + assert default_fp8 + assert default_reduce_and_update + + *_, switched_ctx = _run_grouped_linear_single_step_with_ctx_state( + module, + x, + m_splits, + dy, + mode_recipe, + ) + switched_mode, switched_fp8, switched_reduce_and_update = switched_ctx + assert switched_mode == backward_mode + assert not switched_fp8 + assert not switched_reduce_and_update + + *_, default_ctx_after = _run_grouped_linear_single_step_with_ctx_state( + module, + x, + m_splits, + dy, + default_recipe, + ) + default_mode_after, default_fp8_after, default_reduce_and_update_after = default_ctx_after + assert default_mode_after == "default" + assert default_fp8_after + assert default_reduce_and_update_after + + @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @pytest.mark.parametrize( "fused_pattern,expected_fused_op", @@ -994,23 +1270,24 @@ def test_grouped_linear_backward_mode_matches_reference( ("scale_add", ForwardLinearScaleAdd), ), ) +@pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) @pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) @pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) def test_fused_linear_paths_match_backward_mode_reference( recipe_name: str, fused_pattern: str, expected_fused_op: type, + in_features: int, + out_features: int, m: int, dtype: torch.dtype, backward_mode: str, ) -> None: _maybe_skip_recipe_dtype(recipe_name, dtype) _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") - _maybe_skip_unsupported_recipe_shape(recipe_name, (m, 64), "ops_linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") reset_rng_states() - in_features = 64 - out_features = 64 quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) @@ -1137,10 +1414,12 @@ def test_fused_linear_paths_match_backward_mode_reference( @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) +@pytest.mark.parametrize("out_features", _output_feature_cases) @pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) def test_fused_bias_activation_matches_masked_linear_backward( recipe_name: str, input_shape: tuple[int, ...], + out_features: int, dtype: torch.dtype, backward_mode: str, ) -> None: @@ -1150,7 +1429,6 @@ def test_fused_bias_activation_matches_masked_linear_backward( reset_rng_states() in_features = input_shape[-1] - out_features = 64 quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) @@ -1273,9 +1551,12 @@ def test_fused_bias_activation_matches_masked_linear_backward( @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) @pytest.mark.parametrize("dtype", _core_dtypes, ids=str) def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( recipe_name: str, + in_features: int, + out_features: int, dtype: torch.dtype, backward_mode: str, monkeypatch: pytest.MonkeyPatch, @@ -1294,8 +1575,6 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") # Build a Userbuffers-eligible fuser and representative inputs. - in_features = 64 - out_features = 64 linear = te_ops.BasicLinear( in_features, out_features, From 6ac9050b8df52b845a9621cc5d8529af43ecb2e7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Mar 2026 13:06:59 -0700 Subject: [PATCH 50/61] Add `test_backward_mode_memory_peak_report` Signed-off-by: Ziang Li --- tests/pytorch/test_backward_mode.py | 160 ++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py index 6c0fea2da4..73625af52b 100644 --- a/tests/pytorch/test_backward_mode.py +++ b/tests/pytorch/test_backward_mode.py @@ -1636,3 +1636,163 @@ def test_quantize_op_respects_backward_mode( assert_close(y_override, y_ref, rtol=0, atol=0, check_dtype=True) assert_close(dx_override, dx_ref, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("module_type", ("linear", "layernorm_linear")) +def test_backward_mode_memory_peak_report( + recipe_name: str, + module_type: str, +) -> None: + """Diagnostic-only memory report for default/unquant/dequant backward modes.""" + reset_rng_states() + dtype = torch.bfloat16 + input_shape = (2048, 2048) + out_features = 2048 * 4 + in_features = input_shape[-1] + use_bias = True + + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) + + base_module = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") + + modes = ("default", "unquant", "dequant") + mode_results: dict[str, dict[str, float] | str] = {} + + for mode in modes: + try: + mode_recipe = make_recipe(recipe_name, backward_mode=mode) + + # Keep params identical across modes for a cleaner apples-to-apples read. + module = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + _copy_named_parameters(base_module, module) + + # Warmup run to reduce first-use kernel setup noise. + _run_single_step(module, x, dy, mode_recipe) + + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = te.autocast(enabled=True, recipe=mode_recipe) + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + fwd_start_mem = torch.cuda.memory_allocated() + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + torch.cuda.synchronize() + fwd_peak_alloc = float(torch.cuda.max_memory_allocated() - fwd_start_mem) + fwd_peak_reserved = float(torch.cuda.max_memory_reserved()) + + torch.cuda.reset_peak_memory_stats() + bwd_start_mem = torch.cuda.memory_allocated() + y.backward(dy) + torch.cuda.synchronize() + bwd_peak_alloc = float(torch.cuda.max_memory_allocated() - bwd_start_mem) + bwd_peak_reserved = float(torch.cuda.max_memory_reserved()) + + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = te.autocast(enabled=True, recipe=mode_recipe) + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + e2e_start_mem = torch.cuda.memory_allocated() + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + torch.cuda.synchronize() + e2e_peak_alloc = float(torch.cuda.max_memory_allocated() - e2e_start_mem) + e2e_peak_reserved = float(torch.cuda.max_memory_reserved()) + + mode_results[mode] = { + "fwd_peak_alloc_mb": fwd_peak_alloc / (1024**2), + "fwd_peak_reserved_mb": fwd_peak_reserved / (1024**2), + "bwd_peak_alloc_mb": bwd_peak_alloc / (1024**2), + "bwd_peak_reserved_mb": bwd_peak_reserved / (1024**2), + "e2e_peak_alloc_mb": e2e_peak_alloc / (1024**2), + "e2e_peak_reserved_mb": e2e_peak_reserved / (1024**2), + } + except Exception as exc: # pragma: no cover - diagnostic reporting path + mode_results[mode] = f"{type(exc).__name__}: {exc}" + + print( + "\n[backward_mode_memory_peak_report] " + f"recipe={recipe_name} module_type={module_type} " + f"dtype={dtype} input_shape={input_shape} out_features={out_features}" + ) + print(" units=MB") + metric_col_width = 9 + delta_col_width = 18 + columns = ( + ("mode", metric_col_width), + ("fwd_alloc", metric_col_width), + ("bwd_alloc", metric_col_width), + ("e2e_alloc", metric_col_width), + ("fwd_resrv", metric_col_width), + ("bwd_resrv", metric_col_width), + ("e2e_resrv", metric_col_width), + ("delta_fwd", delta_col_width), + ("delta_bwd", delta_col_width), + ("delta_e2e", delta_col_width), + ) + print(" | ".join(f"{name:>{width}}" for name, width in columns)) + print("-+-".join("-" * width for _, width in columns)) + + def _format_delta_with_pct(delta: float, base: float) -> str: + if math.isclose(base, 0.0, abs_tol=1e-12): + return f"{delta:+.2f} (n/a)" + pct = 100.0 * delta / base + return f"{delta:+.2f} ({pct:+.2f}%)" + + default_metrics = mode_results.get("default") + for mode in modes: + metrics = mode_results[mode] + if isinstance(metrics, str): + print(f"{mode:>{metric_col_width}} | ERROR: {metrics}") + continue + + if isinstance(default_metrics, dict): + delta_fwd = metrics["fwd_peak_alloc_mb"] - default_metrics["fwd_peak_alloc_mb"] + delta_bwd = metrics["bwd_peak_alloc_mb"] - default_metrics["bwd_peak_alloc_mb"] + delta_e2e = metrics["e2e_peak_alloc_mb"] - default_metrics["e2e_peak_alloc_mb"] + delta_fwd_str = _format_delta_with_pct(delta_fwd, default_metrics["fwd_peak_alloc_mb"]) + delta_bwd_str = _format_delta_with_pct(delta_bwd, default_metrics["bwd_peak_alloc_mb"]) + delta_e2e_str = _format_delta_with_pct(delta_e2e, default_metrics["e2e_peak_alloc_mb"]) + else: + delta_fwd_str = "n/a" + delta_bwd_str = "n/a" + delta_e2e_str = "n/a" + + print( + f"{mode:>{metric_col_width}} | " + f"{metrics['fwd_peak_alloc_mb']:{metric_col_width}.2f} | " + f"{metrics['bwd_peak_alloc_mb']:{metric_col_width}.2f} | " + f"{metrics['e2e_peak_alloc_mb']:{metric_col_width}.2f} | " + f"{metrics['fwd_peak_reserved_mb']:{metric_col_width}.2f} | " + f"{metrics['bwd_peak_reserved_mb']:{metric_col_width}.2f} | " + f"{metrics['e2e_peak_reserved_mb']:{metric_col_width}.2f} | " + f"{delta_fwd_str:>{delta_col_width}} | " + f"{delta_bwd_str:>{delta_col_width}} | " + f"{delta_e2e_str:>{delta_col_width}}" + ) From a2b52504cb25e57358b0b24cdefdef52c4fe21f3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 11 Mar 2026 20:39:55 -0700 Subject: [PATCH 51/61] Expand test coverage and fix Signed-off-by: Ziang Li --- tests/pytorch/test_backward_mode.py | 14 ++++++- tests/pytorch/test_cpu_offloading.py | 33 +++++++++++++++-- tests/pytorch/test_cuda_graphs.py | 14 ++++++- tests/pytorch/test_sanity.py | 37 +++++++++++++++++-- tests/pytorch/utils.py | 23 ++++++++++++ .../pytorch/module/grouped_linear.py | 9 ++++- .../pytorch/module/layernorm_linear.py | 2 + transformer_engine/pytorch/module/linear.py | 2 + .../pytorch/ops/basic/basic_linear.py | 4 +- .../float8_blockwise_tensor_storage.py | 4 ++ .../tensor/storage/mxfp8_tensor_storage.py | 2 + .../tensor/storage/nvfp4_tensor_storage.py | 2 + 12 files changed, 135 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py index 73625af52b..3aa47e1166 100644 --- a/tests/pytorch/test_backward_mode.py +++ b/tests/pytorch/test_backward_mode.py @@ -25,7 +25,12 @@ ) from transformer_engine.pytorch.quantized_tensor import restore_from_saved -from utils import assert_close, make_recipe, reset_rng_states +from utils import ( + assert_close, + make_recipe, + reset_rng_states, + skip_unsupported_backward_mode, +) # -------------------------- @@ -803,6 +808,7 @@ def test_linear_like_backward_mode_matches_reference( in_features = input_shape[-1] quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode(module_type, mode_recipe, backward_mode) module_quantized_ref = _make_linear_like_module( module_type, @@ -1154,6 +1160,7 @@ def test_linear_like_runtime_backward_mode_switch_updates_ctx( default_recipe = make_recipe(recipe_name, backward_mode="default") mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode(module_type, mode_recipe, backward_mode) *_, default_ctx = _run_single_step_with_ctx_state(module, x, dy, default_recipe) ( @@ -1291,6 +1298,7 @@ def test_fused_linear_paths_match_backward_mode_reference( quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode("ops_linear", mode_recipe, backward_mode) model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) model_bwd_mode = _make_fused_model(fused_pattern, in_features, out_features, dtype) @@ -1432,6 +1440,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode("ops_linear", mode_recipe, backward_mode) model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) model_bwd_mode = _make_fused_model("bias_activation", in_features, out_features, dtype) @@ -1593,6 +1602,7 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( extra_inputs = [() for _ in range(fuser._num_basic_ops)] quant_recipe = make_recipe(recipe_name, backward_mode="default") + skip_unsupported_backward_mode("ops_linear", quant_recipe, backward_mode) fuser.maybe_fuse_ops( is_grad_enabled=True, recipe=quant_recipe, @@ -1602,6 +1612,7 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( assert any(isinstance(op, UserbuffersForwardLinear) for op, _ in fuser._forward_ops) non_quant_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode("ops_linear", non_quant_recipe, backward_mode) current_recipe["value"] = non_quant_recipe fuser.maybe_fuse_ops( is_grad_enabled=True, @@ -1630,6 +1641,7 @@ def test_quantize_op_respects_backward_mode( model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=False)) mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode("ops_linear", mode_recipe, backward_mode) y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, mode_recipe) y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, mode_recipe) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 7da8dcf863..af8e3b884e 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -6,6 +6,7 @@ import contextlib import pytest import os +import copy import torch from typing import Optional, List from transformer_engine.pytorch.cpu_offload import ( @@ -18,7 +19,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te from transformer_engine.common import recipe -from utils import ModelConfig +from utils import ModelConfig, skip_unsupported_backward_mode import transformer_engine_torch as tex # Check supported quantization schemes @@ -416,9 +417,14 @@ def test_multiple_tensor_offload(self, recipe): class TestTELayers: @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("recipe", quantization_recipes) - def test_sanity(self, layer_type, recipe): + @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) + def test_sanity(self, layer_type, recipe, backward_mode): Utils.memory_leak_check() + skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + if recipe is not None: + recipe = copy.deepcopy(recipe) + recipe.backward_mode = backward_mode # Skip ops-based layers with Float8BlockScaling recipe if ( layer_type in ["linear_op", "layernorm_mlp_ops"] @@ -458,9 +464,15 @@ def test_sanity(self, layer_type, recipe): @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("recipe", quantization_recipes) - def test_memory(self, layer_type, recipe): + @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) + def test_memory(self, layer_type, recipe, backward_mode): Utils.memory_leak_check() + skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + if recipe is not None: + recipe = copy.deepcopy(recipe) + recipe.backward_mode = backward_mode + # Skip ops-based layers with Float8BlockScaling recipe if ( layer_type in ["linear_op", "layernorm_mlp_ops"] @@ -537,9 +549,15 @@ def test_memory(self, layer_type, recipe): @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("recipe", quantization_recipes) - def test_manual_synchronization(self, recipe, layer_type): + @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) + def test_manual_synchronization(self, recipe, layer_type, backward_mode): Utils.memory_leak_check() + skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + if recipe is not None: + recipe = copy.deepcopy(recipe) + recipe.backward_mode = backward_mode + # Skip ops-based layers with Float8BlockScaling recipe if ( layer_type in ["linear_op", "layernorm_mlp_ops"] @@ -600,6 +618,7 @@ def test_manual_synchronization(self, recipe, layer_type): out_2.sum().backward() @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("use_cuda_graphs", [True, False]) @pytest.mark.parametrize("retain_pinned_cpu_buffers", [True, False]) @@ -607,11 +626,17 @@ def test_manual_synchronization(self, recipe, layer_type): def test_numerics( self, recipe, + backward_mode, layer_type, use_cuda_graphs, backend, retain_pinned_cpu_buffers, ): + skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + if recipe is not None: + recipe = copy.deepcopy(recipe) + recipe.backward_mode = backward_mode + # Skip ops-based layers with Float8BlockScaling recipe if ( layer_type in ["linear_op", "layernorm_mlp_ops"] diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 1b9e11792e..bf304dc240 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, Iterable, List, Tuple, Union import pytest +import copy import torch from transformer_engine.pytorch import ( @@ -24,7 +25,7 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe -from utils import ModelConfig, reset_rng_states +from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_mode # Check if FP8 is supported. fp8_available = is_fp8_available() @@ -360,6 +361,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("backward_mode", ("default", "unquant", "dequant")) def test_make_graphed_callables( *, module: str, @@ -368,10 +370,17 @@ def test_make_graphed_callables( dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, + backward_mode: str, fp8_weight_caching: bool = False, ) -> None: fp8 = fp8_recipe is not None + + skip_unsupported_backward_mode(module, fp8_recipe, backward_mode) + if fp8: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: @@ -440,18 +449,21 @@ def test_make_graphed_callables( @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("backward_mode", ("default", "unquant", "dequant")) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, + backward_mode: str, ) -> None: test_make_graphed_callables( module=module, dtype=dtype, fp8_params=fp8_params, fp8_recipe=fp8_recipe, + backward_mode=backward_mode, fp8_weight_caching=True, ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 384b6774f6..fd82996cbb 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -7,6 +7,7 @@ import torch import pytest import os +import copy import transformer_engine import transformer_engine.pytorch as te @@ -37,7 +38,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.tensor.utils import replace_raw_data -from utils import ModelConfig +from utils import ModelConfig, skip_unsupported_backward_mode # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -383,6 +384,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @@ -392,6 +394,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz def test_sanity_layernorm_linear( dtype, fp8_recipe, + backward_mode, model, skip_wgrad, zero_centered_gamma, @@ -401,6 +404,11 @@ def test_sanity_layernorm_linear( ): config = model_configs[model] + skip_unsupported_backward_mode("layernorm_linear", fp8_recipe, backward_mode) + if fp8_recipe is not None: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -424,13 +432,21 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("microbatching", all_boolean) -def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching): +def test_sanity_linear( + dtype, fp8_recipe, backward_mode, model, skip_wgrad, skip_dgrad, microbatching +): config = model_configs[model] + skip_unsupported_backward_mode("linear", fp8_recipe, backward_mode) + if fp8_recipe is not None: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -454,13 +470,21 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) -def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): +def test_sanity_linear_with_zero_tokens( + dtype, bs, model, fp8_recipe, backward_mode, fp8_model_params, use_bias +): config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size num_tokens = bs * config.max_seqlen_q + skip_unsupported_backward_mode("linear", fp8_recipe, backward_mode) + if fp8_recipe is not None: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -487,6 +511,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("single_param", all_boolean) @@ -497,6 +522,7 @@ def test_sanity_grouped_linear( bs, model, fp8_recipe, + backward_mode, fp8_model_params, use_bias, single_param, @@ -509,6 +535,11 @@ def test_sanity_grouped_linear( bs = bs * 16 num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) + skip_unsupported_backward_mode("grouped_linear", fp8_recipe, backward_mode) + if fp8_recipe is not None: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index dfd0c73738..830ca6eecc 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -10,6 +10,7 @@ from typing import Optional, Tuple, Dict, Any, List from packaging.version import Version as PkgVersion +import pytest import torch import transformer_engine @@ -149,6 +150,28 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: raise ValueError(f"Unsupported quantization scheme ({name})") +def skip_unsupported_backward_mode( + layer_type: str, + quant_recipe: Recipe, + backward_mode: str, +) -> None: + """Skip known unsupported layer/recipe/backward-mode combinations used in tests.""" + if backward_mode is None or backward_mode == "default": + return + if quant_recipe is None and backward_mode in ("unquant", "dequant"): + pytest.skip(f"Not a quantized recipe, cannot use backward mode {backward_mode}.") + if quant_recipe.delayed() and backward_mode in ("unquant", "dequant"): + pytest.skip(f"Delayed scaling does not support backward mode {backward_mode}.") + if layer_type in ( + "layernorm_mlp", + "layernorm_mlp_nocheckpoint", + "layernorm_mlp_checkpoint", + "transformer", + "transformer_layer", + ): + pytest.skip(f"{layer_type} does not support NVTE_BACKWARD_MODE={backward_mode}.") + + # Cached RNG state _rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 5350e7e6f6..ef97e72ce2 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -443,7 +443,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], for weight in weights ] elif ctx.backward_mode == "unquant": - weights_for_dgrad = origin_weights + weights_for_dgrad = [ + ( + weight.dequantize(dtype=ctx.activation_dtype) + if isinstance(weight, QuantizedTensorStorage) + else cast_if_needed(weight, ctx.activation_dtype) + ) + for weight in origin_weights + ] # Make sure weights are available in column-wise format # for dgrad computation. for weight in weights_for_dgrad: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 91c67ce3b0..31f65a5239 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -774,6 +774,8 @@ def backward( weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) elif ctx.backward_mode == "unquant": weight_for_dgrad = origin_weight + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f3a35ecade..6e54f9de5d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -762,6 +762,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) elif ctx.backward_mode == "unquant": weight_for_dgrad = weight + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 7f21cd9331..f2e03fa087 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -360,7 +360,9 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: grad_output_quantizer.optimize_for_gemm = True if FP8GlobalStateManager.is_fp8_enabled(): fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + if fp8_recipe.backward_mode in ("unquant", "dequant") and ( + fp8_recipe.mxfp8() or fp8_recipe.nvfp4() + ): if input_quantizer is not None: input_quantizer.optimize_for_gemm = False if grad_output_quantizer is not None: diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 52e292125e..2cede7e832 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -222,6 +222,10 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ if dtype is None: dtype = self._dtype + + if 0 in self.size(): + return torch.empty(self.size(), dtype=dtype, device=self.device) + block_len = 128 if not self._is_2D_scaled: return self._dequantize_vectorwise(dtype=dtype) diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 7bbe809c9d..34d507dd7e 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -182,6 +182,8 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" if dtype is None: dtype = self._dtype + if self._rowwise_data is not None and 0 in self._rowwise_data.size(): + return torch.empty(self.size(), dtype=dtype, device=self.device) return _FromMXFP8Func.forward(None, self, dtype) def size(self, *args, **kwargs): diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index fb163c9032..9fdbe8d595 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -213,6 +213,8 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" if dtype is None: dtype = self._dtype + if self._rowwise_data is not None and 0 in self._rowwise_data.size(): + return torch.empty(self.size(), dtype=dtype, device=self.device) return _FromNVFP4Func.forward(None, self, dtype) def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: From 0a89acf42351c75220f5748d1383400db4ce852a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 12 Mar 2026 23:36:15 -0700 Subject: [PATCH 52/61] Use `numel()` Signed-off-by: Ziang Li --- .../pytorch/tensor/storage/float8_blockwise_tensor_storage.py | 2 +- .../pytorch/tensor/storage/mxfp8_tensor_storage.py | 2 +- .../pytorch/tensor/storage/nvfp4_tensor_storage.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 2cede7e832..ca3913762f 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -223,7 +223,7 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: if dtype is None: dtype = self._dtype - if 0 in self.size(): + if self._rowwise_data is not None and self._rowwise_data.numel() == 0: return torch.empty(self.size(), dtype=dtype, device=self.device) block_len = 128 diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 34d507dd7e..157df4d40e 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -182,7 +182,7 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" if dtype is None: dtype = self._dtype - if self._rowwise_data is not None and 0 in self._rowwise_data.size(): + if self._rowwise_data is not None and self._rowwise_data.numel() == 0: return torch.empty(self.size(), dtype=dtype, device=self.device) return _FromMXFP8Func.forward(None, self, dtype) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 9fdbe8d595..309a0ca0cf 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -213,7 +213,7 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" if dtype is None: dtype = self._dtype - if self._rowwise_data is not None and 0 in self._rowwise_data.size(): + if self._rowwise_data is not None and self._rowwise_data.numel() == 0: return torch.empty(self.size(), dtype=dtype, device=self.device) return _FromNVFP4Func.forward(None, self, dtype) From 70e04ff5f718b42236be1b715aab286b2e84cf3b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 12 Mar 2026 23:50:48 -0700 Subject: [PATCH 53/61] Refactor unit test Signed-off-by: Ziang Li --- tests/pytorch/test_backward_mode.py | 611 ++++++++++++++-------------- 1 file changed, 308 insertions(+), 303 deletions(-) diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py index 3aa47e1166..6469052c49 100644 --- a/tests/pytorch/test_backward_mode.py +++ b/tests/pytorch/test_backward_mode.py @@ -53,6 +53,32 @@ _core_dtypes.insert(1, torch.bfloat16) _fused_dtypes.insert(1, torch.bfloat16) +_quantized_numerics_recipe_list = [ + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + id="Float8CurrentScaling", + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + id="MXFP8BlockScaling", + ), + pytest.param( + "fp8_block_scaling", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, + reason=reason_for_no_fp8_block_scaling, + ), + id="Float8BlockScaling", + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4BlockScaling", + ), +] + @pytest.fixture(autouse=True) def _reset_global_fp8_state(): @@ -67,30 +93,211 @@ def backward_mode(request: pytest.FixtureRequest) -> str: return request.param +# -------------------------- +# Test cases +# -------------------------- + + +_shape_test_cases = [ + pytest.param((1, 64), 64, id="2d_m1_k64_n64"), + pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((32, 96), 96, id="2d_m32_k96_n96"), + pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), + pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), + pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), + pytest.param((160, 64), 64, id="2d_m160_k64_n64"), + pytest.param((5, 64, 64), 64, id="3d_m320_k64_n64"), + pytest.param((3, 5, 32, 64), 96, id="4d_m480_k64_n96"), + pytest.param((2, 5, 16, 128), 64, id="4d_m160_k128_n64"), + # Intentionally unaligned token dimensions to exercise skip/support logic. + pytest.param((3, 64), 64, id="2d_m3_k64_n64_unaligned"), + pytest.param((3, 10, 64), 64, id="3d_m30_k64_n64_unaligned"), + pytest.param((3, 10, 96), 96, id="3d_m30_k96_n96_unaligned"), +] + +_bias_activation_shape_cases = [ + pytest.param((32, 64), id="2d_m32_k64"), + pytest.param((32, 96), id="2d_m32_k96"), + pytest.param((8, 4, 64), id="3d_m32_k64"), + pytest.param((160, 64), id="2d_m160_k64"), + pytest.param((5, 64, 64), id="3d_m320_k64"), + pytest.param((3, 5, 32, 64), id="4d_m480_k64"), + # Intentionally unaligned token dimensions to exercise skip/support logic. + pytest.param((3, 64), id="2d_m3_k64_unaligned"), + pytest.param((3, 10, 64), id="3d_m30_k64_unaligned"), + pytest.param((3, 10, 96), id="3d_m30_k96_unaligned"), +] + +_grouped_m_split_cases = [ + pytest.param([32, 32, 32, 32], id="uniform_splits"), + pytest.param([64, 0, 32, 32], id="with_empty_split"), + pytest.param([1, 31, 0, 96], id="small_and_empty_splits"), + pytest.param([64, 192, 0, 128], id="64_divisible_splits"), +] + +_linear_feature_cases = [ + pytest.param(64, 64, id="k64_n64"), + pytest.param(64, 128, id="k64_n128"), + pytest.param(128, 64, id="k128_n64"), + pytest.param(96, 96, id="k96_n96"), + pytest.param(64, 96, id="k64_n96"), + pytest.param(96, 64, id="k96_n64"), + pytest.param(128, 96, id="k128_n96"), + pytest.param(96, 128, id="k96_n128"), +] + +_output_feature_cases = [ + pytest.param(64, id="n64"), + pytest.param(96, id="n96"), + pytest.param(128, id="n128"), +] + +# -------------------------- +# Skip helpers +# -------------------------- + + +def _maybe_skip_recipe_dtype( + recipe_name: str, + dtype: torch.dtype, + module_type: Optional[str] = None, +) -> None: + if dtype == torch.bfloat16 and not bf16_available: + pytest.skip(reason_for_no_bf16) + if recipe_name == "nvfp4": + if module_type in ("linear", "layernorm_linear") and dtype not in ( + torch.bfloat16, + torch.float32, + ): + pytest.skip(f"NVFP4 only supports BF16 and FP32 for {module_type} in this test") + elif module_type in ("ops_linear", "grouped_linear") and dtype != torch.bfloat16: + pytest.skip(f"NVFP4 only supports BF16 for {module_type} in this test") + + +def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: + if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": + pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + + +def _maybe_skip_unsupported_recipe_shape( + recipe_name: str, + input_shape: tuple[int, ...], + module_type: str, +) -> None: + flat_first_dim = math.prod(input_shape[:-1]) + last_dim = input_shape[-1] + + if module_type in ("linear", "layernorm_linear"): + if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): + pytest.skip( + "Linear/LayerNormLinear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible" + " by 32." + ) + return + if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + pytest.skip( + "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" + " by 16." + ) + return + if flat_first_dim % 8 != 0 or last_dim % 16 != 0: + pytest.skip( + "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " + "and shape[-1] divisible by 16." + ) + elif module_type == "ops_linear": + if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): + pytest.skip( + "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." + ) + if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + pytest.skip( + "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." + ) + + +def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: + non_empty_splits = [m for m in m_splits if m > 0] + if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): + pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") + if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits): + pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") + if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): + pytest.skip( + "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " + "m_split divisible by 64 due to grouped amax kernel constraints." + ) + if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): + pytest.skip( + "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." + ) + + # -------------------------- # Shared helpers # -------------------------- -def _restore_saved_operands(output: torch.Tensor) -> list[Optional[torch.Tensor]]: - if output.grad_fn is None: - raise RuntimeError("Output tensor has no grad_fn; cannot inspect saved operands") - if not hasattr(output.grad_fn, "tensor_objects"): - raise RuntimeError("grad_fn does not expose tensor_objects for saved operand restoration") - return restore_from_saved(output.grad_fn.tensor_objects, list(output.grad_fn.saved_tensors)) +def _make_linear_like_module( + module_type: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + *, + bias: bool, +) -> torch.nn.Module: + if module_type == "linear": + return te.Linear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "layernorm_linear": + return te.LayerNormLinear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "ops_linear": + return te_ops.Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device="cuda", + ) + raise ValueError(f"Unsupported module type: {module_type}") -def _extract_linear_saved_operands( - saved_operands: list[Optional[torch.Tensor]], +def _make_fused_model( + pattern: str, + in_features: int, + out_features: int, + dtype: torch.dtype, *, - context: str, -) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - if len(saved_operands) < 2: - raise RuntimeError( - f"Insufficient saved operands for {context} dequant reference " - f"(got {len(saved_operands)}, expected at least 2)." + scale: float = 0.5, +) -> te_ops.Sequential: + if pattern == "bias_activation": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.ReLU(), + ) + if pattern == "bias_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.AddExtraInput(in_place=True), + ) + if pattern == "scale_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), ) - return saved_operands[0], saved_operands[1] + raise ValueError(f"Unsupported fused test pattern: {pattern}") def _dequantize_saved_operand( @@ -117,32 +324,6 @@ def _dequantize_saved_operand( return saved_operand.dequantize(dtype=dtype) -def _assert_saved_quantized_operand_uses_rowwise_only( - saved_operand: Optional[torch.Tensor], - *, - name: str, -) -> None: - if saved_operand is None: - raise RuntimeError(f"Expected quantized saved {name} operand but got None") - if isinstance(saved_operand, torch.Tensor): - raise RuntimeError( - f"Dequant reference expects quantized saved {name} operand, got torch.Tensor." - ) - if not hasattr(saved_operand, "dequantize"): - raise RuntimeError(f"Unsupported saved {name} operand type: {type(saved_operand)}") - if hasattr(saved_operand, "_rowwise_data") and getattr(saved_operand, "_rowwise_data") is None: - raise RuntimeError( - f"Saved dequant {name} operand lost row-wise fprop payload (likely usage retarget)." - ) - if ( - hasattr(saved_operand, "_columnwise_data") - and getattr(saved_operand, "_columnwise_data") is not None - ): - raise RuntimeError( - f"Saved dequant {name} operand unexpectedly carries column-wise payload." - ) - - def _snapshot_saved_quantized_operand_layout( saved_operand: Optional[torch.Tensor], *, @@ -168,6 +349,67 @@ def _snapshot_saved_quantized_operand_layout( } +def _snapshot_layout_invariants( + guard_operands: list[tuple[str, Optional[torch.Tensor]]], +) -> list[dict[str, object]]: + """Capture saved-operand layout invariants before backward runs.""" + return [ + _snapshot_saved_quantized_operand_layout(saved_operand, name=name) + for name, saved_operand in guard_operands + ] + + +def _snapshot_backward_ctx_state( + output: torch.Tensor, +) -> tuple[str, bool, object, bool]: + if output.grad_fn is None: + raise RuntimeError("Output tensor has no grad_fn; cannot inspect backward context state.") + required_attrs = ( + "backward_mode", + "fp8", + "grad_output_quantizer", + "reduce_and_update_bwd_fp8_tensors", + ) + missing_attrs = [attr for attr in required_attrs if not hasattr(output.grad_fn, attr)] + if missing_attrs: + raise RuntimeError( + "grad_fn does not expose required backward context attributes: " + f"{', '.join(missing_attrs)}." + ) + return ( + getattr(output.grad_fn, "backward_mode"), + bool(getattr(output.grad_fn, "fp8")), + getattr(output.grad_fn, "grad_output_quantizer"), + bool(getattr(output.grad_fn, "reduce_and_update_bwd_fp8_tensors")), + ) + + +def _assert_saved_quantized_operand_uses_rowwise_only( + saved_operand: Optional[torch.Tensor], + *, + name: str, +) -> None: + if saved_operand is None: + raise RuntimeError(f"Expected quantized saved {name} operand but got None") + if isinstance(saved_operand, torch.Tensor): + raise RuntimeError( + f"Dequant reference expects quantized saved {name} operand, got torch.Tensor." + ) + if not hasattr(saved_operand, "dequantize"): + raise RuntimeError(f"Unsupported saved {name} operand type: {type(saved_operand)}") + if hasattr(saved_operand, "_rowwise_data") and getattr(saved_operand, "_rowwise_data") is None: + raise RuntimeError( + f"Saved dequant {name} operand lost row-wise fprop payload (likely usage retarget)." + ) + if ( + hasattr(saved_operand, "_columnwise_data") + and getattr(saved_operand, "_columnwise_data") is not None + ): + raise RuntimeError( + f"Saved dequant {name} operand unexpectedly carries column-wise payload." + ) + + def _assert_saved_quantized_operand_layout_unchanged(snapshot: dict[str, object]) -> None: name = snapshot.get("name") if not isinstance(name, str): @@ -206,16 +448,6 @@ def _assert_saved_quantized_operand_layout_unchanged(snapshot: dict[str, object] ) -def _snapshot_layout_invariants( - guard_operands: list[tuple[str, Optional[torch.Tensor]]], -) -> list[dict[str, object]]: - """Capture saved-operand layout invariants before backward runs.""" - return [ - _snapshot_saved_quantized_operand_layout(saved_operand, name=name) - for name, saved_operand in guard_operands - ] - - def _assert_layout_invariants_unchanged(layout_invariants: list[dict[str, object]]) -> None: """Validate saved-operand layout invariants after backward runs.""" for layout_invariant in layout_invariants: @@ -228,6 +460,15 @@ def _raise_if_ref_failed(ref_exc: Optional[Exception]) -> None: raise ref_exc +def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: + src_params = dict(src_module.named_parameters()) + with torch.no_grad(): + for name, dst_param in dst_module.named_parameters(): + if name not in src_params: + raise RuntimeError(f"Parameter {name} missing in source module") + dst_param.copy_(src_params[name]) + + def _compute_linear_backward_reference_from_saved_operands( saved_input: Optional[torch.Tensor], saved_weight: Optional[torch.Tensor], @@ -283,179 +524,6 @@ def _compute_linear_backward_reference_from_saved_operands( return dx_ref, dw_ref, db_ref -_quantized_numerics_recipe_list = [ - pytest.param( - "fp8_current_scaling", - marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), - id="Float8CurrentScaling", - ), - pytest.param( - "mxfp8", - marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), - id="MXFP8BlockScaling", - ), - pytest.param( - "fp8_block_scaling", - marks=pytest.mark.skipif( - not fp8_block_scaling_available, - reason=reason_for_no_fp8_block_scaling, - ), - id="Float8BlockScaling", - ), - pytest.param( - "nvfp4", - marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), - id="NVFP4BlockScaling", - ), -] - -_shape_test_cases = [ - pytest.param((1, 64), 64, id="2d_m1_k64_n64"), - pytest.param((32, 64), 64, id="2d_m32_k64_n64"), - pytest.param((32, 96), 96, id="2d_m32_k96_n96"), - pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), - pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), - pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), - pytest.param((160, 64), 64, id="2d_m160_k64_n64"), - pytest.param((5, 64, 64), 64, id="3d_m320_k64_n64"), - pytest.param((3, 5, 32, 64), 96, id="4d_m480_k64_n96"), - pytest.param((2, 5, 16, 128), 64, id="4d_m160_k128_n64"), - # Intentionally unaligned token dimensions to exercise skip/support logic. - pytest.param((3, 64), 64, id="2d_m3_k64_n64_unaligned"), - pytest.param((3, 10, 64), 64, id="3d_m30_k64_n64_unaligned"), - pytest.param((3, 10, 96), 96, id="3d_m30_k96_n96_unaligned"), -] - -_bias_activation_shape_cases = [ - pytest.param((32, 64), id="2d_m32_k64"), - pytest.param((32, 96), id="2d_m32_k96"), - pytest.param((8, 4, 64), id="3d_m32_k64"), - pytest.param((160, 64), id="2d_m160_k64"), - pytest.param((5, 64, 64), id="3d_m320_k64"), - pytest.param((3, 5, 32, 64), id="4d_m480_k64"), - # Intentionally unaligned token dimensions to exercise skip/support logic. - pytest.param((3, 64), id="2d_m3_k64_unaligned"), - pytest.param((3, 10, 64), id="3d_m30_k64_unaligned"), - pytest.param((3, 10, 96), id="3d_m30_k96_unaligned"), -] - -_grouped_m_split_cases = [ - pytest.param([32, 32, 32, 32], id="uniform_splits"), - pytest.param([64, 0, 32, 32], id="with_empty_split"), - pytest.param([1, 31, 0, 96], id="small_and_empty_splits"), -] - -_linear_feature_cases = [ - pytest.param(64, 64, id="k64_n64"), - pytest.param(64, 128, id="k64_n128"), - pytest.param(128, 64, id="k128_n64"), - pytest.param(96, 96, id="k96_n96"), - pytest.param(64, 96, id="k64_n96"), - pytest.param(96, 64, id="k96_n64"), - pytest.param(128, 96, id="k128_n96"), - pytest.param(96, 128, id="k96_n128"), -] - -_output_feature_cases = [ - pytest.param(64, id="n64"), - pytest.param(96, id="n96"), - pytest.param(128, id="n128"), -] - - -def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: - src_params = dict(src_module.named_parameters()) - with torch.no_grad(): - for name, dst_param in dst_module.named_parameters(): - if name not in src_params: - raise RuntimeError(f"Parameter {name} missing in source module") - dst_param.copy_(src_params[name]) - - -def _maybe_skip_recipe_dtype(recipe_name: str, dtype: torch.dtype) -> None: - if dtype == torch.bfloat16 and not bf16_available: - pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4" and dtype != torch.bfloat16: - pytest.skip("NVFP4 is only supported with BF16 in this test") - - -def _make_linear_like_module( - module_type: str, - in_features: int, - out_features: int, - dtype: torch.dtype, - *, - bias: bool, -) -> torch.nn.Module: - if module_type == "linear": - return te.Linear( - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device="cuda", - ) - if module_type == "layernorm_linear": - return te.LayerNormLinear( - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device="cuda", - ) - if module_type == "ops_linear": - return te_ops.Linear( - in_features, - out_features, - bias=bias, - dtype=dtype, - device="cuda", - ) - raise ValueError(f"Unsupported module type: {module_type}") - - -def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: - if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": - pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - - -def _maybe_skip_unsupported_recipe_shape( - recipe_name: str, - input_shape: tuple[int, ...], - module_type: str, -) -> None: - flat_first_dim = math.prod(input_shape[:-1]) - last_dim = input_shape[-1] - - if module_type in ("linear", "layernorm_linear"): - if flat_first_dim % 8 != 0 or last_dim % 16 != 0: - pytest.skip( - "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " - "and shape[-1] divisible by 16." - ) - return - - if module_type == "ops_linear": - if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): - pytest.skip( - "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." - ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): - pytest.skip( - "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." - ) - - -def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: - non_empty_splits = [m for m in m_splits if m > 0] - if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): - pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): - pytest.skip( - "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." - ) - - def _run_single_step( module: torch.nn.Module, x: torch.Tensor, @@ -499,7 +567,7 @@ def _run_single_step_with_saved_operands( y = module(x_run) if isinstance(y, tuple): y = y[0] - saved_operands = _restore_saved_operands(y) + saved_operands = restore_from_saved(y.grad_fn.tensor_objects, list(y.grad_fn.saved_tensors)) return y, x_run, saved_operands @@ -544,37 +612,10 @@ def _run_grouped_linear_step_with_saved_operands( x_run = x.detach().clone().requires_grad_(True) with te.autocast(enabled=True, recipe=fp8_recipe): y = module(x_run, m_splits) - saved_operands = _restore_saved_operands(y) + saved_operands = restore_from_saved(y.grad_fn.tensor_objects, list(y.grad_fn.saved_tensors)) return y, x_run, saved_operands -def _make_fused_model( - pattern: str, - in_features: int, - out_features: int, - dtype: torch.dtype, - *, - scale: float = 0.5, -) -> te_ops.Sequential: - if pattern == "bias_activation": - return te_ops.Sequential( - te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), - te_ops.ReLU(), - ) - if pattern == "bias_add": - return te_ops.Sequential( - te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), - te_ops.AddExtraInput(in_place=True), - ) - if pattern == "scale_add": - return te_ops.Sequential( - te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), - te_ops.ConstantScale(scale), - te_ops.AddExtraInput(in_place=True), - ) - raise ValueError(f"Unsupported fused test pattern: {pattern}") - - def _run_fused_single_step( pattern: str, model: te_ops.Sequential, @@ -635,7 +676,7 @@ def _run_fused_single_step_with_saved_operands( y = model(x1_run, x2_run) else: y = model(x1_run) - saved_operands = _restore_saved_operands(y) + saved_operands = restore_from_saved(y.grad_fn.tensor_objects, list(y.grad_fn.saved_tensors)) return y, x1_run, x2_run, saved_operands @@ -656,31 +697,6 @@ def _run_quantize_op_single_step( return y.detach().clone(), x_run.grad.detach().clone() -def _snapshot_backward_ctx_state( - output: torch.Tensor, -) -> tuple[str, bool, object, bool]: - if output.grad_fn is None: - raise RuntimeError("Output tensor has no grad_fn; cannot inspect backward context state.") - required_attrs = ( - "backward_mode", - "fp8", - "grad_output_quantizer", - "reduce_and_update_bwd_fp8_tensors", - ) - missing_attrs = [attr for attr in required_attrs if not hasattr(output.grad_fn, attr)] - if missing_attrs: - raise RuntimeError( - "grad_fn does not expose required backward context attributes: " - f"{', '.join(missing_attrs)}." - ) - return ( - getattr(output.grad_fn, "backward_mode"), - bool(getattr(output.grad_fn, "fp8")), - getattr(output.grad_fn, "grad_output_quantizer"), - bool(getattr(output.grad_fn, "reduce_and_update_bwd_fp8_tensors")), - ) - - def _run_single_step_with_ctx_state( module: torch.nn.Module, x: torch.Tensor, @@ -801,7 +817,7 @@ def test_linear_like_backward_mode_matches_reference( backward_mode: str, ) -> None: reset_rng_states() - _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_recipe_dtype(recipe_name, dtype, module_type) _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) @@ -921,10 +937,7 @@ def test_linear_like_backward_mode_matches_reference( ) dx_ref = dx_ref.view_as(x_bwd_mode) else: - saved_input, saved_weight = _extract_linear_saved_operands( - saved_operands, - context=f"{module_type}", - ) + saved_input, saved_weight = saved_operands[0], saved_operands[1] guard_operands.extend( [ (f"{module_type}_input", saved_input), @@ -982,11 +995,10 @@ def test_grouped_linear_backward_mode_matches_reference( dtype: torch.dtype, backward_mode: str, ) -> None: - if recipe_name == "nvfp4": - pytest.skip("NVFP4 not supported for grouped linear") reset_rng_states() - _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_recipe_dtype(recipe_name, dtype, "grouped_linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "grouped_linear") _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) num_gemms = len(m_splits) num_tokens = sum(m_splits) @@ -1144,7 +1156,7 @@ def test_linear_like_runtime_backward_mode_switch_updates_ctx( backward_mode: str, ) -> None: reset_rng_states() - _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_recipe_dtype(recipe_name, dtype, module_type) _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) @@ -1210,11 +1222,10 @@ def test_grouped_linear_runtime_backward_mode_switch_updates_ctx( dtype: torch.dtype, backward_mode: str, ) -> None: - if recipe_name == "nvfp4": - pytest.skip("NVFP4 not supported for grouped linear") reset_rng_states() - _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_recipe_dtype(recipe_name, dtype, "grouped_linear") + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "grouped_linear") _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) num_tokens = sum(m_splits) @@ -1290,7 +1301,7 @@ def test_fused_linear_paths_match_backward_mode_reference( dtype: torch.dtype, backward_mode: str, ) -> None: - _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") @@ -1362,10 +1373,7 @@ def test_fused_linear_paths_match_backward_mode_reference( guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] ref_exc: Optional[Exception] = None try: - saved_input, saved_weight = _extract_linear_saved_operands( - saved_operands, - context=f"fused_{fused_pattern}", - ) + saved_input, saved_weight = saved_operands[0], saved_operands[1] guard_operands.extend( [ (f"fused_{fused_pattern}_input", saved_input), @@ -1431,7 +1439,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( dtype: torch.dtype, backward_mode: str, ) -> None: - _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "ops_linear") @@ -1501,10 +1509,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] ref_exc: Optional[Exception] = None try: - saved_input, saved_weight = _extract_linear_saved_operands( - saved_operands, - context="fused_bias_activation", - ) + saved_input, saved_weight = saved_operands[0], saved_operands[1] guard_operands.extend( [ ("fused_bias_activation_input", saved_input), @@ -1630,7 +1635,7 @@ def test_quantize_op_respects_backward_mode( dtype: torch.dtype, backward_mode: str, ) -> None: - _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") reset_rng_states() @@ -1664,7 +1669,7 @@ def test_backward_mode_memory_peak_report( in_features = input_shape[-1] use_bias = True - _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_recipe_dtype(recipe_name, dtype, module_type) _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) From 0849eb150bada261eff6bcd7645a6f999c064c9d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 13 Mar 2026 00:41:36 -0700 Subject: [PATCH 54/61] Fix grouped linear to override `*_quantizers` instead of `*_quantizer` Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index ef97e72ce2..1451aa6ca6 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -333,9 +333,9 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False - ctx.grad_input_quantizer = None - ctx.grad_weight_quantizer = None - ctx.grad_output_quantizer = None + ctx.grad_input_quantizers = [None] * num_gemms + ctx.grad_weight_quantizers = [None] * num_gemms + ctx.grad_output_quantizers = [None] * num_gemms ctx.reduce_and_update_bwd_fp8_tensors = False # [*, in_features] -> [*, out_features] except first dimension changes for SP From ebe7e13977c91e1e333a4d3335762deb579c8d8d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 13 Mar 2026 01:07:37 -0700 Subject: [PATCH 55/61] Only save input/weight when `*_requires_grad` on unquant mode Signed-off-by: Ziang Li --- .../pytorch/ops/fused/forward_linear_bias_activation.py | 3 ++- .../pytorch/ops/fused/forward_linear_bias_add.py | 3 ++- .../pytorch/ops/fused/forward_linear_scale_add.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 7584891384..51c010d6da 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -125,7 +125,8 @@ def fuser_forward( if linear_op_ctx.requires_grad: saved_input, saved_weight = x_local, w if backward_mode == "unquant": - saved_input, saved_weight = input_, linear_op.weight + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 6935330f4e..24557004e7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -122,7 +122,8 @@ def fuser_forward( if linear_op_ctx.requires_grad: saved_input, saved_weight = x_local, w if backward_mode == "unquant": - saved_input, saved_weight = input_, linear_op.weight + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 2358140c88..63a1031c86 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -103,7 +103,8 @@ def fuser_forward( if linear_op_ctx.requires_grad: saved_input, saved_weight = x_local, w if backward_mode == "unquant": - saved_input, saved_weight = input_, linear_op.weight + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From 21d744c3407111a12a68626ad3702e7f120451aa Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 14 Mar 2026 00:33:59 -0700 Subject: [PATCH 56/61] Fix Blackwell debug ci Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 1 + transformer_engine/pytorch/module/layernorm_linear.py | 1 + transformer_engine/pytorch/module/linear.py | 1 + .../pytorch/tensor/storage/mxfp8_tensor_storage.py | 5 +++++ .../pytorch/tensor/storage/nvfp4_tensor_storage.py | 4 ++++ 5 files changed, 12 insertions(+) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1451aa6ca6..cabbc8930e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -329,6 +329,7 @@ def forward( # Non-quantized backward mode overrides if backward_mode in ("unquant", "dequant"): ctx.fp8 = False + ctx.debug = False ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 31f65a5239..e1dd660d50 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -541,6 +541,7 @@ def forward( # Non-quantized backward mode overrides if backward_mode in ("unquant", "dequant"): ctx.fp8 = False + ctx.debug = False ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6e54f9de5d..9e142a0b02 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -509,6 +509,7 @@ def forward( # Non-quantized backward mode overrides if backward_mode in ("unquant", "dequant"): ctx.fp8 = False + ctx.debug = False ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 157df4d40e..842f42838b 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -30,6 +30,11 @@ def forward( dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring + if tensor._rowwise_data is not None and tensor._rowwise_data.numel() == 0: + return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) + if tensor._columnwise_data is not None and tensor._columnwise_data.numel() == 0: + return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) + dtype = torch_to_transformer_engine_dtype[dtype] # Make sure FP8 data is in expected format diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 309a0ca0cf..70699ad71a 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -42,6 +42,10 @@ def forward( dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring + if tensor._rowwise_data is not None and tensor._rowwise_data.numel() == 0: + return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) + if tensor._columnwise_data is not None and tensor._columnwise_data.numel() == 0: + return torch.empty(tensor.size(), dtype=dtype, device=tensor.device) # Dequantize row-wise data if tensor._rowwise_data is not None: From 52ed18959f9333535bda8d549a152f2ee5690b94 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 14 Mar 2026 02:36:35 -0700 Subject: [PATCH 57/61] Fix sm89 and sm90 tests Signed-off-by: Ziang Li --- tests/pytorch/test_backward_mode.py | 35 +++++++++++++++++++++++++--- tests/pytorch/test_cpu_offloading.py | 8 ++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py index 6469052c49..03ebc8c0c3 100644 --- a/tests/pytorch/test_backward_mode.py +++ b/tests/pytorch/test_backward_mode.py @@ -16,6 +16,7 @@ from transformer_engine.common import recipe from transformer_engine.pytorch.cpp_extensions import general_gemm, layernorm_bwd from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, ForwardLinearBiasActivation, @@ -206,6 +207,15 @@ def _maybe_skip_unsupported_recipe_shape( "and shape[-1] divisible by 16." ) elif module_type == "ops_linear": + if ( + recipe_name == "fp8_current_scaling" + and not is_non_tn_fp8_gemm_supported() + and flat_first_dim % 16 != 0 + ): + pytest.skip( + "te_ops.Linear + Float8CurrentScaling on pre-Blackwell requires " + "prod(shape[:-1]) divisible by 16 for FP8 NT wgrad GEMM." + ) if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." @@ -218,6 +228,15 @@ def _maybe_skip_unsupported_recipe_shape( def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: non_empty_splits = [m for m in m_splits if m > 0] + if ( + recipe_name == "fp8_current_scaling" + and not is_non_tn_fp8_gemm_supported() + and any(m % 16 != 0 for m in non_empty_splits) + ): + pytest.skip( + "GroupedLinear + Float8CurrentScaling on pre-Blackwell requires each " + "non-empty m_split divisible by 16 for FP8 grouped NT wgrad GEMM." + ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits): @@ -476,6 +495,7 @@ def _compute_linear_backward_reference_from_saved_operands( *, dequant_dtype: torch.dtype, out_dtype: torch.dtype, + with_bias: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Dequant reference path: # 1) use the exact operands saved by quantized forward, @@ -506,10 +526,12 @@ def _compute_linear_backward_reference_from_saved_operands( out_dtype=out_dtype, layout="NN", grad=True, + use_split_accumulator=True, + ) + db_seed = ( + torch.empty(dy_mat.shape[-1], dtype=out_dtype, device=dy_mat.device) if with_bias else None ) - # Derive db from the same GEMM primitive used by runtime wgrad. This avoids - # tiny reduction-order drift vs. a standalone dy.sum() path in FP32 cases. - db_seed = torch.empty(dy_mat.shape[-1], dtype=out_dtype, device=dy_mat.device) + # Derive db from the same GEMM primitive used by runtime wgrad when bias exists. dw_ref, db_ref, *_ = general_gemm( x_ref, dy_mat, @@ -517,6 +539,7 @@ def _compute_linear_backward_reference_from_saved_operands( layout="NT", grad=True, bias=db_seed, + use_split_accumulator=True, ) if db_ref is None: db_ref = dy_mat.sum(dim=0).to(out_dtype) @@ -915,6 +938,7 @@ def test_linear_like_backward_mode_matches_reference( dy, dequant_dtype=dtype, out_dtype=dtype, + with_bias=use_bias, ) ) input_ref = _dequantize_saved_operand(saved_input, dtype) @@ -944,12 +968,14 @@ def test_linear_like_backward_mode_matches_reference( (f"{module_type}_weight", saved_weight), ] ) + linear_wgrad_with_bias = use_bias and module_type != "ops_linear" dx_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( saved_input, saved_weight, dy, dequant_dtype=dtype, out_dtype=dtype, + with_bias=linear_wgrad_with_bias, ) if module_type == "ops_linear" and use_bias: # te_ops bias grad is reduced by the Bias op from incoming dy. @@ -1101,6 +1127,7 @@ def test_grouped_linear_backward_mode_matches_reference( dy_chunk, dequant_dtype=dtype, out_dtype=dtype, + with_bias=use_bias, ) dx_chunks.append(dx_i) dw_ref.append(dw_i) @@ -1387,6 +1414,7 @@ def test_fused_linear_paths_match_backward_mode_reference( dy_for_linear, dequant_dtype=dtype, out_dtype=dtype, + with_bias=False, ) dx2_ref = dy if x2 is not None else None except Exception as exc: @@ -1522,6 +1550,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( dy_after_activation, dequant_dtype=dtype, out_dtype=dtype, + with_bias=False, ) except Exception as exc: ref_exc = exc diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index af8e3b884e..514dcdf501 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -536,7 +536,13 @@ def test_memory(self, layer_type, recipe, backward_mode): out = out + 1 out = sync_function(out) del inp - assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + if backward_mode == "default": + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + else: + assert ( + Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + or Utils.get_cuda_memory_mb() <= init_cuda_memory + ) offloaded_memory_cpu = offload_ctx.offload_synchronizer.get_offloaded_total_size_mb() # This assertion verifies that the memory used by tensors on the CPU matches the memory saved from a layer. From 9f475a1f44398199d62a8f4d26821e4d608626e9 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 14 Mar 2026 12:28:48 -0700 Subject: [PATCH 58/61] Fix unquant mode memory saving Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/basic/basic_linear.py | 8 ++++++-- .../pytorch/ops/fused/forward_linear_bias_activation.py | 8 +++++--- .../pytorch/ops/fused/forward_linear_bias_add.py | 8 +++++--- .../pytorch/ops/fused/forward_linear_scale_add.py | 8 +++++--- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index f2e03fa087..accc5fbe6a 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -1028,8 +1028,12 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: - saved_input = input_ if backward_mode == "unquant" else x_local - saved_weight = self.weight if backward_mode == "unquant" else w + if backward_mode == "unquant": + saved_input = input_ if weight_requires_grad else None + saved_weight = self.weight if input_requires_grad else None + else: + saved_input = x_local + saved_weight = w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 51c010d6da..19d2f679fb 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -123,10 +123,12 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input, saved_weight = x_local, w if backward_mode == "unquant": - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if weight_requires_grad else None + saved_weight = linear_op.weight if input_requires_grad else None + else: + saved_input = x_local + saved_weight = w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 24557004e7..5d2997a50a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -120,10 +120,12 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input, saved_weight = x_local, w if backward_mode == "unquant": - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if weight_requires_grad else None + saved_weight = linear_op.weight if input_requires_grad else None + else: + saved_input = x_local + saved_weight = w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 63a1031c86..c1eeac484f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -101,10 +101,12 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input, saved_weight = x_local, w if backward_mode == "unquant": - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if weight_requires_grad else None + saved_weight = linear_op.weight if input_requires_grad else None + else: + saved_input = x_local + saved_weight = w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From 416490736e22317d25eae397c0419d8d5ec3b3e5 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 14 Mar 2026 14:14:25 -0700 Subject: [PATCH 59/61] Refactor interface to `NVTE_BACKWARD_OVERRIDE=high_precision|dequantized` Signed-off-by: Ziang Li --- tests/pytorch/test_backward_mode.py | 318 +++++++++--------- tests/pytorch/test_cpu_offloading.py | 36 +- tests/pytorch/test_cuda_graphs.py | 16 +- tests/pytorch/test_sanity.py | 34 +- tests/pytorch/utils.py | 20 +- transformer_engine/common/recipe/__init__.py | 100 +++--- transformer_engine/pytorch/module/base.py | 2 +- .../pytorch/module/grouped_linear.py | 30 +- .../pytorch/module/layernorm_linear.py | 32 +- .../pytorch/module/layernorm_mlp.py | 13 +- transformer_engine/pytorch/module/linear.py | 30 +- .../pytorch/ops/basic/basic_linear.py | 30 +- transformer_engine/pytorch/ops/basic/bias.py | 2 +- .../pytorch/ops/basic/quantize.py | 4 +- .../ops/fused/backward_activation_bias.py | 4 +- .../fused/forward_linear_bias_activation.py | 14 +- .../ops/fused/forward_linear_bias_add.py | 14 +- .../ops/fused/forward_linear_scale_add.py | 12 +- .../ops/fused/userbuffers_forward_linear.py | 12 +- transformer_engine/pytorch/ops/fuser.py | 12 +- 20 files changed, 373 insertions(+), 362 deletions(-) diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py index 03ebc8c0c3..ed4f73adbc 100644 --- a/tests/pytorch/test_backward_mode.py +++ b/tests/pytorch/test_backward_mode.py @@ -30,7 +30,7 @@ assert_close, make_recipe, reset_rng_states, - skip_unsupported_backward_mode, + skip_unsupported_backward_override, ) @@ -38,7 +38,7 @@ # Mode and capability config # -------------------------- -_NON_QUANT_BACKWARD_MODES = ("unquant", "dequant") +_BACKWARD_OVERRIDES = ("high_precision", "dequantized") fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) @@ -88,9 +88,9 @@ def _reset_global_fp8_state(): FP8GlobalStateManager.reset() -@pytest.fixture(params=_NON_QUANT_BACKWARD_MODES, ids=lambda mode: f"mode_{mode}") -def backward_mode(request: pytest.FixtureRequest) -> str: - """Backward mode under test.""" +@pytest.fixture(params=_BACKWARD_OVERRIDES, ids=lambda mode: f"mode_{mode}") +def backward_override(request: pytest.FixtureRequest) -> str: + """backward override under test.""" return request.param @@ -325,7 +325,7 @@ def _dequantize_saved_operand( ) -> torch.Tensor: if saved_operand is None: raise RuntimeError("Expected saved operand but got None") - # In dequant mode we must consume the fprop-saved quantized payload directly. + # In dequantized mode we must consume the fprop-saved quantized payload directly. # If row-wise payload is missing, the tensor was retargeted to a transpose-only # layout and no longer represents the original fprop operand. if ( @@ -334,7 +334,7 @@ def _dequantize_saved_operand( and getattr(saved_operand, "_rowwise_data") is None ): raise RuntimeError( - "Saved dequant operand lost row-wise fprop payload (likely usage retarget)." + "Saved dequantized operand lost row-wise fprop payload (likely usage retarget)." ) if isinstance(saved_operand, torch.Tensor): return saved_operand.to(dtype) @@ -384,7 +384,7 @@ def _snapshot_backward_ctx_state( if output.grad_fn is None: raise RuntimeError("Output tensor has no grad_fn; cannot inspect backward context state.") required_attrs = ( - "backward_mode", + "backward_override", "fp8", "grad_output_quantizer", "reduce_and_update_bwd_fp8_tensors", @@ -396,7 +396,7 @@ def _snapshot_backward_ctx_state( f"{', '.join(missing_attrs)}." ) return ( - getattr(output.grad_fn, "backward_mode"), + getattr(output.grad_fn, "backward_override"), bool(getattr(output.grad_fn, "fp8")), getattr(output.grad_fn, "grad_output_quantizer"), bool(getattr(output.grad_fn, "reduce_and_update_bwd_fp8_tensors")), @@ -412,20 +412,20 @@ def _assert_saved_quantized_operand_uses_rowwise_only( raise RuntimeError(f"Expected quantized saved {name} operand but got None") if isinstance(saved_operand, torch.Tensor): raise RuntimeError( - f"Dequant reference expects quantized saved {name} operand, got torch.Tensor." + f"dequantized reference expects quantized saved {name} operand, got torch.Tensor." ) if not hasattr(saved_operand, "dequantize"): raise RuntimeError(f"Unsupported saved {name} operand type: {type(saved_operand)}") if hasattr(saved_operand, "_rowwise_data") and getattr(saved_operand, "_rowwise_data") is None: raise RuntimeError( - f"Saved dequant {name} operand lost row-wise fprop payload (likely usage retarget)." + f"Saved dequantized {name} operand lost row-wise fprop payload (likely usage retarget)." ) if ( hasattr(saved_operand, "_columnwise_data") and getattr(saved_operand, "_columnwise_data") is not None ): raise RuntimeError( - f"Saved dequant {name} operand unexpectedly carries column-wise payload." + f"Saved dequantized {name} operand unexpectedly carries column-wise payload." ) @@ -442,7 +442,7 @@ def _assert_saved_quantized_operand_layout_unchanged(snapshot: dict[str, object] rowwise_now = rowwise_data_now is not None if rowwise_now != rowwise_present: raise RuntimeError( - f"Saved dequant {name} operand row-wise payload presence changed " + f"Saved dequantized {name} operand row-wise payload presence changed " f"from {rowwise_present} to {rowwise_now}." ) # Guard against hidden requantization that swaps in a new row-wise payload. @@ -453,7 +453,7 @@ def _assert_saved_quantized_operand_layout_unchanged(snapshot: dict[str, object] and id(rowwise_data_now) != rowwise_obj_id ): raise RuntimeError( - f"Saved dequant {name} operand row-wise payload identity changed " + f"Saved dequantized {name} operand row-wise payload identity changed " "(likely rewritten/requantized)." ) @@ -462,7 +462,7 @@ def _assert_saved_quantized_operand_layout_unchanged(snapshot: dict[str, object] columnwise_now = getattr(saved_operand, "_columnwise_data", None) is not None if columnwise_now != columnwise_present: raise RuntimeError( - f"Saved dequant {name} operand column-wise payload presence changed " + f"Saved dequantized {name} operand column-wise payload presence changed " f"from {columnwise_present} to {columnwise_now}." ) @@ -497,7 +497,7 @@ def _compute_linear_backward_reference_from_saved_operands( out_dtype: torch.dtype, with_bias: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Dequant reference path: + # dequantized reference path: # 1) use the exact operands saved by quantized forward, # 2) dequantize them to the active high-precision compute dtype, # 3) run backward GEMMs in high precision and compare exactly. @@ -509,7 +509,9 @@ def _compute_linear_backward_reference_from_saved_operands( if dy_mat.shape[0] == 0: out_features = dy_mat.shape[-1] if saved_input is None: - raise RuntimeError("Expected saved input operand for empty-chunk dequant reference.") + raise RuntimeError( + "Expected saved input operand for empty-chunk dequantized reference." + ) in_features = saved_input.size(-1) dx_ref = torch.zeros(*dy.shape[:-1], in_features, dtype=out_dtype, device=dy.device) dw_ref = torch.zeros(out_features, in_features, dtype=out_dtype, device=dy.device) @@ -781,7 +783,7 @@ def _run_grouped_linear_single_step_with_ctx_state( "Output tensor has no grad_fn; cannot inspect grouped backward state." ) required_attrs = ( - "backward_mode", + "backward_override", "fp8", "reduce_and_update_bwd_fp8_tensors", ) @@ -792,7 +794,7 @@ def _run_grouped_linear_single_step_with_ctx_state( f"{', '.join(missing_attrs)}." ) ctx_state = ( - getattr(y.grad_fn, "backward_mode"), + getattr(y.grad_fn, "backward_override"), bool(getattr(y.grad_fn, "fp8")), bool(getattr(y.grad_fn, "reduce_and_update_bwd_fp8_tensors")), ) @@ -815,14 +817,14 @@ def _run_grouped_linear_single_step_with_ctx_state( @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) -def test_backward_mode_recipe_matches_requested_mode( +def test_backward_override_recipe_matches_requested_mode( recipe_name: str, - backward_mode: str, + backward_override: str, ) -> None: - mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) - quant_recipe = make_recipe(recipe_name, backward_mode="default") - assert mode_recipe.backward_mode == backward_mode - assert quant_recipe.backward_mode == "default" + mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + quant_recipe = make_recipe(recipe_name) + assert mode_recipe.backward_override == backward_override + assert quant_recipe.backward_override is None @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -830,14 +832,14 @@ def test_backward_mode_recipe_matches_requested_mode( @pytest.mark.parametrize("input_shape,out_features", _shape_test_cases) @pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) @pytest.mark.parametrize("dtype", _core_dtypes, ids=str) -def test_linear_like_backward_mode_matches_reference( +def test_linear_like_backward_override_matches_reference( recipe_name: str, module_type: str, input_shape: tuple[int, ...], out_features: int, use_bias: bool, dtype: torch.dtype, - backward_mode: str, + backward_override: str, ) -> None: reset_rng_states() _maybe_skip_recipe_dtype(recipe_name, dtype, module_type) @@ -845,9 +847,9 @@ def test_linear_like_backward_mode_matches_reference( _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") - mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) - skip_unsupported_backward_mode(module_type, mode_recipe, backward_mode) + quantized_ref_recipe = make_recipe(recipe_name) + mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override(module_type, mode_recipe, backward_override) module_quantized_ref = _make_linear_like_module( module_type, @@ -870,8 +872,8 @@ def test_linear_like_backward_mode_matches_reference( dy = torch.randn(*output_shape, dtype=dtype, device="cuda") y_quantized_ref, _, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) - if backward_mode == "unquant": - # Unquant reference path: compare against a plain high-precision backward run + if backward_override == "high_precision": + # high_precision reference path: compare against a plain high-precision backward run # (no fp8/autocast), starting from the same params and inputs. module_unquantized_ref = _make_linear_like_module( module_type, @@ -894,7 +896,7 @@ def test_linear_like_backward_mode_matches_reference( None, ) else: - # Dequant reference path: capture saved forward operands from the real dequant-mode + # dequantized reference path: capture saved forward operands from the real dequantized-override # execution, then rebuild backward reference from those saved operands. y_bwd_mode, x_bwd_mode, saved_operands = _run_single_step_with_saved_operands( module_bwd_mode, x, mode_recipe @@ -909,14 +911,14 @@ def test_linear_like_backward_mode_matches_reference( ref_exc: Optional[Exception] = None try: if module_type == "layernorm_linear": - # LayerNormLinear dequant reference: + # LayerNormLinear dequantized reference: # 1) Compute d(ln_out), dw, db from linear backward with saved operands. # 2) Compute exact dx via layernorm_bwd with saved norm statistics. # _LayerNormLinear forward saves operands as: # [inputmat, weightmat, origin_weight, bias, ln_weight, ln_out, mu, rsigma, ...] if len(saved_operands) < 8: raise RuntimeError( - "Insufficient saved operands for layernorm_linear dequant reference " + "Insufficient saved operands for layernorm_linear dequantized reference " f"(got {len(saved_operands)}, expected at least 8)." ) saved_input = saved_operands[0] @@ -1012,14 +1014,14 @@ def test_linear_like_backward_mode_matches_reference( @pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) @pytest.mark.parametrize("m_splits", _grouped_m_split_cases) @pytest.mark.parametrize("dtype", _core_dtypes, ids=str) -def test_grouped_linear_backward_mode_matches_reference( +def test_grouped_linear_backward_override_matches_reference( recipe_name: str, in_features: int, out_features: int, use_bias: bool, m_splits: list[int], dtype: torch.dtype, - backward_mode: str, + backward_override: str, ) -> None: reset_rng_states() @@ -1029,8 +1031,8 @@ def test_grouped_linear_backward_mode_matches_reference( num_gemms = len(m_splits) num_tokens = sum(m_splits) - quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") - mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + quantized_ref_recipe = make_recipe(recipe_name) + mode_recipe = make_recipe(recipe_name, backward_override=backward_override) module_quantized_ref = te.GroupedLinear( num_gemms, @@ -1060,8 +1062,8 @@ def test_grouped_linear_backward_mode_matches_reference( dy, quantized_ref_recipe, ) - if backward_mode == "unquant": - # Unquant reference path: grouped module in plain high precision. + if backward_override == "high_precision": + # high_precision reference path: grouped module in plain high precision. module_unquantized_ref = te.GroupedLinear( num_gemms, in_features, @@ -1086,7 +1088,7 @@ def test_grouped_linear_backward_mode_matches_reference( None, ) else: - # Dequant reference path for grouped GEMMs: + # dequantized reference path for grouped GEMMs: # each GEMM restores its own saved input/weight pair and computes its own ref grads. y_bwd_mode, x_bwd_mode, saved_operands = _run_grouped_linear_step_with_saved_operands( module_bwd_mode, x, m_splits, mode_recipe @@ -1102,7 +1104,7 @@ def test_grouped_linear_backward_mode_matches_reference( try: if len(saved_operands) < 2 * num_gemms: raise RuntimeError( - "Insufficient saved operands for GroupedLinear dequant reference " + "Insufficient saved operands for GroupedLinear dequantized reference " f"(got {len(saved_operands)}, expected at least {2 * num_gemms})." ) @@ -1173,14 +1175,14 @@ def test_grouped_linear_backward_mode_matches_reference( @pytest.mark.parametrize("input_shape,out_features", _shape_test_cases) @pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) @pytest.mark.parametrize("dtype", _core_dtypes, ids=str) -def test_linear_like_runtime_backward_mode_switch_updates_ctx( +def test_linear_like_runtime_backward_override_switch_updates_ctx( recipe_name: str, module_type: str, input_shape: tuple[int, ...], out_features: int, use_bias: bool, dtype: torch.dtype, - backward_mode: str, + backward_override: str, ) -> None: reset_rng_states() _maybe_skip_recipe_dtype(recipe_name, dtype, module_type) @@ -1197,9 +1199,9 @@ def test_linear_like_runtime_backward_mode_switch_updates_ctx( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - default_recipe = make_recipe(recipe_name, backward_mode="default") - mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) - skip_unsupported_backward_mode(module_type, mode_recipe, backward_mode) + default_recipe = make_recipe(recipe_name) + mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override(module_type, mode_recipe, backward_override) *_, default_ctx = _run_single_step_with_ctx_state(module, x, dy, default_recipe) ( @@ -1208,7 +1210,7 @@ def test_linear_like_runtime_backward_mode_switch_updates_ctx( default_grad_output_quantizer, default_reduce_and_update, ) = default_ctx - assert default_mode == "default" + assert default_mode is None assert default_fp8 assert default_grad_output_quantizer is not None assert default_reduce_and_update @@ -1217,7 +1219,7 @@ def test_linear_like_runtime_backward_mode_switch_updates_ctx( switched_mode, switched_fp8, switched_grad_output_quantizer, switched_reduce_and_update = ( switched_ctx ) - assert switched_mode == backward_mode + assert switched_mode == backward_override assert not switched_fp8 assert switched_grad_output_quantizer is None assert not switched_reduce_and_update @@ -1229,7 +1231,7 @@ def test_linear_like_runtime_backward_mode_switch_updates_ctx( default_grad_output_quantizer_after, default_reduce_and_update_after, ) = default_ctx_after - assert default_mode_after == "default" + assert default_mode_after is None assert default_fp8_after assert default_grad_output_quantizer_after is not None assert default_reduce_and_update_after @@ -1240,14 +1242,14 @@ def test_linear_like_runtime_backward_mode_switch_updates_ctx( @pytest.mark.parametrize("m_splits", _grouped_m_split_cases) @pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) @pytest.mark.parametrize("dtype", _core_dtypes, ids=str) -def test_grouped_linear_runtime_backward_mode_switch_updates_ctx( +def test_grouped_linear_runtime_backward_override_switch_updates_ctx( recipe_name: str, in_features: int, out_features: int, m_splits: list[int], use_bias: bool, dtype: torch.dtype, - backward_mode: str, + backward_override: str, ) -> None: reset_rng_states() @@ -1267,8 +1269,8 @@ def test_grouped_linear_runtime_backward_mode_switch_updates_ctx( x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") - default_recipe = make_recipe(recipe_name, backward_mode="default") - mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + default_recipe = make_recipe(recipe_name) + mode_recipe = make_recipe(recipe_name, backward_override=backward_override) *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1278,7 +1280,7 @@ def test_grouped_linear_runtime_backward_mode_switch_updates_ctx( default_recipe, ) default_mode, default_fp8, default_reduce_and_update = default_ctx - assert default_mode == "default" + assert default_mode is None assert default_fp8 assert default_reduce_and_update @@ -1290,7 +1292,7 @@ def test_grouped_linear_runtime_backward_mode_switch_updates_ctx( mode_recipe, ) switched_mode, switched_fp8, switched_reduce_and_update = switched_ctx - assert switched_mode == backward_mode + assert switched_mode == backward_override assert not switched_fp8 assert not switched_reduce_and_update @@ -1302,7 +1304,7 @@ def test_grouped_linear_runtime_backward_mode_switch_updates_ctx( default_recipe, ) default_mode_after, default_fp8_after, default_reduce_and_update_after = default_ctx_after - assert default_mode_after == "default" + assert default_mode_after is None assert default_fp8_after assert default_reduce_and_update_after @@ -1318,7 +1320,7 @@ def test_grouped_linear_runtime_backward_mode_switch_updates_ctx( @pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) @pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) @pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) -def test_fused_linear_paths_match_backward_mode_reference( +def test_fused_linear_paths_match_backward_override_reference( recipe_name: str, fused_pattern: str, expected_fused_op: type, @@ -1326,7 +1328,7 @@ def test_fused_linear_paths_match_backward_mode_reference( out_features: int, m: int, dtype: torch.dtype, - backward_mode: str, + backward_override: str, ) -> None: _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") @@ -1334,9 +1336,9 @@ def test_fused_linear_paths_match_backward_mode_reference( reset_rng_states() - quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") - mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) - skip_unsupported_backward_mode("ops_linear", mode_recipe, backward_mode) + quantized_ref_recipe = make_recipe(recipe_name) + mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) model_bwd_mode = _make_fused_model(fused_pattern, in_features, out_features, dtype) @@ -1357,8 +1359,8 @@ def test_fused_linear_paths_match_backward_mode_reference( x2=x2, ) - if backward_mode == "unquant": - # Unquant reference path: replay the same fused model structure in plain + if backward_override == "high_precision": + # high_precision reference path: replay the same fused model structure in plain # high precision and compare backward outputs exactly. model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) _copy_named_parameters(model_quantized_ref, model_unquantized_ref) @@ -1380,7 +1382,7 @@ def test_fused_linear_paths_match_backward_mode_reference( x2=x2, ) else: - # Dequant reference path: compute backward reference from saved quantized + # dequantized reference path: compute backward reference from saved quantized # linear operands (with branch-specific dy handling for fused epilogues). y_bwd_mode, x1_bwd_mode, x2_bwd_mode_ref, saved_operands = ( _run_fused_single_step_with_saved_operands( @@ -1465,7 +1467,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( input_shape: tuple[int, ...], out_features: int, dtype: torch.dtype, - backward_mode: str, + backward_override: str, ) -> None: _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") @@ -1474,9 +1476,9 @@ def test_fused_bias_activation_matches_masked_linear_backward( reset_rng_states() in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") - mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) - skip_unsupported_backward_mode("ops_linear", mode_recipe, backward_mode) + quantized_ref_recipe = make_recipe(recipe_name) + mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) model_bwd_mode = _make_fused_model("bias_activation", in_features, out_features, dtype) @@ -1493,8 +1495,8 @@ def test_fused_bias_activation_matches_masked_linear_backward( quantized_ref_recipe, ) - if backward_mode == "unquant": - # Unquant reference path: build a plain linear reference and apply the + if backward_override == "high_precision": + # high_precision reference path: build a plain linear reference and apply the # same activation mask (from quantized forward output) before backward. linear_unquantized_ref = _make_linear_like_module( "ops_linear", @@ -1520,7 +1522,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( None, ) else: - # Dequant reference path: restore saved linear operands from fused forward, + # dequantized reference path: restore saved linear operands from fused forward, # apply the same activation mask, then run linear backward reference. y_bwd_mode, x1_bwd_mode, _, saved_operands = _run_fused_single_step_with_saved_operands( "bias_activation", @@ -1576,7 +1578,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( assert len(fused_ops) >= 1 assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) - # In unquant/dequant modes, backward-activation+bias fusion should be disabled. + # In high_precision/dequantized modes, backward-activation+bias fusion should be disabled. bwd_mode_backward_ops = model_bwd_mode._module_groups[0]._backward_ops assert not any(isinstance(op, BackwardActivationBias) for op, _ in bwd_mode_backward_ops) @@ -1596,12 +1598,12 @@ def test_fused_bias_activation_matches_masked_linear_backward( @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) @pytest.mark.parametrize("dtype", _core_dtypes, ids=str) -def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( +def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_override_switch( recipe_name: str, in_features: int, out_features: int, dtype: torch.dtype, - backward_mode: str, + backward_override: str, monkeypatch: pytest.MonkeyPatch, ) -> None: # Simulate a distributed setup to exercise Userbuffers fusion eligibility @@ -1611,7 +1613,7 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( # Use a mutable recipe holder so we can switch fusion behavior on the same # fuser object and verify that the cached fusion plan is refreshed. - current_recipe = {"value": make_recipe(recipe_name, backward_mode="default")} + current_recipe = {"value": make_recipe(recipe_name)} monkeypatch.setattr(FP8GlobalStateManager, "get_fp8_recipe", lambda: current_recipe["value"]) reset_rng_states() @@ -1635,8 +1637,8 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( x = torch.randn(32, in_features, dtype=dtype, device="cuda", requires_grad=True) extra_inputs = [() for _ in range(fuser._num_basic_ops)] - quant_recipe = make_recipe(recipe_name, backward_mode="default") - skip_unsupported_backward_mode("ops_linear", quant_recipe, backward_mode) + quant_recipe = make_recipe(recipe_name) + skip_unsupported_backward_override("ops_linear", quant_recipe, backward_override) fuser.maybe_fuse_ops( is_grad_enabled=True, recipe=quant_recipe, @@ -1645,8 +1647,8 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( ) assert any(isinstance(op, UserbuffersForwardLinear) for op, _ in fuser._forward_ops) - non_quant_recipe = make_recipe(recipe_name, backward_mode=backward_mode) - skip_unsupported_backward_mode("ops_linear", non_quant_recipe, backward_mode) + non_quant_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("ops_linear", non_quant_recipe, backward_override) current_recipe["value"] = non_quant_recipe fuser.maybe_fuse_ops( is_grad_enabled=True, @@ -1659,10 +1661,10 @@ def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @pytest.mark.parametrize("dtype", _core_dtypes, ids=str) -def test_quantize_op_respects_backward_mode( +def test_quantize_op_respects_backward_override( recipe_name: str, dtype: torch.dtype, - backward_mode: str, + backward_override: str, ) -> None: _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") @@ -1674,8 +1676,8 @@ def test_quantize_op_respects_backward_mode( model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=False)) - mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) - skip_unsupported_backward_mode("ops_linear", mode_recipe, backward_mode) + mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, mode_recipe) y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, mode_recipe) @@ -1686,11 +1688,11 @@ def test_quantize_op_respects_backward_mode( @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @pytest.mark.parametrize("module_type", ("linear", "layernorm_linear")) -def test_backward_mode_memory_peak_report( +def test_backward_override_memory_peak_report( recipe_name: str, module_type: str, ) -> None: - """Diagnostic-only memory report for default/unquant/dequant backward modes.""" + """Diagnostic-only memory report for None/high_precision/dequantized backward overrides.""" reset_rng_states() dtype = torch.bfloat16 input_shape = (2048, 2048) @@ -1713,77 +1715,78 @@ def test_backward_mode_memory_peak_report( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - modes = ("default", "unquant", "dequant") + modes = (None, "high_precision", "dequantized") mode_results: dict[str, dict[str, float] | str] = {} for mode in modes: - try: - mode_recipe = make_recipe(recipe_name, backward_mode=mode) - - # Keep params identical across modes for a cleaner apples-to-apples read. - module = _make_linear_like_module( - module_type, - in_features, - out_features, - dtype, - bias=use_bias, - ) - _copy_named_parameters(base_module, module) - - # Warmup run to reduce first-use kernel setup noise. - _run_single_step(module, x, dy, mode_recipe) - - module.zero_grad(set_to_none=True) - x_run = x.detach().clone().requires_grad_(True) - autocast_ctx = te.autocast(enabled=True, recipe=mode_recipe) - - torch.cuda.synchronize() - torch.cuda.reset_peak_memory_stats() - fwd_start_mem = torch.cuda.memory_allocated() - with autocast_ctx: - y = module(x_run) - if isinstance(y, tuple): - y = y[0] - torch.cuda.synchronize() - fwd_peak_alloc = float(torch.cuda.max_memory_allocated() - fwd_start_mem) - fwd_peak_reserved = float(torch.cuda.max_memory_reserved()) - - torch.cuda.reset_peak_memory_stats() - bwd_start_mem = torch.cuda.memory_allocated() - y.backward(dy) - torch.cuda.synchronize() - bwd_peak_alloc = float(torch.cuda.max_memory_allocated() - bwd_start_mem) - bwd_peak_reserved = float(torch.cuda.max_memory_reserved()) - - module.zero_grad(set_to_none=True) - x_run = x.detach().clone().requires_grad_(True) - autocast_ctx = te.autocast(enabled=True, recipe=mode_recipe) - - torch.cuda.synchronize() - torch.cuda.reset_peak_memory_stats() - e2e_start_mem = torch.cuda.memory_allocated() - with autocast_ctx: - y = module(x_run) - if isinstance(y, tuple): - y = y[0] - y.backward(dy) - torch.cuda.synchronize() - e2e_peak_alloc = float(torch.cuda.max_memory_allocated() - e2e_start_mem) - e2e_peak_reserved = float(torch.cuda.max_memory_reserved()) - - mode_results[mode] = { - "fwd_peak_alloc_mb": fwd_peak_alloc / (1024**2), - "fwd_peak_reserved_mb": fwd_peak_reserved / (1024**2), - "bwd_peak_alloc_mb": bwd_peak_alloc / (1024**2), - "bwd_peak_reserved_mb": bwd_peak_reserved / (1024**2), - "e2e_peak_alloc_mb": e2e_peak_alloc / (1024**2), - "e2e_peak_reserved_mb": e2e_peak_reserved / (1024**2), - } - except Exception as exc: # pragma: no cover - diagnostic reporting path - mode_results[mode] = f"{type(exc).__name__}: {exc}" + mode_str = "default" if mode is None else mode + # try: + mode_recipe = make_recipe(recipe_name, backward_override=mode) + + # Keep params identical across modes for a cleaner apples-to-apples read. + module = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + _copy_named_parameters(base_module, module) + + # Warmup run to reduce first-use kernel setup noise. + _run_single_step(module, x, dy, mode_recipe) + + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = te.autocast(enabled=True, recipe=mode_recipe) + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + fwd_start_mem = torch.cuda.memory_allocated() + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + torch.cuda.synchronize() + fwd_peak_alloc = float(torch.cuda.max_memory_allocated() - fwd_start_mem) + fwd_peak_reserved = float(torch.cuda.max_memory_reserved()) + + torch.cuda.reset_peak_memory_stats() + bwd_start_mem = torch.cuda.memory_allocated() + y.backward(dy) + torch.cuda.synchronize() + bwd_peak_alloc = float(torch.cuda.max_memory_allocated() - bwd_start_mem) + bwd_peak_reserved = float(torch.cuda.max_memory_reserved()) + + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = te.autocast(enabled=True, recipe=mode_recipe) + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + e2e_start_mem = torch.cuda.memory_allocated() + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + torch.cuda.synchronize() + e2e_peak_alloc = float(torch.cuda.max_memory_allocated() - e2e_start_mem) + e2e_peak_reserved = float(torch.cuda.max_memory_reserved()) + + mode_results[mode_str] = { + "fwd_peak_alloc_mb": fwd_peak_alloc / (1024**2), + "fwd_peak_reserved_mb": fwd_peak_reserved / (1024**2), + "bwd_peak_alloc_mb": bwd_peak_alloc / (1024**2), + "bwd_peak_reserved_mb": bwd_peak_reserved / (1024**2), + "e2e_peak_alloc_mb": e2e_peak_alloc / (1024**2), + "e2e_peak_reserved_mb": e2e_peak_reserved / (1024**2), + } + # except Exception as exc: # pragma: no cover - diagnostic reporting path + # mode_results[mode_str] = f"{type(exc).__name__}: {exc}" print( - "\n[backward_mode_memory_peak_report] " + "\n[backward_override_memory_peak_report] " f"recipe={recipe_name} module_type={module_type} " f"dtype={dtype} input_shape={input_shape} out_features={out_features}" ) @@ -1791,7 +1794,7 @@ def test_backward_mode_memory_peak_report( metric_col_width = 9 delta_col_width = 18 columns = ( - ("mode", metric_col_width), + ("mode_str", delta_col_width), ("fwd_alloc", metric_col_width), ("bwd_alloc", metric_col_width), ("e2e_alloc", metric_col_width), @@ -1813,9 +1816,10 @@ def _format_delta_with_pct(delta: float, base: float) -> str: default_metrics = mode_results.get("default") for mode in modes: - metrics = mode_results[mode] + mode_str = "default" if mode is None else mode + metrics = mode_results[mode_str] if isinstance(metrics, str): - print(f"{mode:>{metric_col_width}} | ERROR: {metrics}") + print(f"{mode_str:>{delta_col_width}} | ERROR: {metrics}") continue if isinstance(default_metrics, dict): @@ -1831,7 +1835,7 @@ def _format_delta_with_pct(delta: float, base: float) -> str: delta_e2e_str = "n/a" print( - f"{mode:>{metric_col_width}} | " + f"{mode_str:>{delta_col_width}} | " f"{metrics['fwd_peak_alloc_mb']:{metric_col_width}.2f} | " f"{metrics['bwd_peak_alloc_mb']:{metric_col_width}.2f} | " f"{metrics['e2e_peak_alloc_mb']:{metric_col_width}.2f} | " diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 514dcdf501..50196782f2 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -19,7 +19,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te from transformer_engine.common import recipe -from utils import ModelConfig, skip_unsupported_backward_mode +from utils import ModelConfig, skip_unsupported_backward_override import transformer_engine_torch as tex # Check supported quantization schemes @@ -417,14 +417,14 @@ def test_multiple_tensor_offload(self, recipe): class TestTELayers: @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("recipe", quantization_recipes) - @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) - def test_sanity(self, layer_type, recipe, backward_mode): + @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) + def test_sanity(self, layer_type, recipe, backward_override): Utils.memory_leak_check() - skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + skip_unsupported_backward_override(layer_type, recipe, backward_override) if recipe is not None: recipe = copy.deepcopy(recipe) - recipe.backward_mode = backward_mode + recipe.backward_override = backward_override # Skip ops-based layers with Float8BlockScaling recipe if ( layer_type in ["linear_op", "layernorm_mlp_ops"] @@ -464,14 +464,14 @@ def test_sanity(self, layer_type, recipe, backward_mode): @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("recipe", quantization_recipes) - @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) - def test_memory(self, layer_type, recipe, backward_mode): + @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) + def test_memory(self, layer_type, recipe, backward_override): Utils.memory_leak_check() - skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + skip_unsupported_backward_override(layer_type, recipe, backward_override) if recipe is not None: recipe = copy.deepcopy(recipe) - recipe.backward_mode = backward_mode + recipe.backward_override = backward_override # Skip ops-based layers with Float8BlockScaling recipe if ( @@ -536,7 +536,7 @@ def test_memory(self, layer_type, recipe, backward_mode): out = out + 1 out = sync_function(out) del inp - if backward_mode == "default": + if backward_override is None: assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) else: assert ( @@ -555,14 +555,14 @@ def test_memory(self, layer_type, recipe, backward_mode): @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("recipe", quantization_recipes) - @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) - def test_manual_synchronization(self, recipe, layer_type, backward_mode): + @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) + def test_manual_synchronization(self, recipe, layer_type, backward_override): Utils.memory_leak_check() - skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + skip_unsupported_backward_override(layer_type, recipe, backward_override) if recipe is not None: recipe = copy.deepcopy(recipe) - recipe.backward_mode = backward_mode + recipe.backward_override = backward_override # Skip ops-based layers with Float8BlockScaling recipe if ( @@ -624,7 +624,7 @@ def test_manual_synchronization(self, recipe, layer_type, backward_mode): out_2.sum().backward() @pytest.mark.parametrize("recipe", quantization_recipes) - @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) + @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("use_cuda_graphs", [True, False]) @pytest.mark.parametrize("retain_pinned_cpu_buffers", [True, False]) @@ -632,16 +632,16 @@ def test_manual_synchronization(self, recipe, layer_type, backward_mode): def test_numerics( self, recipe, - backward_mode, + backward_override, layer_type, use_cuda_graphs, backend, retain_pinned_cpu_buffers, ): - skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + skip_unsupported_backward_override(layer_type, recipe, backward_override) if recipe is not None: recipe = copy.deepcopy(recipe) - recipe.backward_mode = backward_mode + recipe.backward_override = backward_override # Skip ops-based layers with Float8BlockScaling recipe if ( diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index bf304dc240..a782dadc60 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -25,7 +25,7 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe -from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_mode +from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_override # Check if FP8 is supported. fp8_available = is_fp8_available() @@ -361,7 +361,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) -@pytest.mark.parametrize("backward_mode", ("default", "unquant", "dequant")) +@pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables( *, module: str, @@ -370,16 +370,16 @@ def test_make_graphed_callables( dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, - backward_mode: str, + backward_override: str, fp8_weight_caching: bool = False, ) -> None: fp8 = fp8_recipe is not None - skip_unsupported_backward_mode(module, fp8_recipe, backward_mode) + skip_unsupported_backward_override(module, fp8_recipe, backward_override) if fp8: fp8_recipe = copy.deepcopy(fp8_recipe) - fp8_recipe.backward_mode = backward_mode + fp8_recipe.backward_override = backward_override if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") @@ -449,21 +449,21 @@ def test_make_graphed_callables( @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) -@pytest.mark.parametrize("backward_mode", ("default", "unquant", "dequant")) +@pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, - backward_mode: str, + backward_override: str, ) -> None: test_make_graphed_callables( module=module, dtype=dtype, fp8_params=fp8_params, fp8_recipe=fp8_recipe, - backward_mode=backward_mode, + backward_override=backward_override, fp8_weight_caching=True, ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index fd82996cbb..18502ee374 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -38,7 +38,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.tensor.utils import replace_raw_data -from utils import ModelConfig, skip_unsupported_backward_mode +from utils import ModelConfig, skip_unsupported_backward_override # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -384,7 +384,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) +@pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @@ -394,7 +394,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz def test_sanity_layernorm_linear( dtype, fp8_recipe, - backward_mode, + backward_override, model, skip_wgrad, zero_centered_gamma, @@ -404,10 +404,10 @@ def test_sanity_layernorm_linear( ): config = model_configs[model] - skip_unsupported_backward_mode("layernorm_linear", fp8_recipe, backward_mode) + skip_unsupported_backward_override("layernorm_linear", fp8_recipe, backward_override) if fp8_recipe is not None: fp8_recipe = copy.deepcopy(fp8_recipe) - fp8_recipe.backward_mode = backward_mode + fp8_recipe.backward_override = backward_override if fp8_recipe is not None: if not is_fp8_supported(config): @@ -432,20 +432,20 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) +@pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("microbatching", all_boolean) def test_sanity_linear( - dtype, fp8_recipe, backward_mode, model, skip_wgrad, skip_dgrad, microbatching + dtype, fp8_recipe, backward_override, model, skip_wgrad, skip_dgrad, microbatching ): config = model_configs[model] - skip_unsupported_backward_mode("linear", fp8_recipe, backward_mode) + skip_unsupported_backward_override("linear", fp8_recipe, backward_override) if fp8_recipe is not None: fp8_recipe = copy.deepcopy(fp8_recipe) - fp8_recipe.backward_mode = backward_mode + fp8_recipe.backward_override = backward_override if fp8_recipe is not None: if not is_fp8_supported(config): @@ -470,20 +470,20 @@ def test_sanity_linear( @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) +@pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) def test_sanity_linear_with_zero_tokens( - dtype, bs, model, fp8_recipe, backward_mode, fp8_model_params, use_bias + dtype, bs, model, fp8_recipe, backward_override, fp8_model_params, use_bias ): config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size num_tokens = bs * config.max_seqlen_q - skip_unsupported_backward_mode("linear", fp8_recipe, backward_mode) + skip_unsupported_backward_override("linear", fp8_recipe, backward_override) if fp8_recipe is not None: fp8_recipe = copy.deepcopy(fp8_recipe) - fp8_recipe.backward_mode = backward_mode + fp8_recipe.backward_override = backward_override if fp8_recipe is not None: if not is_fp8_supported(config): @@ -511,7 +511,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) +@pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("single_param", all_boolean) @@ -522,7 +522,7 @@ def test_sanity_grouped_linear( bs, model, fp8_recipe, - backward_mode, + backward_override, fp8_model_params, use_bias, single_param, @@ -535,10 +535,10 @@ def test_sanity_grouped_linear( bs = bs * 16 num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) - skip_unsupported_backward_mode("grouped_linear", fp8_recipe, backward_mode) + skip_unsupported_backward_override("grouped_linear", fp8_recipe, backward_override) if fp8_recipe is not None: fp8_recipe = copy.deepcopy(fp8_recipe) - fp8_recipe.backward_mode = backward_mode + fp8_recipe.backward_override = backward_override if fp8_recipe is not None: if not is_fp8_supported(config): diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 830ca6eecc..d783a893f6 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -150,18 +150,18 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: raise ValueError(f"Unsupported quantization scheme ({name})") -def skip_unsupported_backward_mode( +def skip_unsupported_backward_override( layer_type: str, - quant_recipe: Recipe, - backward_mode: str, + quant_recipe: Optional[Recipe], + backward_override: Optional[str], ) -> None: - """Skip known unsupported layer/recipe/backward-mode combinations used in tests.""" - if backward_mode is None or backward_mode == "default": + """Skip known unsupported layer/recipe/backward-override combinations used in tests.""" + if backward_override is None: return - if quant_recipe is None and backward_mode in ("unquant", "dequant"): - pytest.skip(f"Not a quantized recipe, cannot use backward mode {backward_mode}.") - if quant_recipe.delayed() and backward_mode in ("unquant", "dequant"): - pytest.skip(f"Delayed scaling does not support backward mode {backward_mode}.") + if quant_recipe is None and backward_override is not None: + pytest.skip(f"Not a quantized recipe, cannot use backward override {backward_override}.") + if quant_recipe.delayed() and backward_override is not None: + pytest.skip(f"Delayed scaling does not support backward override {backward_override}.") if layer_type in ( "layernorm_mlp", "layernorm_mlp_nocheckpoint", @@ -169,7 +169,7 @@ def skip_unsupported_backward_mode( "transformer", "transformer_layer", ): - pytest.skip(f"{layer_type} does not support NVTE_BACKWARD_MODE={backward_mode}.") + pytest.skip(f"{layer_type} does not support NVTE_BACKWARD_OVERRIDE={backward_override}.") # Cached RNG state diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 9058f155c4..fd44d69a7d 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,17 +11,17 @@ from pydantic.dataclasses import dataclass -_BACKWARD_MODES = ("default", "unquant", "dequant") +_BACKWARD_OVERRIDES = (None, "high_precision", "dequantized") -def _resolve_backward_mode(mode: Optional[str] = None) -> str: - """Return validated backward mode from argument or NVTE_BACKWARD_MODE env.""" +def _resolve_backward_override(mode: Optional[str] = None) -> Optional[str]: + """Return validated backward override from argument or NVTE_BACKWARD_OVERRIDE env.""" if mode is None: - mode = os.getenv("NVTE_BACKWARD_MODE", "default") - mode = mode.lower() - assert ( - mode in _BACKWARD_MODES - ), f"Invalid NVTE_BACKWARD_MODE value {mode!r}. Supported values are: default|unquant|dequant." + mode = os.getenv("NVTE_BACKWARD_OVERRIDE", None) + assert mode in _BACKWARD_OVERRIDES, ( + f"Invalid NVTE_BACKWARD_OVERRIDE value {mode!r}. Supported values are:" + " high_precision|dequantized." + ) return mode @@ -202,8 +202,8 @@ def scaling_factor_compute(amax: Tensor, `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. - backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' - Backward precision mode. Delayed scaling only supports `default`. + backward_override : {None, 'high_precision', 'dequantized'}, default = None + Backward precision mode. Delayed scaling only supports None. Notes ----- @@ -227,14 +227,14 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False - backward_mode: str = field(default_factory=_resolve_backward_mode) + backward_override: Optional[str] = field(default_factory=_resolve_backward_override) def __post_init__(self) -> None: - self.backward_mode = _resolve_backward_mode(self.backward_mode) + self.backward_override = _resolve_backward_override(self.backward_override) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." assert ( - self.backward_mode == "default" - ), "Delayed scaling only supports backward_mode=default." + self.backward_override is None + ), "Delayed scaling only supports backward_override=None." def __repr__(self) -> str: return ( @@ -245,7 +245,7 @@ def __repr__(self) -> str: f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"backward_mode={self.backward_mode}" + f"backward_override={self.backward_override}" ) @@ -259,10 +259,10 @@ class Float8CurrentScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. - backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' - Backward precision mode. `default` performs quantized backward, - `unquant` keeps original high-precision operands for backward, - and `dequant` dequantizes saved operands to the active high-precision + backward_override : {None, 'high_precision', 'dequantized'}, default = None + Backward precision mode. None does not modify backward behavior, + `high_precision` keeps original high-precision operands for backward, + and `dequantized` dequantizes saved operands to the active high-precision compute dtype (e.g. BF16/FP16/FP32) for backward. """ @@ -276,10 +276,10 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - backward_mode: str = field(default_factory=_resolve_backward_mode) + backward_override: Optional[str] = field(default_factory=_resolve_backward_override) def __post_init__(self) -> None: - self.backward_mode = _resolve_backward_mode(self.backward_mode) + self.backward_override = _resolve_backward_override(self.backward_override) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: @@ -294,7 +294,7 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"backward_mode={self.backward_mode}" + f"backward_override={self.backward_override}" ) @@ -321,10 +321,10 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. - backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' - Backward precision mode. `default` performs quantized backward, - `unquant` keeps original high-precision operands for backward, - and `dequant` dequantizes saved operands to the active high-precision + backward_override : {None, 'high_precision', 'dequantized'}, default = None + Backward precision mode. None does not modify backward behavior, + `high_precision` keeps original high-precision operands for backward, + and `dequantized` dequantizes saved operands to the active high-precision compute dtype (e.g. BF16/FP16/FP32) for backward. """ @@ -332,10 +332,10 @@ class MXFP8BlockScaling(Recipe): fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False - backward_mode: str = field(default_factory=_resolve_backward_mode) + backward_override: Optional[str] = field(default_factory=_resolve_backward_override) def __post_init__(self) -> None: - self.backward_mode = _resolve_backward_mode(self.backward_mode) + self.backward_override = _resolve_backward_override(self.backward_override) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: @@ -343,7 +343,7 @@ def __repr__(self) -> str: f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " - f"backward_mode={self.backward_mode}" + f"backward_override={self.backward_override}" ) @@ -372,10 +372,10 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. - backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' - Backward precision mode. `default` performs quantized backward, - `unquant` keeps original high-precision operands for backward, - and `dequant` dequantizes saved operands to the active high-precision + backward_override : {None, 'high_precision', 'dequantized'}, default = None + Backward precision mode. None does not modify backward behavior, + `high_precision` keeps original high-precision operands for backward, + and `dequantized` dequantizes saved operands to the active high-precision compute dtype (e.g. BF16/FP16/FP32) for backward. """ @@ -393,10 +393,10 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - backward_mode: str = field(default_factory=_resolve_backward_mode) + backward_override: Optional[str] = field(default_factory=_resolve_backward_override) def __post_init__(self) -> None: - self.backward_mode = _resolve_backward_mode(self.backward_mode) + self.backward_override = _resolve_backward_override(self.backward_override) assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" @@ -432,7 +432,7 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"backward_mode={self.backward_mode}" + f"backward_override={self.backward_override}" ) @@ -481,10 +481,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. - backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' - Backward precision mode. `default` performs quantized backward, - `unquant` keeps original high-precision operands for backward, - and `dequant` dequantizes saved operands to the active high-precision + backward_override : {None, 'high_precision', 'dequantized'}, default = None + Backward precision mode. None does not modify backward behavior, + `high_precision` keeps original high-precision operands for backward, + and `dequantized` dequantizes saved operands to the active high-precision compute dtype (e.g. BF16/FP16/FP32) for backward. """ @@ -501,10 +501,10 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False - backward_mode: str = field(default_factory=_resolve_backward_mode) + backward_override: Optional[str] = field(default_factory=_resolve_backward_override) def __post_init__(self) -> None: - self.backward_mode = _resolve_backward_mode(self.backward_mode) + self.backward_override = _resolve_backward_override(self.backward_override) assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" @@ -534,7 +534,7 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"backward_mode={self.backward_mode}, " + f"backward_override={self.backward_override}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " @@ -566,10 +566,10 @@ class CustomRecipe(Recipe): - forward: "linear_input", "linear_weight", "linear_output" - backward: "linear_grad_output", "linear_grad_input" - backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' - Backward precision mode. `default` performs quantized backward, - `unquant` keeps original high-precision operands for backward, - and `dequant` dequantizes saved operands to the active high-precision + backward_override : {None, 'high_precision', 'dequantized'}, default = None + Backward precision mode. None does not modify backward behavior, + `high_precision` keeps original high-precision operands for backward, + and `dequantized` dequantizes saved operands to the active high-precision compute dtype (e.g. BF16/FP16/FP32) for backward. """ @@ -577,14 +577,14 @@ class CustomRecipe(Recipe): fp8_dpa: bool = False fp8_mha: bool = False - backward_mode: str = field(default_factory=_resolve_backward_mode) + backward_override: Optional[str] = field(default_factory=_resolve_backward_override) def __post_init__(self) -> None: - self.backward_mode = _resolve_backward_mode(self.backward_mode) + self.backward_override = _resolve_backward_override(self.backward_override) def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"qfactory={self.qfactory}, " - f"backward_mode={self.backward_mode}" + f"backward_override={self.backward_override}" ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2ca1f1ace2..004173100b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1184,7 +1184,7 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - use_fp8_bwd = ctx.fp8 and ctx.backward_mode == "default" + use_fp8_bwd = ctx.fp8 and ctx.backward_override is None # Non-FP8 case: bgrad is fused with wgrad for this case. if not use_fp8_bwd and not ctx.debug: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index cabbc8930e..d3e3ee4aac 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -98,10 +98,10 @@ def forward( debug, ) = non_tensor_args if fp8: - backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: - backward_mode = "default" - if backward_mode == "unquant": + backward_override = None + if backward_override == "high_precision": save_original_input = True num_gemms = len(m_splits) @@ -121,11 +121,11 @@ def forward( is_grad_enabled and weight_requires_grad and not save_original_input - and backward_mode == "default" + and backward_override is None ), ) columnwise_usage = is_grad_enabled and inp.requires_grad - if backward_mode in ("unquant", "dequant"): + if backward_override is not None: columnwise_usage = False if not columnwise_usage: columnwise_usage = ( @@ -251,8 +251,8 @@ def forward( else: for inputmat in inputmats: if isinstance(inputmat, QuantizedTensorStorage): - if backward_mode in ("unquant", "dequant"): - # In dequant mode we should dequantize directly from + if backward_override is not None: + # In dequantized mode we should dequantize directly from # fprop quantized layouts without retargeting usage. inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) else: @@ -307,7 +307,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.backward_mode = backward_mode + ctx.backward_override = backward_override ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -326,8 +326,8 @@ def forward( ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - # Non-quantized backward mode overrides - if backward_mode in ("unquant", "dequant"): + # backward overrides + if backward_override is not None: ctx.fp8 = False ctx.debug = False ctx.ub_overlap_ag = False @@ -434,7 +434,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) weights_for_dgrad = weights - if ctx.backward_mode == "dequant": + if ctx.backward_override == "dequantized": weights_for_dgrad = [ ( weight.dequantize(dtype=ctx.activation_dtype) @@ -443,7 +443,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) for weight in weights ] - elif ctx.backward_mode == "unquant": + elif ctx.backward_override == "high_precision": weights_for_dgrad = [ ( weight.dequantize(dtype=ctx.activation_dtype) @@ -513,7 +513,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = torch.split( cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits ) - elif ctx.backward_mode == "dequant": + elif ctx.backward_override == "dequantized": inputmats_dequant = [] for m_split, inputmat in zip(ctx.m_splits, inputmats): if isinstance(inputmat, QuantizedTensorStorage): @@ -1147,7 +1147,9 @@ def _get_quantizers(self): grad_output_quantizers[i].internal = True grad_output_quantizers[i].optimize_for_gemm = True fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + if fp8_recipe.backward_override == "dequantized" and ( + fp8_recipe.mxfp8() or fp8_recipe.nvfp4() + ): for input_quantizer in input_quantizers: input_quantizer.optimize_for_gemm = False if torch.is_grad_enabled(): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e1dd660d50..06947d537b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,9 +141,9 @@ def forward( debug, ) = non_tensor_args if fp8: - backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: - backward_mode = "default" + backward_override = None # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -204,7 +204,7 @@ def forward( raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage( rowwise=True, - columnwise=backward_needs_input and backward_mode == "default", + columnwise=backward_needs_input and backward_override is None, ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data @@ -218,7 +218,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and backward_mode == "default" + and backward_override is None and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -242,7 +242,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out - ln_out_hp = ln_out if backward_mode == "unquant" else None + ln_out_hp = ln_out if backward_override == "high_precision" else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -306,7 +306,7 @@ def forward( elif weight_quantizer is not None: weight_quantizer.set_usage( rowwise=True, - columnwise=is_grad_enabled and backward_mode == "default", + columnwise=is_grad_enabled and backward_override is None, ) # Get quantized weight @@ -421,7 +421,7 @@ def forward( if is_grad_enabled: ln_out_to_save = ln_out - if backward_mode == "unquant": + if backward_override == "high_precision": ln_out_to_save = ln_out_hp ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( @@ -429,7 +429,7 @@ def forward( ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input and backward_mode == "default": + if backward_needs_input and backward_override is None: if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -507,7 +507,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.backward_mode = backward_mode + ctx.backward_override = backward_override ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -538,8 +538,8 @@ def forward( ctx.wgrad_store = wgrad_store ctx.debug = debug - # Non-quantized backward mode overrides - if backward_mode in ("unquant", "dequant"): + # backward overrides + if backward_override is not None: ctx.fp8 = False ctx.debug = False ctx.ub_overlap_ag = False @@ -691,7 +691,7 @@ def backward( # -------------------------------------------------- ln_out_total = None ln_out_total_work = None - if ctx.backward_mode == "dequant": + if ctx.backward_override == "dequantized": if isinstance(ln_out, QuantizedTensorStorage): ln_out = ln_out.dequantize(dtype=ctx.activation_dtype) else: @@ -768,12 +768,12 @@ def backward( # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight - if ctx.backward_mode == "dequant": + if ctx.backward_override == "dequantized": if isinstance(weight_for_dgrad, QuantizedTensorStorage): weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) else: weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) - elif ctx.backward_mode == "unquant": + elif ctx.backward_override == "high_precision": weight_for_dgrad = origin_weight if isinstance(weight_for_dgrad, QuantizedTensorStorage): weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) @@ -1676,7 +1676,9 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + if fp8_recipe.backward_override == "dequantized" and ( + fp8_recipe.mxfp8() or fp8_recipe.nvfp4() + ): input_quantizer.optimize_for_gemm = False if grad_output_quantizer is not None: grad_output_quantizer.optimize_for_gemm = False diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4f206c866e..d1a2352ea8 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -235,12 +235,13 @@ def _forward( recompute_for_bwd, ) = non_tensor_args if fp8: - backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: - backward_mode = "default" - assert backward_mode == "default", ( - "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP. " - "Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward." + backward_override = None + assert backward_override is None, ( + "NVTE_BACKWARD_OVERRIDE=high_precision/dequantized is not implemented in LayerNormMLP." + " Replace LayerNormMLP with LayerNormLinear + Linear to enable" + " high_precision/dequantized backward." ) # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take @@ -788,7 +789,7 @@ def _forward( ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.backward_mode = backward_mode + ctx.backward_override = backward_override ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 9e142a0b02..82fba5f33d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,10 +129,10 @@ def forward( debug, ) = non_tensor_args if fp8: - backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: - backward_mode = "default" - if backward_mode == "unquant": + backward_override = None + if backward_override == "high_precision": save_original_input = True # NVTX label for profiling @@ -195,7 +195,7 @@ def forward( own_quantized_input = True input_quantizer.set_usage( rowwise=True, - columnwise=backward_needs_input and backward_mode == "default", + columnwise=backward_needs_input and backward_override is None, ) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) @@ -242,7 +242,7 @@ def forward( columnwise=( backward_needs_input and not save_original_input - and backward_mode == "default" + and backward_override is None ), ) inputmat = input_quantizer(inputmat) @@ -268,7 +268,7 @@ def forward( # for debug mode we create quantizer every iteration, thus we need to set the quantizer states if weight_quantizer is not None and (not isinstance(weight, QuantizedTensor) or debug): columnwise_usage = is_grad_enabled and inp.requires_grad - if backward_mode in ("unquant", "dequant"): + if backward_override is not None: columnwise_usage = False if not columnwise_usage: columnwise_usage = ( @@ -403,8 +403,8 @@ def forward( and own_quantized_input and isinstance(inputmat, QuantizedTensorStorage) ): - if backward_mode in ("unquant", "dequant"): - # In dequant mode we should dequantize directly from the + if backward_override is not None: + # In dequantized mode we should dequantize directly from the # fprop quantized tensor layout without retargeting usage. inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) elif ( @@ -462,7 +462,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.backward_mode = backward_mode + ctx.backward_override = backward_override ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -506,8 +506,8 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store - # Non-quantized backward mode overrides - if backward_mode in ("unquant", "dequant"): + # backward overrides + if backward_override is not None: ctx.fp8 = False ctx.debug = False ctx.ub_overlap_ag = False @@ -756,12 +756,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight_fp8 - if ctx.backward_mode == "dequant": + if ctx.backward_override == "dequantized": if isinstance(weight_for_dgrad, QuantizedTensorStorage): weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) else: weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) - elif ctx.backward_mode == "unquant": + elif ctx.backward_override == "high_precision": weight_for_dgrad = weight if isinstance(weight_for_dgrad, QuantizedTensorStorage): weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) @@ -1542,7 +1542,9 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + if fp8_recipe.backward_override == "dequantized" and ( + fp8_recipe.mxfp8() or fp8_recipe.nvfp4() + ): input_quantizer.optimize_for_gemm = False if grad_output_quantizer is not None: grad_output_quantizer.optimize_for_gemm = False diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index accc5fbe6a..17594726cc 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -333,7 +333,7 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad columnwise_usage = weight_requires_grad - if FP8GlobalStateManager.get_fp8_recipe().backward_mode in ("unquant", "dequant"): + if FP8GlobalStateManager.get_fp8_recipe().backward_override is not None: columnwise_usage = False input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) @@ -360,7 +360,7 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: grad_output_quantizer.optimize_for_gemm = True if FP8GlobalStateManager.is_fp8_enabled(): fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe.backward_mode in ("unquant", "dequant") and ( + if fp8_recipe.backward_override is not None and ( fp8_recipe.mxfp8() or fp8_recipe.nvfp4() ): if input_quantizer is not None: @@ -432,7 +432,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, - backward_mode: str = "default", + backward_override: Optional[str] = None, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -472,8 +472,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. - backward_mode: {`"default"`, `"unquant"`, `"dequant"`}, default = `"default"` - Backward-mode policy for quantized compute. + backward_override: {`None`, `"high_precision"`, `"dequantized"`}, default = `None` + Backward-override policy for quantized compute. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -527,7 +527,7 @@ def _functional_forward( raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage( rowwise=True, - columnwise=weight_requires_grad and backward_mode == "default", + columnwise=weight_requires_grad and backward_override is None, ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) @@ -562,7 +562,7 @@ def _functional_forward( raise ValueError("Missing quantizer for weight tensor") weight_quantizer.set_usage( rowwise=True, - columnwise=input_requires_grad and backward_mode == "default", + columnwise=input_requires_grad and backward_override is None, ) w = weight_quantizer(w) @@ -636,7 +636,7 @@ def _functional_forward( w is not weight and with_quantized_compute and is_quantized_tensor(w) - and backward_mode == "default" + and backward_override is None ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: @@ -647,7 +647,7 @@ def _functional_forward( if ( with_quantized_compute and is_quantized_tensor(x_local) - and backward_mode == "default" + and backward_override is None ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data @@ -999,9 +999,9 @@ def op_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: - backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: - backward_mode = "default" + backward_override = None # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -1018,7 +1018,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, - backward_mode=backward_mode, + backward_override=backward_override, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -1028,7 +1028,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: - if backward_mode == "unquant": + if backward_override == "high_precision": saved_input = input_ if weight_requires_grad else None saved_weight = self.weight if input_requires_grad else None else: @@ -1037,8 +1037,8 @@ def op_forward( if is_cpu_offload_enabled(): mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) - ctx.with_quantized_compute = with_quantized_compute and backward_mode == "default" - ctx.backward_mode = backward_mode + ctx.with_quantized_compute = with_quantized_compute and backward_override is None + ctx.backward_override = backward_override ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index ad147a8d85..88f563b2c5 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -127,7 +127,7 @@ def op_forward( ctx.grad_input_quantizer = prev_op_grad_output_quantizer if FP8GlobalStateManager.is_fp8_enabled(): fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_recipe.backward_mode in ("unquant", "dequant"): + if fp8_recipe.backward_override is not None: ctx.grad_input_quantizer = None return x + b diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index c5474c18a0..d0c1137d91 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -59,10 +59,10 @@ def op_forward( quantize_forward = fp8_enabled and self._quantize_forward quantize_backward = fp8_enabled and self._quantize_backward - # Backward quantization is controlled by recipe backward mode. + # Backward quantization is controlled by recipe backward override. if fp8_enabled: recipe = FP8GlobalStateManager.get_fp8_recipe() - quantize_backward = quantize_backward and recipe.backward_mode == "default" + quantize_backward = quantize_backward and recipe.backward_override is None # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 7b3025c03e..3950316a3c 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -105,8 +105,8 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion. - # unquant/dequant backward modes should use unfused backward ops. - if recipe is None or recipe.backward_mode in ("unquant", "dequant"): + # high_precision/dequantized backward overrides should use unfused backward ops. + if recipe is None or recipe.backward_override is not None: return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 19d2f679fb..8df929f799 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -93,9 +93,9 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: - backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: - backward_mode = "default" + backward_override = None # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -113,7 +113,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, - backward_mode=backward_mode, + backward_override=backward_override, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -123,7 +123,7 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - if backward_mode == "unquant": + if backward_override == "high_precision": saved_input = input_ if weight_requires_grad else None saved_weight = linear_op.weight if input_requires_grad else None else: @@ -133,9 +133,9 @@ def fuser_forward( mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and backward_mode == "default" + with_quantized_compute and backward_override is None ) - linear_op_ctx.backward_mode = backward_mode + linear_op_ctx.backward_override = backward_override linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -145,7 +145,7 @@ def fuser_forward( linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() - if backward_mode in ("unquant", "dequant"): + if backward_override is not None: bias_op_ctx.grad_input_quantizer = None return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 5d2997a50a..5376a7d264 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -87,9 +87,9 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: - backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: - backward_mode = "default" + backward_override = None # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -110,7 +110,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, - backward_mode=backward_mode, + backward_override=backward_override, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -120,7 +120,7 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - if backward_mode == "unquant": + if backward_override == "high_precision": saved_input = input_ if weight_requires_grad else None saved_weight = linear_op.weight if input_requires_grad else None else: @@ -130,9 +130,9 @@ def fuser_forward( mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and backward_mode == "default" + with_quantized_compute and backward_override is None ) - linear_op_ctx.backward_mode = backward_mode + linear_op_ctx.backward_override = backward_override linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -142,7 +142,7 @@ def fuser_forward( linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: bias_op_ctx.grad_input_quantizer = ( - None if backward_mode != "default" else linear_op.get_grad_output_quantizer() + None if backward_override is not None else linear_op.get_grad_output_quantizer() ) return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index c1eeac484f..abeb39adfa 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -66,9 +66,9 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: - backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: - backward_mode = "default" + backward_override = None # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -91,7 +91,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, - backward_mode=backward_mode, + backward_override=backward_override, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -101,7 +101,7 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - if backward_mode == "unquant": + if backward_override == "high_precision": saved_input = input_ if weight_requires_grad else None saved_weight = linear_op.weight if input_requires_grad else None else: @@ -111,9 +111,9 @@ def fuser_forward( mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and backward_mode == "default" + with_quantized_compute and backward_override is None ) - linear_op_ctx.backward_mode = backward_mode + linear_op_ctx.backward_override = backward_override linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 54411f650d..84073be6f8 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -388,17 +388,17 @@ def fuse_forward_ops( """ - # Disable Userbuffers for non-quantized backward modes. - # In unquant/dequant modes we want to avoid all UB-specific overlap + # Disable Userbuffers for backward overrides. + # In high_precision/dequantized modes we want to avoid all UB-specific overlap # paths and run through the standard non-UB operator sequence instead. recipe = unused.get("recipe", None) if recipe is not None: - backward_mode = recipe.backward_mode + backward_override = recipe.backward_override elif FP8GlobalStateManager.is_fp8_enabled(): - backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override else: - backward_mode = "default" - if backward_mode in ("unquant", "dequant"): + backward_override = None + if backward_override is not None: return ops # Return immediately if environment is not distributed diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 616c075ad8..32b43e42fc 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -339,7 +339,7 @@ def __init__( # Cache and detect change of state relevant for fusing operations self.recipe_type = None self.first_op_requiring_backward = 0 - self.backward_mode = "default" + self.backward_override = None self._last_amax_history_len = 0 # Flatten list of parameters @@ -416,14 +416,14 @@ def maybe_fuse_ops( # Early exit if fusion parameters haven't changed need_reset = False recipe_type = type(recipe) - backward_mode = recipe.backward_mode if recipe is not None else "default" - fusion_params = (recipe_type, first_op_requiring_backward, backward_mode) + backward_override = recipe.backward_override if recipe is not None else None + fusion_params = (recipe_type, first_op_requiring_backward, backward_override) if fusion_params != ( self.recipe_type, self.first_op_requiring_backward, - self.backward_mode, + self.backward_override, ): - # Recipe type, backward mode, or grad requirements have changed + # Recipe type, backward override, or grad requirements have changed need_reset = True elif ( recipe is not None @@ -457,7 +457,7 @@ def maybe_fuse_ops( ) # Save current fusion params - self.recipe_type, self.first_op_requiring_backward, self.backward_mode = fusion_params + self.recipe_type, self.first_op_requiring_backward, self.backward_override = fusion_params # Save amax history length if isinstance(recipe, DelayedScaling): From 00893bb0b78c46b2d7d9ff24a43dbf84ace93a7a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 14 Mar 2026 14:16:28 -0700 Subject: [PATCH 60/61] Rename unit test Signed-off-by: Ziang Li --- qa/L0_pytorch_unittest/test.sh | 2 +- .../{test_backward_mode.py => test_backward_override.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/pytorch/{test_backward_mode.py => test_backward_override.py} (100%) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index a9df8a1bb6..9085c33dd9 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -42,7 +42,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_mode.xml $TE_PATH/tests/pytorch/test_backward_mode.py || test_fail "test_backward_mode.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_override.xml $TE_PATH/tests/pytorch/test_backward_override.py || test_fail "test_backward_override.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_override.py similarity index 100% rename from tests/pytorch/test_backward_mode.py rename to tests/pytorch/test_backward_override.py From 433880d3736152a9dc90ca95d111506d57029e25 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 14 Mar 2026 15:24:25 -0700 Subject: [PATCH 61/61] Simplify env var parsing Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 47 ++++++++++---------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index fd44d69a7d..67b6f87067 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -14,17 +14,6 @@ _BACKWARD_OVERRIDES = (None, "high_precision", "dequantized") -def _resolve_backward_override(mode: Optional[str] = None) -> Optional[str]: - """Return validated backward override from argument or NVTE_BACKWARD_OVERRIDE env.""" - if mode is None: - mode = os.getenv("NVTE_BACKWARD_OVERRIDE", None) - assert mode in _BACKWARD_OVERRIDES, ( - f"Invalid NVTE_BACKWARD_OVERRIDE value {mode!r}. Supported values are:" - " high_precision|dequantized." - ) - return mode - - class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -227,11 +216,13 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False - backward_override: Optional[str] = field(default_factory=_resolve_backward_override) + backward_override: Optional[str] = os.getenv("NVTE_BACKWARD_OVERRIDE", None) def __post_init__(self) -> None: - self.backward_override = _resolve_backward_override(self.backward_override) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert ( + self.backward_override in _BACKWARD_OVERRIDES + ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." assert ( self.backward_override is None ), "Delayed scaling only supports backward_override=None." @@ -276,11 +267,13 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - backward_override: Optional[str] = field(default_factory=_resolve_backward_override) + backward_override: Optional[str] = os.getenv("NVTE_BACKWARD_OVERRIDE", None) def __post_init__(self) -> None: - self.backward_override = _resolve_backward_override(self.backward_override) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert ( + self.backward_override in _BACKWARD_OVERRIDES + ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." def __repr__(self) -> str: return ( @@ -332,11 +325,13 @@ class MXFP8BlockScaling(Recipe): fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False - backward_override: Optional[str] = field(default_factory=_resolve_backward_override) + backward_override: Optional[str] = os.getenv("NVTE_BACKWARD_OVERRIDE", None) def __post_init__(self) -> None: - self.backward_override = _resolve_backward_override(self.backward_override) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert ( + self.backward_override in _BACKWARD_OVERRIDES + ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." def __repr__(self) -> str: return ( @@ -393,10 +388,9 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - backward_override: Optional[str] = field(default_factory=_resolve_backward_override) + backward_override: Optional[str] = os.getenv("NVTE_BACKWARD_OVERRIDE", None) def __post_init__(self) -> None: - self.backward_override = _resolve_backward_override(self.backward_override) assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" @@ -416,6 +410,9 @@ def __post_init__(self) -> None: not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert ( + self.backward_override in _BACKWARD_OVERRIDES + ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." def __repr__(self) -> str: return ( @@ -501,12 +498,14 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False - backward_override: Optional[str] = field(default_factory=_resolve_backward_override) + backward_override: Optional[str] = os.getenv("NVTE_BACKWARD_OVERRIDE", None) def __post_init__(self) -> None: - self.backward_override = _resolve_backward_override(self.backward_override) assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + assert ( + self.backward_override in _BACKWARD_OVERRIDES + ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." # Quantization params # Note: RHT is currently only applied to column-wise usage so that @@ -577,10 +576,12 @@ class CustomRecipe(Recipe): fp8_dpa: bool = False fp8_mha: bool = False - backward_override: Optional[str] = field(default_factory=_resolve_backward_override) + backward_override: Optional[str] = os.getenv("NVTE_BACKWARD_OVERRIDE", None) def __post_init__(self) -> None: - self.backward_override = _resolve_backward_override(self.backward_override) + assert ( + self.backward_override in _BACKWARD_OVERRIDES + ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." def __repr__(self) -> str: return (