Skip to content

Commit 456374c

Browse files
committed
[1/N] Autobucketing: add configs and setups for autobucketing
ghstack-source-id: 0eda4b1 Pull-Request: #127
1 parent 85efc8f commit 456374c

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

autoparallel/auto_bucketing.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from .autobucketing_util import bucket_utils
4+
5+
6+
class simplefsdp_autobucketing_config:
7+
"""
8+
Config for simplefsdp's autobucketing pass, which by default would give good performance.
9+
To make the results tunable, we expose the following parameters:
10+
- relax_ratio: relax comp time to include more comm in one bucket
11+
with this config, comp is updated as comp * (1 + relax_ratio)
12+
- peak_memory_offset: relax peak_memory to include more comm in one bucket
13+
with this config, peak_memory is updated as (peak_memory + peak_memory_offset)
14+
- load_cache: set to True to load cache from save_estimation_path
15+
- enable_bucket_ir: set to True to bucket all_gather/reduce_scatter
16+
- enable_reorder_ir: set to True to reorder all_gather/reduce_satter
17+
"""
18+
19+
relax_ratio = 0
20+
peak_memory_offset = 0
21+
load_cache = False
22+
save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl"
23+
enable_bucket_ir = True
24+
enable_reorder_ir = True
25+
26+
27+
def simple_fsdp_autobucketing_reordering_pass(
28+
snodes: list["torch._inductor.scheduler.BaseSchedulerNode"],
29+
configs: "simplefsdp_autobucketing_config",
30+
) -> list["torch._inductor.scheduler.BaseSchedulerNode"]:
31+
scheduler = snodes[0].scheduler
32+
bucket_utils.get_bucketable_ir_nodes(
33+
snodes, scheduler.name_to_fused_node, scheduler.name_to_buf
34+
)
35+
return snodes
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# mypy: ignore-errors
2+
from typing import Any, Callable, Dict
3+
4+
import torch
5+
from torch._inductor import scheduler
6+
from torch._inductor.dependencies import WeakDep
7+
from torch._inductor.utils import buf_name_to_fused_snode, is_collective
8+
from torch.utils._ordered_set import OrderedSet
9+
10+
11+
def _find_recursive_deps_of_snode(
12+
snode: "scheduler.BaseSchedulerNode",
13+
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
14+
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
15+
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
16+
criteria_cb: Callable[[Any], bool] = lambda snode: False,
17+
allow_weak_dep: bool = True,
18+
):
19+
if criteria_cb(snode):
20+
return
21+
collected_node_set.add(snode)
22+
for dep in snode.unmet_dependencies:
23+
if isinstance(dep, WeakDep) and not allow_weak_dep:
24+
continue
25+
defining_op_for_dep = buf_name_to_fused_snode(
26+
dep.name, name_to_buf, name_to_fused_node
27+
)
28+
if defining_op_for_dep in collected_node_set:
29+
continue
30+
_find_recursive_deps_of_snode(
31+
defining_op_for_dep,
32+
collected_node_set,
33+
name_to_buf,
34+
name_to_fused_node,
35+
criteria_cb=criteria_cb,
36+
)
37+
38+
39+
def _find_recursive_users_of_snode(
40+
snode: "scheduler.BaseSchedulerNode",
41+
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
42+
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
43+
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
44+
criteria_cb: Callable[[Any], bool] = lambda snode: False,
45+
):
46+
if criteria_cb(snode):
47+
return
48+
collected_node_set.add(snode)
49+
for o in snode.get_outputs():
50+
for user in o.users:
51+
assert user.node is not None
52+
if user.node.get_name() == "OUTPUT":
53+
continue
54+
if user.node.get_name() not in name_to_fused_node:
55+
continue
56+
user_op = name_to_fused_node[user.node.get_name()]
57+
if user_op in collected_node_set:
58+
continue
59+
_find_recursive_users_of_snode(
60+
user_op,
61+
collected_node_set,
62+
name_to_buf,
63+
name_to_fused_node,
64+
criteria_cb=criteria_cb,
65+
)
66+
67+
68+
def get_bucketable_ir_nodes(
69+
snodes: list["torch._inductor.scheduler.BaseSchedulerNode"],
70+
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
71+
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
72+
) -> set[str]:
73+
"""
74+
This function selects the ir nodes' names that are bucketable
75+
From first principle, only all-gathers that gather parameters and reduce-scatters
76+
that update model gradients could be bucketed together.
77+
Thus, bucketable all-gathers's deps are (1) computed buffer for dtype conversion (optional)
78+
(2) all-gather itself
79+
bucketable reduce-scatter wait's users are (1) reduce-scatter wait itself
80+
"""
81+
bucketable_ir_nodes = set()
82+
for snode in snodes:
83+
if is_collective(
84+
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
85+
):
86+
ag_related_snode_set: OrderedSet[
87+
"torch._inductor.scheduler.BaseSchedulerNode"
88+
] = OrderedSet()
89+
_find_recursive_deps_of_snode(
90+
snode,
91+
ag_related_snode_set,
92+
name_to_buf,
93+
name_to_fused_node,
94+
allow_weak_dep=False,
95+
)
96+
if len(ag_related_snode_set) <= 2:
97+
bucketable_ir_nodes.add(snode.node.get_name())
98+
elif is_collective(
99+
snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
100+
):
101+
wait_snode = snode.get_outputs()[0].users[0].node
102+
wait_snode_recursive_users: OrderedSet[
103+
"torch._inductor.scheduler.BaseSchedulerNode"
104+
] = OrderedSet()
105+
_find_recursive_users_of_snode(
106+
wait_snode,
107+
wait_snode_recursive_users,
108+
name_to_buf,
109+
name_to_fused_node,
110+
)
111+
if len(wait_snode_recursive_users) <= 1:
112+
bucketable_ir_nodes.add(snode.node.get_name())
113+
114+
return bucketable_ir_nodes

0 commit comments

Comments
 (0)