diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 86b8a4acf4..964f1a5a16 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -110,6 +110,7 @@ def _make_graphed_callables( _reuse_graph_input_output_buffers: bool = False, pre_warmup_hook: Optional[Callable] = None, post_warmup_hook: Optional[Callable] = None, + capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -331,13 +332,18 @@ def _make_graphed_callables( if isinstance(c, torch.nn.Module): if not ( len(c._backward_hooks) == 0 + and len(c._backward_pre_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0 ): raise RuntimeError( "Modules must not have hooks registered at the time they are passed. " + "However, registering hooks on modules after passing them " - + "through make_graphed_callables is allowed." + + "through make_graphed_callables is allowed. " + + "If you have to use hooks during capture time, you can provide them " + + "in the capture_time_hooks argument, and they will be executed outside " + + "the CUDA graph capture context, meaning they will not be recorded into " + + "the graph and will not be replayed." ) if not all(b.requires_grad is False for b in c.buffers()): raise RuntimeError( @@ -359,6 +365,12 @@ def _make_graphed_callables( + "for each callable must contain only Tensors. Other types are not allowed." ) + if capture_time_hooks is not None and len(capture_time_hooks) != len(callables): + raise ValueError( + f"capture_time_hooks has {len(capture_time_hooks)} entries but there are " + f"{len(callables)} callables" + ) + # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly # passes to forward (ie, its sample_args) AND the module's parameter attributes. # Note: These per_callable_* variables are not actually @@ -448,111 +460,209 @@ def _make_graphed_callables( visited_te_modules = {} need_bwd_dw_graph = {} - # Run warmup and do the above filtering. - with torch.cuda.stream(torch.cuda.Stream()): - for func_idx, func in zip(warmup_func_idx, warmup_func): - args = sample_args[func_idx] - kwargs = sample_kwargs[func_idx] - static_input_surface = per_callable_static_input_surfaces[func_idx] - - def hook_fn( - module, inputs, outputs, func_idx=func_idx - ): # pylint: disable=unused-argument - modules = set() - if isinstance(module, TransformerEngineBaseModule): - modules.add(module) - # If forward is called on a BasicOperation directly the hook will run - elif isinstance(module, BasicOperation): - modules.add(module) - # If forward is called on a te.ops.Sequential it is not called on its constituent ops - elif isinstance(module, Sequential): - if module._module_groups is None: + def _run_warmup_forward(func_idx, func, callable_idx): + """Run forward for one callable during warmup; returns flattened outputs.""" + args = sample_args[func_idx] + kwargs = sample_kwargs[func_idx] + + def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unused-argument + modules = set() + if isinstance(module, TransformerEngineBaseModule): + modules.add(module) + # If forward is called on a BasicOperation directly the hook will run + elif isinstance(module, BasicOperation): + modules.add(module) + # If forward is called on a te.ops.Sequential it is not called on its constituent ops + elif isinstance(module, Sequential): + if module._module_groups is None: + raise RuntimeError( + "module._module_groups should have been initialized by warmup" + ) + for module_group in module._module_groups: + if isinstance(module_group, OperationFuser): + for basic_op in module_group._basic_ops: + modules.add(basic_op) + if modules: + if func_idx not in visited_te_modules: + visited_te_modules[func_idx] = modules + else: + visited_te_modules[func_idx].update(modules) + + if capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None: + _hooks = capture_time_hooks[callable_idx] + _pre_fwd_with_kwargs = _hooks.get("forward_pre_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_pre_hooks", {}).items(): + if hook_id in _pre_fwd_with_kwargs: + if hook(func, args, kwargs) is not None: raise RuntimeError( - "module._module_groups should have been initialized by warmup" + "capture_time_hooks forward_pre_hooks must not return a value " + "(args/kwargs must not be modified via hook return)" ) - for module_group in module._module_groups: - if isinstance(module_group, OperationFuser): - for basic_op in module_group._basic_ops: - modules.add(basic_op) - if modules: - if func_idx not in visited_te_modules: - visited_te_modules[func_idx] = modules - else: - visited_te_modules[func_idx].update(modules) - - if pre_warmup_hook is not None: - pre_warmup_hook() - for warmup_iter in range(num_warmup_iters): - hooks = [] - for module in func.modules(): - hook = module.register_forward_hook(hook_fn) - hooks.append(hook) - outputs, _ = _tree_flatten(func(*args, **kwargs)) - for hook in hooks: - hook.remove() - if is_training: - inputs = tuple(i for i in static_input_surface if i.requires_grad) - with _none_grad_context_wrapper(inputs): - outputs_requiring_grad = tuple( - o for o in outputs if o is not None and o.requires_grad + else: + if hook(func, args) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a value " + "(args must not be modified via hook return)" ) - torch.autograd.backward( - outputs_requiring_grad, - grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), + + hooks = [] + for module in func.modules(): + hooks.append(module.register_forward_hook(hook_fn)) + outputs = func(*args, **kwargs) + for hook in hooks: + hook.remove() + + if capture_time_hooks is not None and capture_time_hooks[callable_idx] is not None: + _hooks = capture_time_hooks[callable_idx] + _fwd_with_kwargs = _hooks.get("forward_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_hooks", {}).items(): + if hook_id in _fwd_with_kwargs: + if hook(func, args, kwargs, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" ) - grad_inputs = tuple(input.grad for input in inputs) + else: + if hook(func, args, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) + + outputs, _ = _tree_flatten(outputs) + return outputs + + def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx): + """Run dgrad backward for one callable during warmup.""" + static_input_surface = per_callable_static_input_surfaces[func_idx] + + inputs = tuple(i for i in static_input_surface if i.requires_grad) + outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad) + grad_outputs = tuple(torch.empty_like(o) for o in outputs_requiring_grad) - # Filter module params that get None grad from grad_inputs and remove them - # from static_input_surface. This is to ensure that the backward hooks - # registered to these params are not wrongly triggered. - num_required_grad_sample_args = sum( - arg.requires_grad for arg in flatten_sample_args[func_idx] + if ( + capture_time_hooks is not None + and capture_time_hooks[callable_idx] is not None + and "backward_pre_hooks" in capture_time_hooks[callable_idx] + ): + for hook in capture_time_hooks[callable_idx]["backward_pre_hooks"].values(): + if hook(func, grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_pre_hooks must not return a value " + "(grad_output must not be modified via hook return)" ) - required_grad_input_idx = [] - for i, arg in enumerate(static_input_surface): - if arg.requires_grad: - required_grad_input_idx.append(i) - module_params_with_grad = [] - for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): - if ( - grad_inputs[grad_inputs_idx] is None - and grad_inputs_idx < num_required_grad_sample_args - ): - if not allow_unused_input: - raise RuntimeError( - "The input tensor requires grad, but the grad is None after" - " backward pass." + + with _none_grad_context_wrapper(inputs): + torch.autograd.backward(outputs_requiring_grad, grad_tensors=grad_outputs) + grad_inputs = tuple(input.grad for input in inputs) + + if ( + capture_time_hooks is not None + and capture_time_hooks[callable_idx] is not None + and "backward_hooks" in capture_time_hooks[callable_idx] + ): + for hook in capture_time_hooks[callable_idx]["backward_hooks"].values(): + if hook(func, grad_inputs, grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_hooks must not return a value " + "(grad_input must not be modified via hook return)" + ) + + # Filter module params that get None grad from grad_inputs and remove them + # from static_input_surface. This is to ensure that the backward hooks + # registered to these params are not wrongly triggered. + num_required_grad_sample_args = sum( + arg.requires_grad for arg in flatten_sample_args[func_idx] + ) + required_grad_input_idx = [] + for i, arg in enumerate(static_input_surface): + if arg.requires_grad: + required_grad_input_idx.append(i) + module_params_with_grad = [] + for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): + if ( + grad_inputs[grad_inputs_idx] is None + and grad_inputs_idx < num_required_grad_sample_args + ): + if not allow_unused_input: + raise RuntimeError( + "The input tensor requires grad, but the grad is None after backward pass." + ) + elif ( + grad_inputs[grad_inputs_idx] is not None + and grad_inputs_idx >= num_required_grad_sample_args + ): + module_params_with_grad.append(static_input_surface[inputs_idx]) + if len(module_params_with_grad) != len(per_callable_module_params[func_idx]): + if warmup_iter != 0: + raise RuntimeError( + "no-grad params should only be used as inputs in the first warmup" + f" iteration, but found in iteration {warmup_iter}" + ) + per_callable_module_params[func_idx] = tuple(module_params_with_grad) + static_input_surface = flatten_sample_args[func_idx] + tuple(module_params_with_grad) + per_callable_static_input_surfaces[func_idx] = static_input_surface + + # Run wgrad. This is essential for some TE modules when they have + # delay_wgrad_compute enabled. + need_backward_dw = False + for module in visited_te_modules.get(func_idx, set()): + if hasattr(module, "need_backward_dw") and module.need_backward_dw(): + need_backward_dw = True + module.backward_dw() + need_bwd_dw_graph[func_idx] = need_backward_dw + + # Run warmup and do the above filtering. + with torch.cuda.stream(torch.cuda.Stream()): + if pre_warmup_hook is not None: + pre_warmup_hook() + + for warmup_iter in range(num_warmup_iters): + if _order is None: + # All forwards in order, then all backwards in reverse order. + warmup_outputs = [] + for func_idx, func in zip(warmup_func_idx, warmup_func): + outputs = _run_warmup_forward(func_idx, func, func_idx) + warmup_outputs.append((func_idx, func, outputs)) + if is_training: + for func_idx, func, outputs in reversed(warmup_outputs): + _run_warmup_backward(func_idx, func, outputs, warmup_iter, func_idx) + else: + # Follow _order exactly, mirroring the capture phase. + per_fwd_outputs = {} # per_callable_fwd_idx -> flattened outputs + fwd_idx = [0] * num_model_chunks + bwd_idx = [0] * num_model_chunks + for c_id in _order: + if c_id > 0: + # Forward pass for chunk c_id. + m_chunk = c_id - 1 + for l_no in range(_num_layers_per_chunk[m_chunk]): + callable_idx = _prefix_num_layers[m_chunk] + l_no + per_callable_fwd_idx = ( + _prefix_num_layers[m_chunk] * num_microbatches + ) + (fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) + func = callables[callable_idx] + outputs = _run_warmup_forward(per_callable_fwd_idx, func, callable_idx) + per_fwd_outputs[per_callable_fwd_idx] = outputs + fwd_idx[m_chunk] += 1 + elif ceil(c_id) == c_id: + # Backward pass for chunk -c_id. + if is_training: + m_chunk = -c_id - 1 + for l_no in reversed(range(_num_layers_per_chunk[m_chunk])): + callable_idx = _prefix_num_layers[m_chunk] + l_no + per_callable_bwd_idx = ( + _prefix_num_layers[m_chunk] * num_microbatches + ) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) + func = callables[callable_idx] + outputs = per_fwd_outputs[per_callable_bwd_idx] + _run_warmup_backward( + per_callable_bwd_idx, func, outputs, warmup_iter, callable_idx ) - elif ( - grad_inputs[grad_inputs_idx] is not None - and grad_inputs_idx >= num_required_grad_sample_args - ): - module_params_with_grad.append(static_input_surface[inputs_idx]) - if len(module_params_with_grad) != len(per_callable_module_params[func_idx]): - if warmup_iter != 0: - raise RuntimeError( - "no-grad params should only be used as inputs in the first warmup" - f" iteration, but found in iteration {warmup_iter}" - ) - per_callable_module_params[func_idx] = tuple(module_params_with_grad) - static_input_surface = flatten_sample_args[func_idx] + tuple( - module_params_with_grad - ) - per_callable_static_input_surfaces[func_idx] = static_input_surface - - # Run wgrad. This is essential for some TE modules when they have - # delay_wgrad_compute enabled. - need_backward_dw = False - for module in visited_te_modules.get(func_idx, set()): - if hasattr(module, "need_backward_dw") and module.need_backward_dw(): - need_backward_dw = True - module.backward_dw() - need_bwd_dw_graph[func_idx] = need_backward_dw - else: - grad_inputs = None - del outputs, grad_inputs - if post_warmup_hook is not None: - post_warmup_hook() + bwd_idx[m_chunk] += 1 + + if post_warmup_hook is not None: + post_warmup_hook() torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -578,15 +688,60 @@ def hook_fn( # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] m_chunk = c_id - 1 for l_no in range(_num_layers_per_chunk[m_chunk]): - func = callables[_prefix_num_layers[m_chunk] + l_no] + callable_idx = _prefix_num_layers[m_chunk] + l_no + func = callables[callable_idx] per_callable_fwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no ) args = sample_args[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] + + # Call pre_forward hooks before forward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and capture_time_hooks[callable_idx] is not None + ): + _hooks = capture_time_hooks[callable_idx] + _pre_fwd_with_kwargs = _hooks.get("forward_pre_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_pre_hooks", {}).items(): + if hook_id in _pre_fwd_with_kwargs: + if hook(func, args, kwargs) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a" + " value (args/kwargs must not be modified via hook return)" + ) + else: + if hook(func, args) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a" + " value (args must not be modified via hook return)" + ) + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) + + # Call forward hooks after forward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and capture_time_hooks[callable_idx] is not None + ): + _hooks = capture_time_hooks[callable_idx] + _fwd_with_kwargs = _hooks.get("forward_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_hooks", {}).items(): + if hook_id in _fwd_with_kwargs: + if hook(func, args, kwargs, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) + else: + if hook(func, args, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) + flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec @@ -597,6 +752,7 @@ def hook_fn( m_chunk = -ceil(c_id) - 1 previous_per_callable_bwd_idx = None for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))): + callable_idx = _prefix_num_layers[m_chunk] + l_no per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no ) @@ -683,6 +839,23 @@ def hook_fn( for o in static_outputs ) if is_training: + # Call pre_backward hooks before backward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and capture_time_hooks[callable_idx] is not None + and "backward_pre_hooks" in capture_time_hooks[callable_idx] + ): + # Get the callable module for this backward index + callable_module = graph_callables[per_callable_bwd_idx] + for hook in capture_time_hooks[callable_idx][ + "backward_pre_hooks" + ].values(): + if hook(callable_module, static_grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_pre_hooks must not return a" + " value (grad_output must not be modified via hook return)" + ) + inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs), _graph_context_wrapper( bwd_graph, pool=mempool @@ -696,6 +869,24 @@ def hook_fn( ) grad_inputs = tuple(input.grad for input in inputs) + # Call backward hooks after backward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and capture_time_hooks[callable_idx] is not None + and "backward_hooks" in capture_time_hooks[callable_idx] + ): + # Get the callable module for this backward index + callable_module = graph_callables[per_callable_bwd_idx] + for hook in capture_time_hooks[callable_idx]["backward_hooks"].values(): + if ( + hook(callable_module, grad_inputs, static_grad_outputs) + is not None + ): + raise RuntimeError( + "capture_time_hooks backward_hooks must not return a value " + "(grad_input must not be modified via hook return)" + ) + # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs # that don't require grad. I couldn't think of a one-liner for this pattern. @@ -750,12 +941,48 @@ def hook_fn( # Capture forward graphs per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] - graph_id = 0 - for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs): + for func_idx, (func, args, kwargs, fwd_graph) in enumerate( + zip(callables, sample_args, sample_kwargs, fwd_graphs) + ): + # Call pre_forward hooks before forward graph capture (outside capture context) + if capture_time_hooks is not None and capture_time_hooks[func_idx] is not None: + _hooks = capture_time_hooks[func_idx] + _pre_fwd_with_kwargs = _hooks.get("forward_pre_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_pre_hooks", {}).items(): + if hook_id in _pre_fwd_with_kwargs: + if hook(func, args, kwargs) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a value " + "(args/kwargs must not be modified via hook return)" + ) + else: + if hook(func, args) is not None: + raise RuntimeError( + "capture_time_hooks forward_pre_hooks must not return a value " + "(args must not be modified via hook return)" + ) + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) - graph_callables[graph_id] = func - graph_id += 1 + graph_callables[func_idx] = func + + # Call forward hooks after forward graph capture (outside capture context) + if capture_time_hooks is not None and capture_time_hooks[func_idx] is not None: + _hooks = capture_time_hooks[func_idx] + _fwd_with_kwargs = _hooks.get("forward_hooks_with_kwargs", {}) + for hook_id, hook in _hooks.get("forward_hooks", {}).items(): + if hook_id in _fwd_with_kwargs: + if hook(func, args, kwargs, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) + else: + if hook(func, args, outputs) is not None: + raise RuntimeError( + "capture_time_hooks forward_hooks must not return a value " + "(output must not be modified via hook return)" + ) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs.append(tuple(flatten_outputs)) @@ -777,6 +1004,20 @@ def hook_fn( for o in static_outputs ) if is_training: + # Call pre_backward hooks before backward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and capture_time_hooks[bwd_idx] is not None + and "backward_pre_hooks" in capture_time_hooks[bwd_idx] + ): + callable_module = graph_callables[bwd_idx] + for hook in capture_time_hooks[bwd_idx]["backward_pre_hooks"].values(): + if hook(callable_module, static_grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_pre_hooks must not return a value " + "(grad_output must not be modified via hook return)" + ) + inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs), _graph_context_wrapper( bwd_graph, pool=mempool @@ -788,6 +1029,20 @@ def hook_fn( ) grad_inputs = tuple(input.grad for input in inputs) + # Call backward hooks after backward graph capture (outside capture context) + if ( + capture_time_hooks is not None + and capture_time_hooks[bwd_idx] is not None + and "backward_hooks" in capture_time_hooks[bwd_idx] + ): + callable_module = graph_callables[bwd_idx] + for hook in capture_time_hooks[bwd_idx]["backward_hooks"].values(): + if hook(callable_module, grad_inputs, static_grad_outputs) is not None: + raise RuntimeError( + "capture_time_hooks backward_hooks must not return a value " + "(grad_input must not be modified via hook return)" + ) + if need_bwd_dw_graph[bwd_idx]: with _graph_context_wrapper(bwd_dw_graph, pool=mempool): for module in visited_te_modules[bwd_idx]: @@ -1145,6 +1400,7 @@ def make_graphed_callables( _reuse_graph_input_output_buffers: bool = False, pre_warmup_hook: Optional[Callable] = None, post_warmup_hook: Optional[Callable] = None, + capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -1184,9 +1440,38 @@ def make_graphed_callables( when `_order` is provided. All callables in `modules` are assumed to have inputs and outputs with the same dtype and shape. pre_warmup_hook: callable, default = None - A hook function that will be called before the warmup iterations. + A hook function that will be called once before all warmup iterations + (not once per callable). post_warmup_hook: callable, default = None - A hook function that will be called after the warmup iterations. + A hook function that will be called once after all warmup iterations + (not once per callable). + capture_time_hooks: list of dict, optional + Per-callable hooks invoked at capture time (during warmup iterations and + graph capture), but intentionally executed **outside** the CUDA graph + capture context so they are **not** recorded into the graph and will + **not** be replayed. Use this for operations that are inherently + non-capturable but essential for correct module execution, such as + CPU-side state updates. All the hooks must return None, meaning that + modifying tensors is not supported. Any attempt to return a non-None + value will raise ``RuntimeError``. + Each element corresponds to one callable and is a dict with any subset + of the following keys (names mirror PyTorch's ``nn.Module`` hook + attributes): + - ``"forward_pre_hooks"``: ``{hook_id: hook_fn}`` dict of *all* pre-forward + hooks (both plain and with-kwargs). Plain signature: ``hook(module, args)``. + With-kwargs signature: ``hook(module, args, kwargs)``. + - ``"forward_pre_hooks_with_kwargs"``: ``{hook_id: True}`` flag set marking + which entries in ``"forward_pre_hooks"`` should be called with kwargs. + - ``"forward_hooks"``: ``{hook_id: hook_fn}`` dict of *all* post-forward + hooks (both plain and with-kwargs). Plain signature: + ``hook(module, args, output)``. With-kwargs signature: + ``hook(module, args, kwargs, output)``. + - ``"forward_hooks_with_kwargs"``: ``{hook_id: True}`` flag set marking + which entries in ``"forward_hooks"`` should be called with kwargs. + - ``"backward_pre_hooks"``: ``{hook_id: hook_fn}`` dict of hooks called + *before* the backward pass. Signature: ``hook(module, grad_output)``. + - ``"backward_hooks"``: ``{hook_id: hook_fn}`` dict of hooks called *after* + the backward pass. Signature: ``hook(module, grad_input, grad_output)``. Quantization parameters ----------------------- @@ -1380,6 +1665,7 @@ def call_func(self, *args, **kwargs): _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, pre_warmup_hook=pre_warmup_hook, post_warmup_hook=post_warmup_hook, + capture_time_hooks=capture_time_hooks, ) # Ensures warmup does not affect numerics for ops such as dropout.