Add capture_time_hooks to make_graphed_callables for non-capturable per-callable hooks#2831
Add capture_time_hooks to make_graphed_callables for non-capturable per-callable hooks#2831buptzyb wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Robin Zhang <robinz@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a The implementation is clean: hooks are dispatched symmetrically across both the Confidence Score: 5/5Safe to merge; the implementation is correct for the intended FSDP use-case and all prior P1 concerns are resolved. No P0 or P1 issues found. The forward-hook output consistency concern from a prior review round is resolved — both warmup and capture now pass raw (pre-flatten) outputs to the forward hook. The backward hook fires at the same relative position (before wgrad) in both warmup and capture. The sole remaining issue is a P2 documentation gap: the docstring does not warn that pre_forward hooks returning new (non-static) tensor objects will cause silent stale-data reads during replay. This does not affect the primary FSDP use-case. transformer_engine/pytorch/graph.py — specifically the pre_forward hook documentation under make_graphed_callables. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant GC as make_graphed_callables
participant CUDA as CUDAGraph
participant H as capture_time_hooks[i]
rect rgb(240, 248, 255)
Note over Caller,H: Warmup phase (outside CUDA graph context)
GC->>H: pre_forward(func, args, kwargs)
H-->>GC: (opt) updated args, kwargs
GC->>GC: func(*args, **kwargs)
GC->>H: forward(func, args, outputs)
H-->>GC: (opt) updated outputs
GC->>H: pre_backward(func)
GC->>GC: torch.autograd.backward(...)
GC->>H: backward(func)
end
rect rgb(255, 245, 220)
Note over Caller,H: Graph capture phase
GC->>H: pre_forward(func, args, kwargs)
H-->>GC: (opt) updated args, kwargs
GC->>CUDA: begin fwd capture
GC->>GC: func(*args, **kwargs)
GC->>CUDA: end fwd capture
GC->>H: forward(func, args, outputs)
H-->>GC: (opt) updated outputs
GC->>H: pre_backward(func)
GC->>CUDA: begin bwd capture
GC->>GC: torch.autograd.backward(...)
GC->>CUDA: end bwd capture
GC->>H: backward(func)
end
rect rgb(230, 255, 230)
Note over Caller,CUDA: Runtime replay (no hooks)
Caller->>GC: forward(*live_inputs)
GC->>CUDA: fwd_graph.replay()
CUDA-->>Caller: static_outputs
Caller->>GC: backward(*grads)
GC->>CUDA: bwd_graph.replay()
end
Greploops — Automatically fix all review issues by running Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| pre_warmup_hook: Optional[Callable] = None, | ||
| post_warmup_hook: Optional[Callable] = None, | ||
| capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None, |
There was a problem hiding this comment.
Nit: It's kind of uncomfortable that the APIs for the warmup and capture hooks are inconsistent. Also, the capture_time_hooks is slightly misleading because they are applied during warmup, not just during capture.
There was a problem hiding this comment.
Yes, I've thought about the naming for a long time, but still failed to think of a good name...
At first I want to call it non_capturable_hooks or graph_exterior_hooks, but this only indicates the hooks are not captured into the graph, not saying they won't be called around the graph replay. I also thought of warmup_capture_hooks, but this confuses the users when we have three hook arguments all named in warmup... Maybe graph_construction_hooks is better?
And per_callable_hooks? But this doesn't distinguish between capture and replay time...
So eventually I came to capture_time_hooks. I persuaded myself that warmup is also a part of the capture time. Do you have a better suggestion?
Signed-off-by: Robin Zhang <robinz@nvidia.com>
for more information, see https://pre-commit.ci
Description
Summary
make_graphed_callablespreviously had no mechanism to run per-callablehooks during warmup/capture that must execute outside the CUDA graph
capture context (i.e., hooks that are inherently non-capturable, such as
CPU-side state updates or FSDP parameter un-shard/re-shard calls).
This PR adds a new
capture_time_hooksparameter tomake_graphed_callablesthat accepts per-callable hooks invoked atcapture time (warmup iterations and graph capture), but intentionally
executed outside the CUDA graph capture context so they are not
recorded into the graph and will not be replayed.
Changes
transformer_engine/pytorch/graph.py:capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]]parameter to
make_graphed_callablesiterations and graph capture, in both the
_order is not Noneand_order is Nonecapture paths_forward_pre_hooks/_forward_hooksformat:Dict[hook_type, Dict[handle_id, hook_fn]]where hook types are
pre_forward,forward,pre_backward,backwardcapture_hooks→capture_time_hookswithupdated docstring clarifying the non-capturable semantics
Motivation
Used by Megatron-LM's FSDP integration: during CUDA Graph capture,
PyTorch's memory allocator is frozen, causing FSDP parameter
un-shard/re-shard (which requires allocation) to fail. By routing FSDP
hooks through
capture_time_hooks, they execute outside the capturecontext and are manually driven at the right points during
warmup/capture, while the graph itself only records pure GPU compute.
Hook Invocation Order
For each callable at each warmup/capture iteration:
pre_forwardhooks — beforefunc(*args, **kwargs)forwardhooks — after forwardpre_backwardhooks — beforetorch.autograd.backwardbackwardhooks — after backwardType of change
Checklist: