File tree Expand file tree Collapse file tree 2 files changed +64
-0
lines changed Expand file tree Collapse file tree 2 files changed +64
-0
lines changed Original file line number Diff line number Diff line change 1+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+ #
3+ # This source code is licensed under the BSD license found in the
4+ # LICENSE file in the root directory of this source tree.
5+
6+ import dataclasses
7+
8+ import torch
9+ import torch .utils ._pytree as pytree
10+ from torch ._functorch ._aot_autograd .descriptors import AOTOutput
11+ from torch ._functorch .partitioners import _extract_graph_with_inputs_outputs
12+
13+
14+ @dataclasses .dataclass (frozen = True )
15+ class PrefetchOutput (AOTOutput ):
16+ pass
17+
18+
19+ def split_fsdp_prefetch (g : torch .fx .Graph ) -> tuple [torch .fx .Graph , torch .fx .Graph ]:
20+ g_ins = g .find_nodes (op = "placeholder" )
21+ prefetch_g_outs_map = {}
22+
23+ for g_in in g_ins :
24+ n = g_in
25+ while True :
26+ if len (n .users ) != 1 :
27+ break
28+ user = next (iter (n .users ))
29+ if len (user .all_input_nodes ) > 1 :
30+ break
31+ n = user
32+ prefetch_g_outs_map [g_in ] = n
33+
34+ prefetch_g_outs = list (prefetch_g_outs_map .values ())
35+ prefetch_g_outs_descs : list [AOTOutput ] = [
36+ PrefetchOutput () for _ in range (len (prefetch_g_outs ))
37+ ]
38+
39+ prefetch_g = _extract_graph_with_inputs_outputs (
40+ g ,
41+ g_ins ,
42+ prefetch_g_outs ,
43+ prefetch_g_outs_descs ,
44+ )
45+
46+ g_outs = pytree .arg_tree_leaves (* (n .args for n in g .find_nodes (op = "output" )))
47+ g_outs_descs = pytree .arg_tree_leaves (
48+ next (iter (g .find_nodes (op = "output" ))).meta .get ("desc" , [None ] * len (g_outs ))
49+ )
50+ main_g = _extract_graph_with_inputs_outputs (
51+ g ,
52+ prefetch_g_outs ,
53+ g_outs ,
54+ g_outs_descs ,
55+ )
56+ return prefetch_g , main_g
Original file line number Diff line number Diff line change @@ -253,6 +253,14 @@ def _pass(graph):
253253 print (f"Took { time .time () - t :.2f} s" )
254254 parallel_mod = autop .apply_placement (sharding_placement )
255255
256+ test_split_fsdp_prefetch = True
257+ if test_split_fsdp_prefetch :
258+ gm = autop .parallel_gm
259+ g = gm .graph
260+ from autoparallel .pipeline .passes import split_fsdp_prefetch
261+
262+ prefetch_g , main_g = split_fsdp_prefetch (g )
263+
256264# run weight init on our sharded DTensor params
257265parallel_mod .to_empty (device = "cuda" )
258266parallel_mod .init_weights ()
You can’t perform that action at this time.
0 commit comments