Skip to content

Commit 415b736

Browse files
committed
Pass to split prefetch fsdp graph
stack-info: PR: #201, branch: IvanKobzarev/stack/9
1 parent f1887eb commit 415b736

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

autoparallel/pipeline/passes.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import dataclasses
7+
8+
import torch
9+
import torch.utils._pytree as pytree
10+
from torch._functorch._aot_autograd.descriptors import AOTOutput
11+
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
12+
13+
14+
@dataclasses.dataclass(frozen=True)
15+
class PrefetchOutput(AOTOutput):
16+
pass
17+
18+
19+
def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]:
20+
g_ins = g.find_nodes(op="placeholder")
21+
prefetch_g_outs_map = {}
22+
23+
for g_in in g_ins:
24+
n = g_in
25+
while True:
26+
if len(n.users) != 1:
27+
break
28+
user = next(iter(n.users))
29+
if len(user.all_input_nodes) > 1:
30+
break
31+
n = user
32+
prefetch_g_outs_map[g_in] = n
33+
34+
prefetch_g_outs = list(prefetch_g_outs_map.values())
35+
prefetch_g_outs_descs: list[AOTOutput] = [
36+
PrefetchOutput() for _ in range(len(prefetch_g_outs))
37+
]
38+
39+
prefetch_g = _extract_graph_with_inputs_outputs(
40+
g,
41+
g_ins,
42+
prefetch_g_outs,
43+
prefetch_g_outs_descs,
44+
)
45+
46+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
47+
g_outs_descs = pytree.arg_tree_leaves(
48+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
49+
)
50+
main_g = _extract_graph_with_inputs_outputs(
51+
g,
52+
prefetch_g_outs,
53+
g_outs,
54+
g_outs_descs,
55+
)
56+
return prefetch_g, main_g

examples/example_llama3.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ def _pass(graph):
253253
print(f"Took {time.time() - t:.2f} s")
254254
parallel_mod = autop.apply_placement(sharding_placement)
255255

256+
test_split_fsdp_prefetch = True
257+
if test_split_fsdp_prefetch:
258+
gm = autop.parallel_gm
259+
g = gm.graph
260+
from autoparallel.pipeline.passes import split_fsdp_prefetch
261+
262+
prefetch_g, main_g = split_fsdp_prefetch(g)
263+
256264
# run weight init on our sharded DTensor params
257265
parallel_mod.to_empty(device="cuda")
258266
parallel_mod.init_weights()

0 commit comments

Comments
 (0)