From de58f327ac21f14d3524e8371eb9e5ab563ffc4f Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Fri, 3 Apr 2026 05:20:12 -0700 Subject: [PATCH 1/6] init commit Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 410 +++++++++++++++++++++------- 1 file changed, 305 insertions(+), 105 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 86b8a4acf4..991e0894c7 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -110,6 +110,7 @@ def _make_graphed_callables( _reuse_graph_input_output_buffers: bool = False, pre_warmup_hook: Optional[Callable] = None, post_warmup_hook: Optional[Callable] = None, + capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -448,111 +449,197 @@ def _make_graphed_callables( visited_te_modules = {} need_bwd_dw_graph = {} + def _run_warmup_forward(func_idx, func): + """Run forward for one callable during warmup; returns flattened outputs.""" + args = sample_args[func_idx] + kwargs = sample_kwargs[func_idx] + + def hook_fn( + module, inputs, outputs, func_idx=func_idx + ): # pylint: disable=unused-argument + modules = set() + if isinstance(module, TransformerEngineBaseModule): + modules.add(module) + # If forward is called on a BasicOperation directly the hook will run + elif isinstance(module, BasicOperation): + modules.add(module) + # If forward is called on a te.ops.Sequential it is not called on its constituent ops + elif isinstance(module, Sequential): + if module._module_groups is None: + raise RuntimeError( + "module._module_groups should have been initialized by warmup" + ) + for module_group in module._module_groups: + if isinstance(module_group, OperationFuser): + for basic_op in module_group._basic_ops: + modules.add(basic_op) + if modules: + if func_idx not in visited_te_modules: + visited_te_modules[func_idx] = modules + else: + visited_te_modules[func_idx].update(modules) + + if ( + capture_time_hooks is not None + and func_idx < len(capture_time_hooks) + and capture_time_hooks[func_idx] is not None + and "forward_pre" in capture_time_hooks[func_idx] + ): + for hook in capture_time_hooks[func_idx]["forward_pre"].values(): + hook(func, args, kwargs) + + hooks = [] + for module in func.modules(): + hooks.append(module.register_forward_hook(hook_fn)) + outputs, _ = _tree_flatten(func(*args, **kwargs)) + for hook in hooks: + hook.remove() + + if ( + capture_time_hooks is not None + and func_idx < len(capture_time_hooks) + and capture_time_hooks[func_idx] is not None + and "forward" in capture_time_hooks[func_idx] + ): + for hook in capture_time_hooks[func_idx]["forward"].values(): + hook(func, args, outputs) + + return outputs + + def _run_warmup_backward(func_idx, func, outputs, warmup_iter): + """Run dgrad backward for one callable during warmup.""" + static_input_surface = per_callable_static_input_surfaces[func_idx] + + if ( + capture_time_hooks is not None + and func_idx < len(capture_time_hooks) + and capture_time_hooks[func_idx] is not None + and "pre_backward" in capture_time_hooks[func_idx] + ): + for hook in capture_time_hooks[func_idx]["pre_backward"].values(): + hook(func) + + inputs = tuple(i for i in static_input_surface if i.requires_grad) + with _none_grad_context_wrapper(inputs): + outputs_requiring_grad = tuple( + o for o in outputs if o is not None and o.requires_grad + ) + torch.autograd.backward( + outputs_requiring_grad, + grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), + ) + grad_inputs = tuple(input.grad for input in inputs) + + if ( + capture_time_hooks is not None + and func_idx < len(capture_time_hooks) + and capture_time_hooks[func_idx] is not None + and "backward" in capture_time_hooks[func_idx] + ): + for hook in capture_time_hooks[func_idx]["backward"].values(): + hook(func) + + # Filter module params that get None grad from grad_inputs and remove them + # from static_input_surface. This is to ensure that the backward hooks + # registered to these params are not wrongly triggered. + num_required_grad_sample_args = sum( + arg.requires_grad for arg in flatten_sample_args[func_idx] + ) + required_grad_input_idx = [] + for i, arg in enumerate(static_input_surface): + if arg.requires_grad: + required_grad_input_idx.append(i) + module_params_with_grad = [] + for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): + if ( + grad_inputs[grad_inputs_idx] is None + and grad_inputs_idx < num_required_grad_sample_args + ): + if not allow_unused_input: + raise RuntimeError( + "The input tensor requires grad, but the grad is None after" + " backward pass." + ) + elif ( + grad_inputs[grad_inputs_idx] is not None + and grad_inputs_idx >= num_required_grad_sample_args + ): + module_params_with_grad.append(static_input_surface[inputs_idx]) + if len(module_params_with_grad) != len(per_callable_module_params[func_idx]): + if warmup_iter != 0: + raise RuntimeError( + "no-grad params should only be used as inputs in the first warmup" + f" iteration, but found in iteration {warmup_iter}" + ) + per_callable_module_params[func_idx] = tuple(module_params_with_grad) + static_input_surface = flatten_sample_args[func_idx] + tuple( + module_params_with_grad + ) + per_callable_static_input_surfaces[func_idx] = static_input_surface + + # Run wgrad. This is essential for some TE modules when they have + # delay_wgrad_compute enabled. + need_backward_dw = False + for module in visited_te_modules.get(func_idx, set()): + if hasattr(module, "need_backward_dw") and module.need_backward_dw(): + need_backward_dw = True + module.backward_dw() + need_bwd_dw_graph[func_idx] = need_backward_dw + # Run warmup and do the above filtering. with torch.cuda.stream(torch.cuda.Stream()): - for func_idx, func in zip(warmup_func_idx, warmup_func): - args = sample_args[func_idx] - kwargs = sample_kwargs[func_idx] - static_input_surface = per_callable_static_input_surfaces[func_idx] - - def hook_fn( - module, inputs, outputs, func_idx=func_idx - ): # pylint: disable=unused-argument - modules = set() - if isinstance(module, TransformerEngineBaseModule): - modules.add(module) - # If forward is called on a BasicOperation directly the hook will run - elif isinstance(module, BasicOperation): - modules.add(module) - # If forward is called on a te.ops.Sequential it is not called on its constituent ops - elif isinstance(module, Sequential): - if module._module_groups is None: - raise RuntimeError( - "module._module_groups should have been initialized by warmup" - ) - for module_group in module._module_groups: - if isinstance(module_group, OperationFuser): - for basic_op in module_group._basic_ops: - modules.add(basic_op) - if modules: - if func_idx not in visited_te_modules: - visited_te_modules[func_idx] = modules - else: - visited_te_modules[func_idx].update(modules) - - if pre_warmup_hook is not None: - pre_warmup_hook() - for warmup_iter in range(num_warmup_iters): - hooks = [] - for module in func.modules(): - hook = module.register_forward_hook(hook_fn) - hooks.append(hook) - outputs, _ = _tree_flatten(func(*args, **kwargs)) - for hook in hooks: - hook.remove() - if is_training: - inputs = tuple(i for i in static_input_surface if i.requires_grad) - with _none_grad_context_wrapper(inputs): - outputs_requiring_grad = tuple( - o for o in outputs if o is not None and o.requires_grad - ) - torch.autograd.backward( - outputs_requiring_grad, - grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), - ) - grad_inputs = tuple(input.grad for input in inputs) + if pre_warmup_hook is not None: + pre_warmup_hook() - # Filter module params that get None grad from grad_inputs and remove them - # from static_input_surface. This is to ensure that the backward hooks - # registered to these params are not wrongly triggered. - num_required_grad_sample_args = sum( - arg.requires_grad for arg in flatten_sample_args[func_idx] - ) - required_grad_input_idx = [] - for i, arg in enumerate(static_input_surface): - if arg.requires_grad: - required_grad_input_idx.append(i) - module_params_with_grad = [] - for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): - if ( - grad_inputs[grad_inputs_idx] is None - and grad_inputs_idx < num_required_grad_sample_args - ): - if not allow_unused_input: - raise RuntimeError( - "The input tensor requires grad, but the grad is None after" - " backward pass." + for warmup_iter in range(num_warmup_iters): + if _order is None: + # All forwards in order, then all backwards in reverse order. + warmup_outputs = [] + for func_idx, func in zip(warmup_func_idx, warmup_func): + outputs = _run_warmup_forward(func_idx, func) + warmup_outputs.append((func_idx, func, outputs)) + if is_training: + for func_idx, func, outputs in reversed(warmup_outputs): + _run_warmup_backward(func_idx, func, outputs, warmup_iter) + else: + # Follow _order exactly, mirroring the capture phase. + per_fwd_outputs = {} # per_callable_fwd_idx -> flattened outputs + fwd_idx = [0] * num_model_chunks + bwd_idx = [0] * num_model_chunks + for c_id in _order: + if c_id > 0: + # Forward pass for chunk c_id. + m_chunk = c_id - 1 + for l_no in range(_num_layers_per_chunk[m_chunk]): + per_callable_fwd_idx = ( + _prefix_num_layers[m_chunk] * num_microbatches + ) + (fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) + func = callables[_prefix_num_layers[m_chunk] + l_no] + outputs = _run_warmup_forward(per_callable_fwd_idx, func) + per_fwd_outputs[per_callable_fwd_idx] = outputs + fwd_idx[m_chunk] += 1 + elif ceil(c_id) == c_id: + # Backward pass for chunk -c_id. + if is_training: + m_chunk = -c_id - 1 + for l_no in reversed(range(_num_layers_per_chunk[m_chunk])): + per_callable_bwd_idx = ( + _prefix_num_layers[m_chunk] * num_microbatches + ) + ( + bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no ) - elif ( - grad_inputs[grad_inputs_idx] is not None - and grad_inputs_idx >= num_required_grad_sample_args - ): - module_params_with_grad.append(static_input_surface[inputs_idx]) - if len(module_params_with_grad) != len(per_callable_module_params[func_idx]): - if warmup_iter != 0: - raise RuntimeError( - "no-grad params should only be used as inputs in the first warmup" - f" iteration, but found in iteration {warmup_iter}" - ) - per_callable_module_params[func_idx] = tuple(module_params_with_grad) - static_input_surface = flatten_sample_args[func_idx] + tuple( - module_params_with_grad - ) - per_callable_static_input_surfaces[func_idx] = static_input_surface - - # Run wgrad. This is essential for some TE modules when they have - # delay_wgrad_compute enabled. - need_backward_dw = False - for module in visited_te_modules.get(func_idx, set()): - if hasattr(module, "need_backward_dw") and module.need_backward_dw(): - need_backward_dw = True - module.backward_dw() - need_bwd_dw_graph[func_idx] = need_backward_dw - else: - grad_inputs = None - del outputs, grad_inputs - if post_warmup_hook is not None: - post_warmup_hook() + func = callables[_prefix_num_layers[m_chunk] + l_no] + outputs = per_fwd_outputs[per_callable_bwd_idx] + _run_warmup_backward( + per_callable_bwd_idx, + func, + outputs, + warmup_iter + ) + bwd_idx[m_chunk] += 1 + + if post_warmup_hook is not None: + post_warmup_hook() torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -585,8 +672,30 @@ def hook_fn( args = sample_args[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] + + # Call forward_pre hooks before forward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and per_callable_fwd_idx < len(capture_time_hooks) + and capture_time_hooks[per_callable_fwd_idx] is not None + and "forward_pre" in capture_time_hooks[per_callable_fwd_idx] + ): + for hook in capture_time_hooks[per_callable_fwd_idx]["forward_pre"].values(): + hook(func, args, kwargs) # forward_pre hook signature: (module, args, kwargs) + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) + + # Call forward hooks after forward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and per_callable_fwd_idx < len(capture_time_hooks) + and capture_time_hooks[per_callable_fwd_idx] is not None + and "forward" in capture_time_hooks[per_callable_fwd_idx] + ): + for hook in capture_time_hooks[per_callable_fwd_idx]["forward"].values(): + hook(func, args, outputs) # forward hook signature: (module, inputs, output) + flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec @@ -683,6 +792,22 @@ def hook_fn( for o in static_outputs ) if is_training: + # Call pre_backward hooks before backward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and per_callable_bwd_idx < len(capture_time_hooks) + and capture_time_hooks[per_callable_bwd_idx] is not None + and "pre_backward" in capture_time_hooks[per_callable_bwd_idx] + ): + # Get the callable module for this backward index + callable_module = graph_callables[per_callable_bwd_idx] + for hook in capture_time_hooks[per_callable_bwd_idx][ + "pre_backward" + ].values(): + # During capture, call with the actual module (not None) + # FSDP hooks need to access module attributes + hook(callable_module) + inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs), _graph_context_wrapper( bwd_graph, pool=mempool @@ -696,6 +821,20 @@ def hook_fn( ) grad_inputs = tuple(input.grad for input in inputs) + # Call backward hooks after backward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and per_callable_bwd_idx < len(capture_time_hooks) + and capture_time_hooks[per_callable_bwd_idx] is not None + and "backward" in capture_time_hooks[per_callable_bwd_idx] + ): + # Get the callable module for this backward index + callable_module = graph_callables[per_callable_bwd_idx] + for hook in capture_time_hooks[per_callable_bwd_idx]["backward"].values(): + # During capture, call with the actual module (not None) + # FSDP hooks need to access module attributes + hook(callable_module) + # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs # that don't require grad. I couldn't think of a one-liner for this pattern. @@ -750,12 +889,32 @@ def hook_fn( # Capture forward graphs per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] - graph_id = 0 - for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs): + for func_idx, (func, args, kwargs, fwd_graph) in enumerate( + zip(callables, sample_args, sample_kwargs, fwd_graphs) + ): + # Call forward_pre hooks before forward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and func_idx < len(capture_time_hooks) + and capture_time_hooks[func_idx] is not None + and "forward_pre" in capture_time_hooks[func_idx] + ): + for hook in capture_time_hooks[func_idx]["forward_pre"].values(): + hook(func, args, kwargs) + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) - graph_callables[graph_id] = func - graph_id += 1 + graph_callables[func_idx] = func + + # Call forward hooks after forward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and func_idx < len(capture_time_hooks) + and capture_time_hooks[func_idx] is not None + and "forward" in capture_time_hooks[func_idx] + ): + for hook in capture_time_hooks[func_idx]["forward"].values(): + hook(func, args, outputs) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs.append(tuple(flatten_outputs)) @@ -777,6 +936,17 @@ def hook_fn( for o in static_outputs ) if is_training: + # Call pre_backward hooks before backward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and bwd_idx < len(capture_time_hooks) + and capture_time_hooks[bwd_idx] is not None + and "pre_backward" in capture_time_hooks[bwd_idx] + ): + callable_module = graph_callables[bwd_idx] + for hook in capture_time_hooks[bwd_idx]["pre_backward"].values(): + hook(callable_module) + inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs), _graph_context_wrapper( bwd_graph, pool=mempool @@ -788,6 +958,17 @@ def hook_fn( ) grad_inputs = tuple(input.grad for input in inputs) + # Call backward hooks after backward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and bwd_idx < len(capture_time_hooks) + and capture_time_hooks[bwd_idx] is not None + and "backward" in capture_time_hooks[bwd_idx] + ): + callable_module = graph_callables[bwd_idx] + for hook in capture_time_hooks[bwd_idx]["backward"].values(): + hook(callable_module) + if need_bwd_dw_graph[bwd_idx]: with _graph_context_wrapper(bwd_dw_graph, pool=mempool): for module in visited_te_modules[bwd_idx]: @@ -1145,6 +1326,7 @@ def make_graphed_callables( _reuse_graph_input_output_buffers: bool = False, pre_warmup_hook: Optional[Callable] = None, post_warmup_hook: Optional[Callable] = None, + capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -1187,6 +1369,23 @@ def make_graphed_callables( A hook function that will be called before the warmup iterations. post_warmup_hook: callable, default = None A hook function that will be called after the warmup iterations. + capture_time_hooks: list of dict, optional + Per-callable hooks invoked at capture time (during warmup iterations and + graph capture), but intentionally executed **outside** the CUDA graph + capture context so they are **not** recorded into the graph and will + **not** be replayed. Use this for operations that are inherently + non-capturable but essential for correct module execution, such as + CPU-side state updates. + Each element corresponds to one callable and is a dict with any subset + of the following keys: + - ``"forward_pre"``: dict of hooks called *before* the forward pass. + Hook signature: ``hook(module, args, kwargs)``. + - ``"forward"``: dict of hooks called *after* the forward pass. + Hook signature: ``hook(module, args, output)``. + - ``"pre_backward"``: dict of hooks called *before* the backward pass. + Hook signature: ``hook(module)``. + - ``"backward"``: dict of hooks called *after* the backward pass. + Hook signature: ``hook(module)``. Quantization parameters ----------------------- @@ -1380,6 +1579,7 @@ def call_func(self, *args, **kwargs): _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, pre_warmup_hook=pre_warmup_hook, post_warmup_hook=post_warmup_hook, + capture_time_hooks=capture_time_hooks, ) # Ensures warmup does not affect numerics for ops such as dropout. From 7f095c52ee264d533962019605a85b9936d9c8ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Apr 2026 14:19:44 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 40 +++++++++++++---------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 991e0894c7..7d4e2540e7 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -454,9 +454,7 @@ def _run_warmup_forward(func_idx, func): args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] - def hook_fn( - module, inputs, outputs, func_idx=func_idx - ): # pylint: disable=unused-argument + def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unused-argument modules = set() if isinstance(module, TransformerEngineBaseModule): modules.add(module) @@ -521,9 +519,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): - outputs_requiring_grad = tuple( - o for o in outputs if o is not None and o.requires_grad - ) + outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad) torch.autograd.backward( outputs_requiring_grad, grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), @@ -557,8 +553,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): ): if not allow_unused_input: raise RuntimeError( - "The input tensor requires grad, but the grad is None after" - " backward pass." + "The input tensor requires grad, but the grad is None after backward pass." ) elif ( grad_inputs[grad_inputs_idx] is not None @@ -572,9 +567,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): f" iteration, but found in iteration {warmup_iter}" ) per_callable_module_params[func_idx] = tuple(module_params_with_grad) - static_input_surface = flatten_sample_args[func_idx] + tuple( - module_params_with_grad - ) + static_input_surface = flatten_sample_args[func_idx] + tuple(module_params_with_grad) per_callable_static_input_surfaces[func_idx] = static_input_surface # Run wgrad. This is essential for some TE modules when they have @@ -625,16 +618,11 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): for l_no in reversed(range(_num_layers_per_chunk[m_chunk])): per_callable_bwd_idx = ( _prefix_num_layers[m_chunk] * num_microbatches - ) + ( - bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no - ) + ) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) func = callables[_prefix_num_layers[m_chunk] + l_no] outputs = per_fwd_outputs[per_callable_bwd_idx] _run_warmup_backward( - per_callable_bwd_idx, - func, - outputs, - warmup_iter + per_callable_bwd_idx, func, outputs, warmup_iter ) bwd_idx[m_chunk] += 1 @@ -680,8 +668,12 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): and capture_time_hooks[per_callable_fwd_idx] is not None and "forward_pre" in capture_time_hooks[per_callable_fwd_idx] ): - for hook in capture_time_hooks[per_callable_fwd_idx]["forward_pre"].values(): - hook(func, args, kwargs) # forward_pre hook signature: (module, args, kwargs) + for hook in capture_time_hooks[per_callable_fwd_idx][ + "forward_pre" + ].values(): + hook( + func, args, kwargs + ) # forward_pre hook signature: (module, args, kwargs) with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) @@ -694,7 +686,9 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): and "forward" in capture_time_hooks[per_callable_fwd_idx] ): for hook in capture_time_hooks[per_callable_fwd_idx]["forward"].values(): - hook(func, args, outputs) # forward hook signature: (module, inputs, output) + hook( + func, args, outputs + ) # forward hook signature: (module, inputs, output) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) @@ -830,7 +824,9 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): ): # Get the callable module for this backward index callable_module = graph_callables[per_callable_bwd_idx] - for hook in capture_time_hooks[per_callable_bwd_idx]["backward"].values(): + for hook in capture_time_hooks[per_callable_bwd_idx][ + "backward" + ].values(): # During capture, call with the actual module (not None) # FSDP hooks need to access module attributes hook(callable_module) From 7536553c9551a130c71247ca4ea328460ddcfd1c Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Tue, 7 Apr 2026 07:02:57 -0700 Subject: [PATCH 3/6] resolve comments Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 143 ++++++++++++++-------------- 1 file changed, 72 insertions(+), 71 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 7d4e2540e7..e4d3eff196 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -360,6 +360,12 @@ def _make_graphed_callables( + "for each callable must contain only Tensors. Other types are not allowed." ) + if capture_time_hooks is not None and len(capture_time_hooks) != len(callables): + raise ValueError( + f"capture_time_hooks has {len(capture_time_hooks)} entries but there are " + f"{len(callables)} callables" + ) + # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly # passes to forward (ie, its sample_args) AND the module's parameter attributes. # Note: These per_callable_* variables are not actually @@ -449,7 +455,7 @@ def _make_graphed_callables( visited_te_modules = {} need_bwd_dw_graph = {} - def _run_warmup_forward(func_idx, func): + def _run_warmup_forward(func_idx, func, callable_idx): """Run forward for one callable during warmup; returns flattened outputs.""" args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] @@ -479,42 +485,44 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus if ( capture_time_hooks is not None - and func_idx < len(capture_time_hooks) - and capture_time_hooks[func_idx] is not None - and "forward_pre" in capture_time_hooks[func_idx] + and capture_time_hooks[callable_idx] is not None + and "pre_forward" in capture_time_hooks[callable_idx] ): - for hook in capture_time_hooks[func_idx]["forward_pre"].values(): - hook(func, args, kwargs) + for hook in capture_time_hooks[callable_idx]["pre_forward"].values(): + result = hook(func, args, kwargs) + if result is not None: + args, kwargs = result hooks = [] for module in func.modules(): hooks.append(module.register_forward_hook(hook_fn)) - outputs, _ = _tree_flatten(func(*args, **kwargs)) + outputs = func(*args, **kwargs) for hook in hooks: hook.remove() if ( capture_time_hooks is not None - and func_idx < len(capture_time_hooks) - and capture_time_hooks[func_idx] is not None - and "forward" in capture_time_hooks[func_idx] + and capture_time_hooks[callable_idx] is not None + and "forward" in capture_time_hooks[callable_idx] ): - for hook in capture_time_hooks[func_idx]["forward"].values(): - hook(func, args, outputs) + for hook in capture_time_hooks[callable_idx]["forward"].values(): + result = hook(func, args, outputs) + if result is not None: + outputs = result + outputs, _ = _tree_flatten(outputs) return outputs - def _run_warmup_backward(func_idx, func, outputs, warmup_iter): + def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): """Run dgrad backward for one callable during warmup.""" static_input_surface = per_callable_static_input_surfaces[func_idx] if ( capture_time_hooks is not None - and func_idx < len(capture_time_hooks) - and capture_time_hooks[func_idx] is not None - and "pre_backward" in capture_time_hooks[func_idx] + and capture_time_hooks[callable_idx] is not None + and "pre_backward" in capture_time_hooks[callable_idx] ): - for hook in capture_time_hooks[func_idx]["pre_backward"].values(): + for hook in capture_time_hooks[callable_idx]["pre_backward"].values(): hook(func) inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -528,11 +536,10 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): if ( capture_time_hooks is not None - and func_idx < len(capture_time_hooks) - and capture_time_hooks[func_idx] is not None - and "backward" in capture_time_hooks[func_idx] + and capture_time_hooks[callable_idx] is not None + and "backward" in capture_time_hooks[callable_idx] ): - for hook in capture_time_hooks[func_idx]["backward"].values(): + for hook in capture_time_hooks[callable_idx]["backward"].values(): hook(func) # Filter module params that get None grad from grad_inputs and remove them @@ -589,11 +596,11 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): # All forwards in order, then all backwards in reverse order. warmup_outputs = [] for func_idx, func in zip(warmup_func_idx, warmup_func): - outputs = _run_warmup_forward(func_idx, func) + outputs = _run_warmup_forward(func_idx, func, func_idx) warmup_outputs.append((func_idx, func, outputs)) if is_training: for func_idx, func, outputs in reversed(warmup_outputs): - _run_warmup_backward(func_idx, func, outputs, warmup_iter) + _run_warmup_backward(func_idx, func, outputs, warmup_iter, func_idx) else: # Follow _order exactly, mirroring the capture phase. per_fwd_outputs = {} # per_callable_fwd_idx -> flattened outputs @@ -604,11 +611,12 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): # Forward pass for chunk c_id. m_chunk = c_id - 1 for l_no in range(_num_layers_per_chunk[m_chunk]): + callable_idx = _prefix_num_layers[m_chunk] + l_no per_callable_fwd_idx = ( _prefix_num_layers[m_chunk] * num_microbatches ) + (fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) - func = callables[_prefix_num_layers[m_chunk] + l_no] - outputs = _run_warmup_forward(per_callable_fwd_idx, func) + func = callables[callable_idx] + outputs = _run_warmup_forward(per_callable_fwd_idx, func, callable_idx) per_fwd_outputs[per_callable_fwd_idx] = outputs fwd_idx[m_chunk] += 1 elif ceil(c_id) == c_id: @@ -616,13 +624,14 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): if is_training: m_chunk = -c_id - 1 for l_no in reversed(range(_num_layers_per_chunk[m_chunk])): + callable_idx = _prefix_num_layers[m_chunk] + l_no per_callable_bwd_idx = ( _prefix_num_layers[m_chunk] * num_microbatches ) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) - func = callables[_prefix_num_layers[m_chunk] + l_no] + func = callables[callable_idx] outputs = per_fwd_outputs[per_callable_bwd_idx] _run_warmup_backward( - per_callable_bwd_idx, func, outputs, warmup_iter + per_callable_bwd_idx, func, outputs, warmup_iter, callable_idx ) bwd_idx[m_chunk] += 1 @@ -653,7 +662,8 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] m_chunk = c_id - 1 for l_no in range(_num_layers_per_chunk[m_chunk]): - func = callables[_prefix_num_layers[m_chunk] + l_no] + callable_idx = _prefix_num_layers[m_chunk] + l_no + func = callables[callable_idx] per_callable_fwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no ) @@ -661,19 +671,16 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] - # Call forward_pre hooks before forward graph capture (outside capture context) + # Call pre_forward hooks before forward graph capture (outside capture context) if ( capture_time_hooks is not None - and per_callable_fwd_idx < len(capture_time_hooks) - and capture_time_hooks[per_callable_fwd_idx] is not None - and "forward_pre" in capture_time_hooks[per_callable_fwd_idx] + and capture_time_hooks[callable_idx] is not None + and "pre_forward" in capture_time_hooks[callable_idx] ): - for hook in capture_time_hooks[per_callable_fwd_idx][ - "forward_pre" - ].values(): - hook( - func, args, kwargs - ) # forward_pre hook signature: (module, args, kwargs) + for hook in capture_time_hooks[callable_idx]["pre_forward"].values(): + result = hook(func, args, kwargs) + if result is not None: + args, kwargs = result with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) @@ -681,14 +688,13 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): # Call forward hooks after forward graph capture (outside capture context) if ( capture_time_hooks is not None - and per_callable_fwd_idx < len(capture_time_hooks) - and capture_time_hooks[per_callable_fwd_idx] is not None - and "forward" in capture_time_hooks[per_callable_fwd_idx] + and capture_time_hooks[callable_idx] is not None + and "forward" in capture_time_hooks[callable_idx] ): - for hook in capture_time_hooks[per_callable_fwd_idx]["forward"].values(): - hook( - func, args, outputs - ) # forward hook signature: (module, inputs, output) + for hook in capture_time_hooks[callable_idx]["forward"].values(): + result = hook(func, args, outputs) + if result is not None: + outputs = result flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) @@ -700,6 +706,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): m_chunk = -ceil(c_id) - 1 previous_per_callable_bwd_idx = None for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))): + callable_idx = _prefix_num_layers[m_chunk] + l_no per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no ) @@ -789,17 +796,14 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): # Call pre_backward hooks before backward graph capture (outside capture context) if ( capture_time_hooks is not None - and per_callable_bwd_idx < len(capture_time_hooks) - and capture_time_hooks[per_callable_bwd_idx] is not None - and "pre_backward" in capture_time_hooks[per_callable_bwd_idx] + and capture_time_hooks[callable_idx] is not None + and "pre_backward" in capture_time_hooks[callable_idx] ): # Get the callable module for this backward index callable_module = graph_callables[per_callable_bwd_idx] - for hook in capture_time_hooks[per_callable_bwd_idx][ + for hook in capture_time_hooks[callable_idx][ "pre_backward" ].values(): - # During capture, call with the actual module (not None) - # FSDP hooks need to access module attributes hook(callable_module) inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -818,17 +822,12 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): # Call backward hooks after backward graph capture (outside capture context) if ( capture_time_hooks is not None - and per_callable_bwd_idx < len(capture_time_hooks) - and capture_time_hooks[per_callable_bwd_idx] is not None - and "backward" in capture_time_hooks[per_callable_bwd_idx] + and capture_time_hooks[callable_idx] is not None + and "backward" in capture_time_hooks[callable_idx] ): # Get the callable module for this backward index callable_module = graph_callables[per_callable_bwd_idx] - for hook in capture_time_hooks[per_callable_bwd_idx][ - "backward" - ].values(): - # During capture, call with the actual module (not None) - # FSDP hooks need to access module attributes + for hook in capture_time_hooks[callable_idx]["backward"].values(): hook(callable_module) # Constructs a tuple suitable for returning from Graphed.backward: @@ -888,15 +887,16 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): for func_idx, (func, args, kwargs, fwd_graph) in enumerate( zip(callables, sample_args, sample_kwargs, fwd_graphs) ): - # Call forward_pre hooks before forward graph capture (outside capture context) + # Call pre_forward hooks before forward graph capture (outside capture context) if ( capture_time_hooks is not None - and func_idx < len(capture_time_hooks) and capture_time_hooks[func_idx] is not None - and "forward_pre" in capture_time_hooks[func_idx] + and "pre_forward" in capture_time_hooks[func_idx] ): - for hook in capture_time_hooks[func_idx]["forward_pre"].values(): - hook(func, args, kwargs) + for hook in capture_time_hooks[func_idx]["pre_forward"].values(): + result = hook(func, args, kwargs) + if result is not None: + args, kwargs = result with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) @@ -905,12 +905,13 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): # Call forward hooks after forward graph capture (outside capture context) if ( capture_time_hooks is not None - and func_idx < len(capture_time_hooks) and capture_time_hooks[func_idx] is not None and "forward" in capture_time_hooks[func_idx] ): for hook in capture_time_hooks[func_idx]["forward"].values(): - hook(func, args, outputs) + result = hook(func, args, outputs) + if result is not None: + outputs = result flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs.append(tuple(flatten_outputs)) @@ -935,7 +936,6 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): # Call pre_backward hooks before backward graph capture (outside capture context) if ( capture_time_hooks is not None - and bwd_idx < len(capture_time_hooks) and capture_time_hooks[bwd_idx] is not None and "pre_backward" in capture_time_hooks[bwd_idx] ): @@ -957,7 +957,6 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter): # Call backward hooks after backward graph capture (outside capture context) if ( capture_time_hooks is not None - and bwd_idx < len(capture_time_hooks) and capture_time_hooks[bwd_idx] is not None and "backward" in capture_time_hooks[bwd_idx] ): @@ -1362,9 +1361,11 @@ def make_graphed_callables( when `_order` is provided. All callables in `modules` are assumed to have inputs and outputs with the same dtype and shape. pre_warmup_hook: callable, default = None - A hook function that will be called before the warmup iterations. + A hook function that will be called once before all warmup iterations + (not once per callable). post_warmup_hook: callable, default = None - A hook function that will be called after the warmup iterations. + A hook function that will be called once after all warmup iterations + (not once per callable). capture_time_hooks: list of dict, optional Per-callable hooks invoked at capture time (during warmup iterations and graph capture), but intentionally executed **outside** the CUDA graph @@ -1374,7 +1375,7 @@ def make_graphed_callables( CPU-side state updates. Each element corresponds to one callable and is a dict with any subset of the following keys: - - ``"forward_pre"``: dict of hooks called *before* the forward pass. + - ``"pre_forward"``: dict of hooks called *before* the forward pass. Hook signature: ``hook(module, args, kwargs)``. - ``"forward"``: dict of hooks called *after* the forward pass. Hook signature: ``hook(module, args, output)``. From d5c5e711b0e92274765952ae2b422da2871b5500 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:03:54 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e4d3eff196..daa8ae28d4 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -801,9 +801,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): ): # Get the callable module for this backward index callable_module = graph_callables[per_callable_bwd_idx] - for hook in capture_time_hooks[callable_idx][ - "pre_backward" - ].values(): + for hook in capture_time_hooks[callable_idx]["pre_backward"].values(): hook(callable_module) inputs = tuple(i for i in static_input_surface if i.requires_grad) From 8baa2c901d53ab86ae68e57da87d5f6c5d05b6c6 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Tue, 7 Apr 2026 23:19:11 -0700 Subject: [PATCH 5/6] align hook naming and signature with Pytorch nn.Module Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 260 ++++++++++++++++++---------- 1 file changed, 171 insertions(+), 89 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index daa8ae28d4..6014a3e05c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -332,13 +332,18 @@ def _make_graphed_callables( if isinstance(c, torch.nn.Module): if not ( len(c._backward_hooks) == 0 + and len(c._backward_pre_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0 ): raise RuntimeError( "Modules must not have hooks registered at the time they are passed. " + "However, registering hooks on modules after passing them " - + "through make_graphed_callables is allowed." + + "through make_graphed_callables is allowed. " + + "If you have to use hooks during capture time, you can provide them " + + "in the capture_time_hooks argument, and they will be executed outside " + + "the CUDA graph capture context, meaning they will not be recorded into " + + "the graph and will not be replayed." ) if not all(b.requires_grad is False for b in c.buffers()): raise RuntimeError( @@ -483,15 +488,22 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus else: visited_te_modules[func_idx].update(modules) - if ( - capture_time_hooks is not None - and capture_time_hooks[callable_idx] is not None - and "pre_forward" in capture_time_hooks[callable_idx] - ): - for hook in capture_time_hooks[callable_idx]["pre_forward"].values(): - result = hook(func, args, kwargs) - if result is not None: - args, kwargs = result + if capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None: + _hooks = capture_time_hooks[callable_idx] + _pre_fwd_with_kwargs = _hooks.get("forward_pre_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_pre_hooks", {}).items(): + if hook_id in _pre_fwd_with_kwargs: + if hook(func, args, kwargs) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a value " + "(args/kwargs must not be modified via hook return)" + ) + else: + if hook(func, args) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a value " + "(args must not be modified via hook return)" + ) hooks = [] for module in func.modules(): @@ -500,15 +512,22 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus for hook in hooks: hook.remove() - if ( - capture_time_hooks is not None - and capture_time_hooks[callable_idx] is not None - and "forward" in capture_time_hooks[callable_idx] - ): - for hook in capture_time_hooks[callable_idx]["forward"].values(): - result = hook(func, args, outputs) - if result is not None: - outputs = result + if capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None: + _hooks = capture_time_hooks[callable_idx] + _fwd_with_kwargs = _hooks.get("forward_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_hooks", {}).items(): + if hook_id in _fwd_with_kwargs: + if hook(func, args, kwargs, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) + else: + if hook(func, args, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) outputs, _ = _tree_flatten(outputs) return outputs @@ -517,30 +536,37 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): """Run dgrad backward for one callable during warmup.""" static_input_surface = per_callable_static_input_surfaces[func_idx] + inputs = tuple(i for i in static_input_surface if i.requires_grad) + outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad) + grad_outputs = tuple(torch.empty_like(o) for o in outputs_requiring_grad) + if ( capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None - and "pre_backward" in capture_time_hooks[callable_idx] + and "backward_pre_hooks" in capture_time_hooks[callable_idx] ): - for hook in capture_time_hooks[callable_idx]["pre_backward"].values(): - hook(func) + for hook in capture_time_hooks[callable_idx]["backward_pre_hooks"].values(): + if hook(func, grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_pre_hooks must not return a value " + "(grad_output must not be modified via hook return)" + ) - inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): - outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad) - torch.autograd.backward( - outputs_requiring_grad, - grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), - ) + torch.autograd.backward(outputs_requiring_grad, grad_tensors=grad_outputs) grad_inputs = tuple(input.grad for input in inputs) if ( capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None - and "backward" in capture_time_hooks[callable_idx] + and "backward_hooks" in capture_time_hooks[callable_idx] ): - for hook in capture_time_hooks[callable_idx]["backward"].values(): - hook(func) + for hook in capture_time_hooks[callable_idx]["backward_hooks"].values(): + if hook(func, grad_inputs, grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_hooks must not return a value " + "(grad_input must not be modified via hook return)" + ) # Filter module params that get None grad from grad_inputs and remove them # from static_input_surface. This is to ensure that the backward hooks @@ -672,29 +698,43 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): fwd_graph = fwd_graphs[per_callable_fwd_idx] # Call pre_forward hooks before forward graph capture (outside capture context) - if ( - capture_time_hooks is not None - and capture_time_hooks[callable_idx] is not None - and "pre_forward" in capture_time_hooks[callable_idx] - ): - for hook in capture_time_hooks[callable_idx]["pre_forward"].values(): - result = hook(func, args, kwargs) - if result is not None: - args, kwargs = result + if capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None: + _hooks = capture_time_hooks[callable_idx] + _pre_fwd_with_kwargs = _hooks.get("forward_pre_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_pre_hooks", {}).items(): + if hook_id in _pre_fwd_with_kwargs: + if hook(func, args, kwargs) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a value " + "(args/kwargs must not be modified via hook return)" + ) + else: + if hook(func, args) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a value " + "(args must not be modified via hook return)" + ) with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) # Call forward hooks after forward graph capture (outside capture context) - if ( - capture_time_hooks is not None - and capture_time_hooks[callable_idx] is not None - and "forward" in capture_time_hooks[callable_idx] - ): - for hook in capture_time_hooks[callable_idx]["forward"].values(): - result = hook(func, args, outputs) - if result is not None: - outputs = result + if capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None: + _hooks = capture_time_hooks[callable_idx] + _fwd_with_kwargs = _hooks.get("forward_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_hooks", {}).items(): + if hook_id in _fwd_with_kwargs: + if hook(func, args, kwargs, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) + else: + if hook(func, args, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) @@ -797,12 +837,18 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): if ( capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None - and "pre_backward" in capture_time_hooks[callable_idx] + and "backward_pre_hooks" in capture_time_hooks[callable_idx] ): # Get the callable module for this backward index callable_module = graph_callables[per_callable_bwd_idx] - for hook in capture_time_hooks[callable_idx]["pre_backward"].values(): - hook(callable_module) + for hook in capture_time_hooks[callable_idx][ + "backward_pre_hooks" + ].values(): + if hook(callable_module, static_grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_pre_hooks must not return a value " + "(grad_output must not be modified via hook return)" + ) inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs), _graph_context_wrapper( @@ -821,12 +867,16 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): if ( capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None - and "backward" in capture_time_hooks[callable_idx] + and "backward_hooks" in capture_time_hooks[callable_idx] ): # Get the callable module for this backward index callable_module = graph_callables[per_callable_bwd_idx] - for hook in capture_time_hooks[callable_idx]["backward"].values(): - hook(callable_module) + for hook in capture_time_hooks[callable_idx]["backward_hooks"].values(): + if hook(callable_module, grad_inputs, static_grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_hooks must not return a value " + "(grad_input must not be modified via hook return)" + ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs @@ -886,30 +936,44 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): zip(callables, sample_args, sample_kwargs, fwd_graphs) ): # Call pre_forward hooks before forward graph capture (outside capture context) - if ( - capture_time_hooks is not None - and capture_time_hooks[func_idx] is not None - and "pre_forward" in capture_time_hooks[func_idx] - ): - for hook in capture_time_hooks[func_idx]["pre_forward"].values(): - result = hook(func, args, kwargs) - if result is not None: - args, kwargs = result + if capture_time_hooks is not None and capture_time_hooks[func_idx] is not None: + _hooks = capture_time_hooks[func_idx] + _pre_fwd_with_kwargs = _hooks.get("forward_pre_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_pre_hooks", {}).items(): + if hook_id in _pre_fwd_with_kwargs: + if hook(func, args, kwargs) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a value " + "(args/kwargs must not be modified via hook return)" + ) + else: + if hook(func, args) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a value " + "(args must not be modified via hook return)" + ) with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) graph_callables[func_idx] = func # Call forward hooks after forward graph capture (outside capture context) - if ( - capture_time_hooks is not None - and capture_time_hooks[func_idx] is not None - and "forward" in capture_time_hooks[func_idx] - ): - for hook in capture_time_hooks[func_idx]["forward"].values(): - result = hook(func, args, outputs) - if result is not None: - outputs = result + if capture_time_hooks is not None and capture_time_hooks[func_idx] is not None: + _hooks = capture_time_hooks[func_idx] + _fwd_with_kwargs = _hooks.get("forward_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_hooks", {}).items(): + if hook_id in _fwd_with_kwargs: + if hook(func, args, kwargs, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) + else: + if hook(func, args, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs.append(tuple(flatten_outputs)) @@ -935,11 +999,15 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): if ( capture_time_hooks is not None and capture_time_hooks[bwd_idx] is not None - and "pre_backward" in capture_time_hooks[bwd_idx] + and "backward_pre_hooks" in capture_time_hooks[bwd_idx] ): callable_module = graph_callables[bwd_idx] - for hook in capture_time_hooks[bwd_idx]["pre_backward"].values(): - hook(callable_module) + for hook in capture_time_hooks[bwd_idx]["backward_pre_hooks"].values(): + if hook(callable_module, static_grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_pre_hooks must not return a value " + "(grad_output must not be modified via hook return)" + ) inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs), _graph_context_wrapper( @@ -956,11 +1024,15 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): if ( capture_time_hooks is not None and capture_time_hooks[bwd_idx] is not None - and "backward" in capture_time_hooks[bwd_idx] + and "backward_hooks" in capture_time_hooks[bwd_idx] ): callable_module = graph_callables[bwd_idx] - for hook in capture_time_hooks[bwd_idx]["backward"].values(): - hook(callable_module) + for hook in capture_time_hooks[bwd_idx]["backward_hooks"].values(): + if hook(callable_module, grad_inputs, static_grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_hooks must not return a value " + "(grad_input must not be modified via hook return)" + ) if need_bwd_dw_graph[bwd_idx]: with _graph_context_wrapper(bwd_dw_graph, pool=mempool): @@ -1370,17 +1442,27 @@ def make_graphed_callables( capture context so they are **not** recorded into the graph and will **not** be replayed. Use this for operations that are inherently non-capturable but essential for correct module execution, such as - CPU-side state updates. + CPU-side state updates. All the hooks must return None, meaning that + modifying tensors is not supported. Any attempt to return a non-None + value will raise ``RuntimeError``. Each element corresponds to one callable and is a dict with any subset - of the following keys: - - ``"pre_forward"``: dict of hooks called *before* the forward pass. - Hook signature: ``hook(module, args, kwargs)``. - - ``"forward"``: dict of hooks called *after* the forward pass. - Hook signature: ``hook(module, args, output)``. - - ``"pre_backward"``: dict of hooks called *before* the backward pass. - Hook signature: ``hook(module)``. - - ``"backward"``: dict of hooks called *after* the backward pass. - Hook signature: ``hook(module)``. + of the following keys (names mirror PyTorch's ``nn.Module`` hook + attributes): + - ``"forward_pre_hooks"``: ``{hook_id: hook_fn}`` dict of *all* pre-forward + hooks (both plain and with-kwargs). Plain signature: ``hook(module, args)``. + With-kwargs signature: ``hook(module, args, kwargs)``. + - ``"forward_pre_hooks_with_kwargs"``: ``{hook_id: True}`` flag set marking + which entries in ``"forward_pre_hooks"`` should be called with kwargs. + - ``"forward_hooks"``: ``{hook_id: hook_fn}`` dict of *all* post-forward + hooks (both plain and with-kwargs). Plain signature: + ``hook(module, args, output)``. With-kwargs signature: + ``hook(module, args, kwargs, output)``. + - ``"forward_hooks_with_kwargs"``: ``{hook_id: True}`` flag set marking + which entries in ``"forward_hooks"`` should be called with kwargs. + - ``"backward_pre_hooks"``: ``{hook_id: hook_fn}`` dict of hooks called + *before* the backward pass. Signature: ``hook(module, grad_output)``. + - ``"backward_hooks"``: ``{hook_id: hook_fn}`` dict of hooks called *after* + the backward pass. Signature: ``hook(module, grad_input, grad_output)``. Quantization parameters ----------------------- From 100b6e36aad7c040c6395b4c1fa0f85639a9e1e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Apr 2026 06:21:04 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 6014a3e05c..964f1a5a16 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -698,28 +698,34 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): fwd_graph = fwd_graphs[per_callable_fwd_idx] # Call pre_forward hooks before forward graph capture (outside capture context) - if capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None: + if ( + capture_time_hooks is not None + and capture_time_hooks[callable_idx] is not None + ): _hooks = capture_time_hooks[callable_idx] _pre_fwd_with_kwargs = _hooks.get("forward_pre_hooks_with_kwargs", {}) for hook_id, hook in _hooks.get("forward_pre_hooks", {}).items(): if hook_id in _pre_fwd_with_kwargs: if hook(func, args, kwargs) is not None: raise RuntimeError( - "capture_time_hooks forward_pre_hooks must not return a value " - "(args/kwargs must not be modified via hook return)" + "capture_time_hooks forward_pre_hooks must not return a" + " value (args/kwargs must not be modified via hook return)" ) else: if hook(func, args) is not None: raise RuntimeError( - "capture_time_hooks forward_pre_hooks must not return a value " - "(args must not be modified via hook return)" + "capture_time_hooks forward_pre_hooks must not return a" + " value (args must not be modified via hook return)" ) with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) # Call forward hooks after forward graph capture (outside capture context) - if capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None: + if ( + capture_time_hooks is not None + and capture_time_hooks[callable_idx] is not None + ): _hooks = capture_time_hooks[callable_idx] _fwd_with_kwargs = _hooks.get("forward_hooks_with_kwargs", {}) for hook_id, hook in _hooks.get("forward_hooks", {}).items(): @@ -846,8 +852,8 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): ].values(): if hook(callable_module, static_grad_outputs) is not None: raise RuntimeError( - "capture_time_hooks backward_pre_hooks must not return a value " - "(grad_output must not be modified via hook return)" + "capture_time_hooks backward_pre_hooks must not return a" + " value (grad_output must not be modified via hook return)" ) inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -872,7 +878,10 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): # Get the callable module for this backward index callable_module = graph_callables[per_callable_bwd_idx] for hook in capture_time_hooks[callable_idx]["backward_hooks"].values(): - if hook(callable_module, grad_inputs, static_grad_outputs) is not None: + if ( + hook(callable_module, grad_inputs, static_grad_outputs) + is not None + ): raise RuntimeError( "capture_time_hooks backward_hooks must not return a value " "(grad_input must not be modified via hook return)"