Skip to content

Add capture_time_hooks to make_graphed_callables for non-capturable per-callable hooks#2831

Open
buptzyb wants to merge 4 commits intoNVIDIA:mainfrom
buptzyb:robinz/capture-time-hooks
Open

Add capture_time_hooks to make_graphed_callables for non-capturable per-callable hooks#2831
buptzyb wants to merge 4 commits intoNVIDIA:mainfrom
buptzyb:robinz/capture-time-hooks

Conversation

@buptzyb
Copy link
Copy Markdown
Contributor

@buptzyb buptzyb commented Apr 3, 2026

Description

Summary

make_graphed_callables previously had no mechanism to run per-callable
hooks during warmup/capture that must execute outside the CUDA graph
capture context (i.e., hooks that are inherently non-capturable, such as
CPU-side state updates or FSDP parameter un-shard/re-shard calls).

This PR adds a new capture_time_hooks parameter to
make_graphed_callables that accepts per-callable hooks invoked at
capture time (warmup iterations and graph capture), but intentionally
executed outside the CUDA graph capture context so they are not
recorded into the graph and will not be replayed.

Changes

  • transformer_engine/pytorch/graph.py:
    • Add capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]]
      parameter to make_graphed_callables
    • Invoke hooks around forward and backward passes during both warmup
      iterations and graph capture, in both the _order is not None and
      _order is None capture paths
    • Hook dict structure mirrors PyTorch's _forward_pre_hooks /
      _forward_hooks format: Dict[hook_type, Dict[handle_id, hook_fn]]
      where hook types are pre_forward, forward, pre_backward,
      backward
    • Rename parameter from capture_hookscapture_time_hooks with
      updated docstring clarifying the non-capturable semantics

Motivation

Used by Megatron-LM's FSDP integration: during CUDA Graph capture,
PyTorch's memory allocator is frozen, causing FSDP parameter
un-shard/re-shard (which requires allocation) to fail. By routing FSDP
hooks through capture_time_hooks, they execute outside the capture
context and are manually driven at the right points during
warmup/capture, while the graph itself only records pure GPU compute.

Hook Invocation Order

For each callable at each warmup/capture iteration:

  1. pre_forward hooks — before func(*args, **kwargs)
  2. (CUDA graph capture context entered)
  3. Forward pass
  4. (CUDA graph capture context exited)
  5. forward hooks — after forward
  6. pre_backward hooks — before torch.autograd.backward
  7. Backward pass
  8. backward hooks — after backward

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

buptzyb and others added 2 commits April 3, 2026 05:20
Signed-off-by: Robin Zhang <robinz@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 3, 2026

Greptile Summary

This PR adds a capture_time_hooks parameter to make_graphed_callables, allowing per-callable hooks to run outside the CUDA graph capture context during both warmup iterations and graph capture — enabling non-capturable operations such as FSDP parameter un-shard/re-shard to be driven manually at the correct points.

The implementation is clean: hooks are dispatched symmetrically across both the _order is None and _order is not None capture paths, the forward hook correctly receives raw (pre-flatten) outputs in both warmup and capture (resolving the inconsistency flagged in prior review), and the backward hook fires at a consistent point relative to wgrad in both paths. The one non-obvious constraint — that pre_forward hooks should not return entirely new tensor objects during capture, since the CUDA graph records the device addresses used at capture time — is left undocumented and could silently produce wrong replay behavior for callers unfamiliar with CUDA graph internals.

Confidence Score: 5/5

Safe to merge; the implementation is correct for the intended FSDP use-case and all prior P1 concerns are resolved.

No P0 or P1 issues found. The forward-hook output consistency concern from a prior review round is resolved — both warmup and capture now pass raw (pre-flatten) outputs to the forward hook. The backward hook fires at the same relative position (before wgrad) in both warmup and capture. The sole remaining issue is a P2 documentation gap: the docstring does not warn that pre_forward hooks returning new (non-static) tensor objects will cause silent stale-data reads during replay. This does not affect the primary FSDP use-case.

transformer_engine/pytorch/graph.py — specifically the pre_forward hook documentation under make_graphed_callables.

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Adds capture_time_hooks parameter with per-callable pre_forward/forward/pre_backward/backward hooks executed outside CUDA graph capture context; consistent hook ordering confirmed across warmup and capture paths.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant GC as make_graphed_callables
    participant CUDA as CUDAGraph
    participant H as capture_time_hooks[i]

    rect rgb(240, 248, 255)
        Note over Caller,H: Warmup phase (outside CUDA graph context)
        GC->>H: pre_forward(func, args, kwargs)
        H-->>GC: (opt) updated args, kwargs
        GC->>GC: func(*args, **kwargs)
        GC->>H: forward(func, args, outputs)
        H-->>GC: (opt) updated outputs
        GC->>H: pre_backward(func)
        GC->>GC: torch.autograd.backward(...)
        GC->>H: backward(func)
    end

    rect rgb(255, 245, 220)
        Note over Caller,H: Graph capture phase
        GC->>H: pre_forward(func, args, kwargs)
        H-->>GC: (opt) updated args, kwargs
        GC->>CUDA: begin fwd capture
        GC->>GC: func(*args, **kwargs)
        GC->>CUDA: end fwd capture
        GC->>H: forward(func, args, outputs)
        H-->>GC: (opt) updated outputs
        GC->>H: pre_backward(func)
        GC->>CUDA: begin bwd capture
        GC->>GC: torch.autograd.backward(...)
        GC->>CUDA: end bwd capture
        GC->>H: backward(func)
    end

    rect rgb(230, 255, 230)
        Note over Caller,CUDA: Runtime replay (no hooks)
        Caller->>GC: forward(*live_inputs)
        GC->>CUDA: fwd_graph.replay()
        CUDA-->>Caller: static_outputs
        Caller->>GC: backward(*grads)
        GC->>CUDA: bwd_graph.replay()
    end
Loading

Greploops — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.
Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines 1323 to +1325
pre_warmup_hook: Optional[Callable] = None,
post_warmup_hook: Optional[Callable] = None,
capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's kind of uncomfortable that the APIs for the warmup and capture hooks are inconsistent. Also, the capture_time_hooks is slightly misleading because they are applied during warmup, not just during capture.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've thought about the naming for a long time, but still failed to think of a good name...

At first I want to call it non_capturable_hooks or graph_exterior_hooks, but this only indicates the hooks are not captured into the graph, not saying they won't be called around the graph replay. I also thought of warmup_capture_hooks, but this confuses the users when we have three hook arguments all named in warmup... Maybe graph_construction_hooks is better?

And per_callable_hooks? But this doesn't distinguish between capture and replay time...

So eventually I came to capture_time_hooks. I persuaded myself that warmup is also a part of the capture time. Do you have a better suggestion?

buptzyb and others added 2 commits April 7, 2026 07:02
Signed-off-by: Robin Zhang <robinz@nvidia.com>
@buptzyb buptzyb requested a review from timmoon10 April 7, 2026 14:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants