Skip to content

Commit a4009c1

Browse files
committed
[1/N] Autobucketing: add configs and setups for autobucketing
ghstack-source-id: 2c82f50 Pull-Request: #127
1 parent 85efc8f commit a4009c1

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed

autoparallel/auto_bucketing.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from .autobucketing_util import bucket_utils
8+
9+
10+
class simplefsdp_autobucketing_config:
11+
"""
12+
Config for simplefsdp's autobucketing pass, which by default would give good performance.
13+
To make the results tunable, we expose the following parameters:
14+
- relax_ratio: relax comp time to include more comm in one bucket
15+
with this config, comp is updated as comp * (1 + relax_ratio)
16+
- peak_memory_offset: relax peak_memory to include more comm in one bucket
17+
with this config, peak_memory is updated as (peak_memory + peak_memory_offset)
18+
- load_cache: set to True to load cache from save_estimation_path
19+
- enable_bucket_ir: set to True to bucket all_gather/reduce_scatter
20+
- enable_reorder_ir: set to True to reorder all_gather/reduce_satter
21+
"""
22+
23+
relax_ratio = 0
24+
peak_memory_offset = 0
25+
load_cache = False
26+
save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl"
27+
enable_bucket_ir = True
28+
enable_reorder_ir = True
29+
30+
31+
def simple_fsdp_autobucketing_reordering_pass(
32+
snodes: list["torch._inductor.scheduler.BaseSchedulerNode"],
33+
configs: "simplefsdp_autobucketing_config",
34+
) -> list["torch._inductor.scheduler.BaseSchedulerNode"]:
35+
scheduler = snodes[0].scheduler
36+
bucket_utils.get_bucketable_ir_nodes(
37+
snodes, scheduler.name_to_fused_node, scheduler.name_to_buf
38+
)
39+
return snodes
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# mypy: ignore-errors
7+
from typing import Any, Callable, Dict
8+
9+
import torch
10+
from torch._inductor import scheduler
11+
from torch._inductor.dependencies import WeakDep
12+
from torch._inductor.utils import buf_name_to_fused_snode, is_collective
13+
from torch.utils._ordered_set import OrderedSet
14+
15+
16+
def _find_recursive_deps_of_snode(
17+
snode: "scheduler.BaseSchedulerNode",
18+
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
19+
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
20+
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
21+
criteria_cb: Callable[[Any], bool] = lambda snode: False,
22+
allow_weak_dep: bool = True,
23+
):
24+
if criteria_cb(snode):
25+
return
26+
collected_node_set.add(snode)
27+
for dep in snode.unmet_dependencies:
28+
if isinstance(dep, WeakDep) and not allow_weak_dep:
29+
continue
30+
defining_op_for_dep = buf_name_to_fused_snode(
31+
dep.name, name_to_buf, name_to_fused_node
32+
)
33+
if defining_op_for_dep in collected_node_set:
34+
continue
35+
_find_recursive_deps_of_snode(
36+
defining_op_for_dep,
37+
collected_node_set,
38+
name_to_buf,
39+
name_to_fused_node,
40+
criteria_cb=criteria_cb,
41+
)
42+
43+
44+
def _find_recursive_users_of_snode(
45+
snode: "scheduler.BaseSchedulerNode",
46+
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
47+
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
48+
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
49+
criteria_cb: Callable[[Any], bool] = lambda snode: False,
50+
):
51+
if criteria_cb(snode):
52+
return
53+
collected_node_set.add(snode)
54+
for o in snode.get_outputs():
55+
for user in o.users:
56+
assert user.node is not None
57+
if user.node.get_name() == "OUTPUT":
58+
continue
59+
if user.node.get_name() not in name_to_fused_node:
60+
continue
61+
user_op = name_to_fused_node[user.node.get_name()]
62+
if user_op in collected_node_set:
63+
continue
64+
_find_recursive_users_of_snode(
65+
user_op,
66+
collected_node_set,
67+
name_to_buf,
68+
name_to_fused_node,
69+
criteria_cb=criteria_cb,
70+
)
71+
72+
73+
def get_bucketable_ir_nodes(
74+
snodes: list["torch._inductor.scheduler.BaseSchedulerNode"],
75+
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
76+
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
77+
) -> set[str]:
78+
"""
79+
This function selects the ir nodes' names that are bucketable
80+
From first principle, only all-gathers that gather parameters and reduce-scatters
81+
that update model gradients could be bucketed together.
82+
Thus, bucketable all-gathers's deps are (1) computed buffer for dtype conversion (optional)
83+
(2) all-gather itself
84+
bucketable reduce-scatter wait's users are (1) reduce-scatter wait itself
85+
"""
86+
bucketable_ir_nodes = set()
87+
for snode in snodes:
88+
if is_collective(
89+
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
90+
):
91+
ag_related_snode_set: OrderedSet[
92+
"torch._inductor.scheduler.BaseSchedulerNode"
93+
] = OrderedSet()
94+
_find_recursive_deps_of_snode(
95+
snode,
96+
ag_related_snode_set,
97+
name_to_buf,
98+
name_to_fused_node,
99+
allow_weak_dep=False,
100+
)
101+
if len(ag_related_snode_set) <= 2:
102+
bucketable_ir_nodes.add(snode.node.get_name())
103+
elif is_collective(
104+
snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
105+
):
106+
wait_snode = snode.get_outputs()[0].users[0].node
107+
wait_snode_recursive_users: OrderedSet[
108+
"torch._inductor.scheduler.BaseSchedulerNode"
109+
] = OrderedSet()
110+
_find_recursive_users_of_snode(
111+
wait_snode,
112+
wait_snode_recursive_users,
113+
name_to_buf,
114+
name_to_fused_node,
115+
)
116+
if len(wait_snode_recursive_users) <= 1:
117+
bucketable_ir_nodes.add(snode.node.get_name())
118+
119+
return bucketable_ir_nodes

0 commit comments

Comments
 (0)