Skip to content

Commit a984dd2

Browse files
committed
[3/N] Autobucketing: add greedy algorithm for autobucketing
ghstack-source-id: 08af129 Pull-Request: #129
1 parent c10c095 commit a984dd2

File tree

3 files changed

+457
-2
lines changed

3 files changed

+457
-2
lines changed

autoparallel/auto_bucketing.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from .autobucketing_util import bucket_utils
8+
from .autobucketing_util import bucket_plan, bucket_utils
99

1010

1111
class simplefsdp_autobucketing_config:
@@ -36,7 +36,21 @@ def simple_fsdp_autobucketing_reordering_pass(
3636
configs: "simplefsdp_autobucketing_config",
3737
) -> list["torch._inductor.scheduler.BaseSchedulerNode"]:
3838
scheduler = snodes[0].scheduler
39-
bucket_utils.get_bucketable_ir_nodes(
39+
bucketable_nodes = bucket_utils.get_bucketable_ir_nodes(
4040
snodes, scheduler.name_to_fused_node, scheduler.name_to_buf
4141
)
42+
43+
assert (
44+
not torch._inductor.config.allow_buffer_reuse
45+
), "bucketing algorithm requires torch._inductor.config.allow_buffer_reuse to be False"
46+
47+
if configs.enable_bucket_ir:
48+
all_gather_plan, reduce_scatter_plan = bucket_plan.get_simplefsdp_auto_plan(
49+
scheduler,
50+
snodes,
51+
scheduler.name_to_buf,
52+
scheduler.name_to_fused_node,
53+
bucketable_nodes,
54+
configs,
55+
)
4256
return snodes
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import functools
7+
8+
# mypy: ignore-errors
9+
from collections import defaultdict
10+
from typing import Any, Dict
11+
12+
import torch
13+
from torch._C._distributed_c10d import ReduceOp
14+
from torch._inductor import scheduler
15+
from torch._inductor.comm import _schedule_fallback_operation
16+
from torch._inductor.utils import is_collective
17+
18+
from .bucket_utils import (
19+
check_ir_node_bucketable,
20+
estimate_bucketed_snode_runtime,
21+
get_data_size,
22+
get_snode_process_group_info,
23+
get_snode_tensor_info,
24+
)
25+
from .estimation import benchmark_and_sync_runtime
26+
27+
28+
def get_dynamic_memory_threshold(
29+
peak_memory,
30+
peak_memory_per_step_dict,
31+
current_step,
32+
) -> int:
33+
"""
34+
this function calculates the memory gap from the current step's memory to the peak memory
35+
"""
36+
left_peak_memory = 0
37+
right_peak_memory = 0
38+
for idx, memory in peak_memory_per_step_dict.items():
39+
if idx <= current_step:
40+
left_peak_memory = max(memory, left_peak_memory)
41+
if idx >= current_step:
42+
right_peak_memory = max(memory, right_peak_memory)
43+
current_peak_memory = min(left_peak_memory, right_peak_memory)
44+
return peak_memory - current_peak_memory
45+
46+
47+
def get_simplefsdp_auto_plan(
48+
sched: "scheduler.Scheduler",
49+
snodes: list["scheduler.BaseSchedulerNode"],
50+
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
51+
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
52+
bucketable_nodes: set[str],
53+
configs: Any,
54+
verbose: bool = True,
55+
) -> tuple[
56+
list[Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]]],
57+
list[Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]]],
58+
]:
59+
"""
60+
This function implements a greedy algorithm, which decides if the node could be bucketed
61+
with the previous one based on several criteria below:
62+
FWD Pass:
63+
(1) the bucketed AG communication could be overlapped by the previous computation;
64+
(2) the bucketed AG memory doesn't exceed peak memory;
65+
(3) bucketed AG communication size doesn't exceed 0.2*sum(fwd_ag_tensor_list), such
66+
that the estimated AG communication time is always in the calibration bound.
67+
BWD Pass:
68+
(1) the bucketed AG + RS communication could be overlapped by the previous computation;
69+
(2) the bucketed AG+RS memory doesn't exceed peak memory;
70+
(3) RS always have future compute to overlap it, such that its final exposed communication is small;
71+
(4) bucketed AG/RS communication size doesn't exceed 0.2* sum(fwd_ag_tensor_list) & 0.2* sum(bwd_rs_tensor_list),
72+
such that the estimated AG/RS communication time is always in the calibration bound.
73+
"""
74+
all_gather_plan = []
75+
reduce_scatter_plan = []
76+
current_ag_bucket: Dict[
77+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
78+
] = defaultdict(list)
79+
current_rs_bucket: Dict[
80+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
81+
] = defaultdict(list)
82+
schedule_fallback_operation = functools.partial(
83+
_schedule_fallback_operation,
84+
scheduler=sched,
85+
name_to_buf=name_to_buf,
86+
name_to_fused_node=name_to_fused_node,
87+
)
88+
89+
heuristic_info = {
90+
# time info
91+
"last_step_rs_comm_time": 0.0,
92+
"this_step_comp_time": 0.0,
93+
"this_step_rs_comm_time": 0.0,
94+
"next_step_comp_time": 0.0,
95+
"next_step_nonfsdp_comm_time": 0.0,
96+
# memory info
97+
"accumulated_gradient_size": 0,
98+
"last_step_rs_comm_size": 0,
99+
"this_step_rs_comm_out_size": 0,
100+
"this_step_rs_comm_inp_size": 0,
101+
"this_step_memory": 0,
102+
"next_step_memory": 0,
103+
}
104+
105+
# sync runtime info across ranks
106+
(
107+
comm_cache,
108+
comp_time_dict,
109+
memory_dict,
110+
peak_memory_per_step_dict,
111+
) = benchmark_and_sync_runtime(sched, snodes, bucketable_nodes)
112+
future_comp_time = sum(comp_time_dict.values())
113+
peak_memory = max(peak_memory_per_step_dict.values()) + configs.peak_memory_offset
114+
115+
# autobucket algorithm
116+
bucketable_ag_idx = -1
117+
seen_new_bucketable_ag = True
118+
for _, snode in enumerate(snodes):
119+
if is_collective(
120+
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
121+
) and check_ir_node_bucketable(snode.node, bucketable_nodes):
122+
bucketable_ag_idx += 1
123+
seen_new_bucketable_ag = True
124+
future_comp_time -= comp_time_dict[bucketable_ag_idx]
125+
126+
ag_node_info = get_snode_tensor_info(
127+
snode, return_data_size=False
128+
) + get_snode_process_group_info(
129+
snode,
130+
expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default,
131+
resolve_pg=False,
132+
)
133+
current_ag_bucket[ag_node_info].append(snode)
134+
(
135+
estimated_comm,
136+
comm_size_inp,
137+
comm_size_out,
138+
) = estimate_bucketed_snode_runtime(
139+
current_ag_bucket,
140+
schedule_fallback_operation,
141+
name_to_buf,
142+
torch.ops._c10d_functional.all_gather_into_tensor.default,
143+
comm_cache,
144+
)
145+
146+
# Check if current bucketing breaks the greedy criteria
147+
# (1) Overlappping criteria
148+
comp_time = heuristic_info["this_step_comp_time"] * (
149+
1 + configs.relax_ratio
150+
)
151+
comm_time = estimated_comm + heuristic_info["last_step_rs_comm_time"]
152+
break_overlap_criteria = comp_time < comm_time
153+
154+
# (2) Memory criteria
155+
memory_threshold = get_dynamic_memory_threshold(
156+
peak_memory,
157+
peak_memory_per_step_dict,
158+
bucketable_ag_idx,
159+
)
160+
# the buckted AG/RS are created on-the-fly, whose memory was not captured by the
161+
# estimate_peak_memory function. The bucketed_comm_memory consists of:
162+
# in FWD pass:
163+
# (1) all-gather copy-in (comm_size_inp): smaller buffers for dtype_conversion + bigger buffer to copy_in smaller buffers
164+
# thus, we have comm_size_inp*2
165+
# (2) all-gather copy-out (comm_size_out): bigger buffer to copy_out from ag_wait + split out smaller buffers for compute
166+
# thus, we have comm_size_out*2
167+
# in BWD pass:
168+
# TODO (ruisizhang123): we need to double check this. From memory trace, we can clearly see
169+
# these three regions stack together at a certain moment
170+
# due to reordering, the peak memory occurs at the end of current step's all-gather when last step & this step's reduce-scatter
171+
# are not cleared in time
172+
# (1) all-gather copy-in/copy-out (like FWD pass)
173+
# (2) last step's reduce-scatter: bigger buffer containts gradient
174+
# (3) next step's reduce-scatter: smaller buffers for dtype_conversion + bigger buffer to copy_in gradient
175+
bucketed_comm_memory = (
176+
2 * comm_size_inp
177+
+ 2 * comm_size_out
178+
+ heuristic_info["accumulated_gradient_size"]
179+
+ heuristic_info["this_step_rs_comm_inp_size"] * 2
180+
+ heuristic_info["last_step_rs_comm_size"]
181+
)
182+
break_memory_criteria = (
183+
memory_threshold
184+
< heuristic_info["next_step_memory"] + bucketed_comm_memory
185+
)
186+
187+
# (3) Communication size criteria
188+
break_comm_size_criteria = (
189+
comm_cache.ag_max_inp_size <= comm_size_inp
190+
or comm_cache.rs_max_out_size
191+
<= heuristic_info["this_step_rs_comm_out_size"]
192+
)
193+
194+
if (
195+
break_overlap_criteria
196+
or break_memory_criteria
197+
or break_comm_size_criteria
198+
):
199+
if heuristic_info["this_step_comp_time"] > 0:
200+
# if bucketing breaks the greedy criteria, pop the last node out
201+
overflow_ag = current_ag_bucket[ag_node_info].pop()
202+
all_gather_plan.append(current_ag_bucket)
203+
current_ag_bucket: Dict[
204+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
205+
] = defaultdict(list)
206+
current_ag_bucket[ag_node_info].append(overflow_ag)
207+
else:
208+
# if there is no compute, we have to keep the all_gather to avoid deadlock
209+
all_gather_plan.append(current_ag_bucket)
210+
current_ag_bucket: Dict[
211+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
212+
] = defaultdict(list)
213+
214+
if verbose:
215+
print(
216+
"break_overlap_criteria",
217+
break_overlap_criteria,
218+
)
219+
print("Current comm time", comm_time, "comp time", comp_time)
220+
print(
221+
"break_memory_criteria",
222+
break_memory_criteria,
223+
)
224+
print(
225+
"memory_threshold",
226+
memory_threshold,
227+
"total memory",
228+
heuristic_info["next_step_memory"] + bucketed_comm_memory,
229+
)
230+
print(
231+
"break_comm_size_criteria",
232+
break_comm_size_criteria,
233+
)
234+
print("current_ag_bucket", all_gather_plan[-1])
235+
236+
# bucket reduce scatters if there are any
237+
if len(current_rs_bucket) > 0:
238+
(
239+
current_estimated_rs,
240+
rs_comm_size_inp,
241+
rs_comm_size_out,
242+
) = estimate_bucketed_snode_runtime(
243+
current_rs_bucket,
244+
schedule_fallback_operation,
245+
name_to_buf,
246+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
247+
comm_cache,
248+
ReduceOp.AVG,
249+
)
250+
heuristic_info["last_step_rs_comm_time"] = current_estimated_rs
251+
reduce_scatter_plan.append(current_rs_bucket)
252+
heuristic_info["last_step_rs_comm_size"] = rs_comm_size_out
253+
current_rs_bucket: Dict[
254+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
255+
] = defaultdict(list)
256+
257+
# update heuristic info for the next step
258+
(
259+
heuristic_info["this_step_comp_time"],
260+
heuristic_info["this_step_memory"],
261+
) = (
262+
heuristic_info["next_step_comp_time"]
263+
+ heuristic_info["next_step_nonfsdp_comm_time"],
264+
heuristic_info["next_step_memory"],
265+
)
266+
(
267+
heuristic_info["next_step_comp_time"],
268+
heuristic_info["next_step_memory"],
269+
) = (
270+
0,
271+
0,
272+
)
273+
elif is_collective(
274+
snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
275+
) and check_ir_node_bucketable(snode.node, bucketable_nodes):
276+
node_info = get_snode_tensor_info(
277+
snode, return_data_size=False
278+
) + get_snode_process_group_info(
279+
snode,
280+
expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default,
281+
resolve_pg=False,
282+
)
283+
current_rs_bucket[node_info].append(snode)
284+
285+
(
286+
heuristic_info["this_step_rs_comm_time"],
287+
rs_comm_size_inp,
288+
rs_comm_size_out,
289+
) = estimate_bucketed_snode_runtime(
290+
current_rs_bucket,
291+
schedule_fallback_operation,
292+
name_to_buf,
293+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
294+
comm_cache,
295+
ReduceOp.AVG,
296+
)
297+
heuristic_info["this_step_rs_comm_out_size"] = rs_comm_size_out
298+
heuristic_info["this_step_rs_comm_inp_size"] = rs_comm_size_inp
299+
heuristic_info["accumulated_gradient_size"] += get_data_size(
300+
snode.node.layout.size
301+
)
302+
303+
# Check if current bucketing breaks the greedy criteria
304+
# (4) future compute to overlap RS criteria
305+
break_rs_overlap_criteria = (
306+
future_comp_time < heuristic_info["this_step_rs_comm_time"] * 5
307+
)
308+
if break_rs_overlap_criteria:
309+
reduce_scatter_plan.append(current_rs_bucket)
310+
heuristic_info["last_step_rs_comm_time"] = heuristic_info[
311+
"this_step_rs_comm_time"
312+
]
313+
heuristic_info["this_step_rs_comm_time"] = 0
314+
current_rs_bucket: Dict[
315+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
316+
] = defaultdict(list)
317+
else:
318+
# TODO (ruisizhang123): for now, we only consider FSDP + (TP & CP), whose comms are AG & RS & All_Reduce
319+
# For TP and CP, we consider the node as a "COMP" node with exposed communication as Comp time
320+
if is_collective(snode.node):
321+
current_comm = comm_cache.get_comm_time(
322+
snode.node.inputs[0].layout.size,
323+
snode.node.layout.size,
324+
getattr(snode.node, "python_kernel_name", ""),
325+
calibrated=True,
326+
)
327+
heuristic_info["next_step_nonfsdp_comm_time"] += current_comm
328+
else:
329+
if seen_new_bucketable_ag:
330+
heuristic_info["next_step_memory"] += memory_dict[bucketable_ag_idx]
331+
heuristic_info["next_step_comp_time"] += comp_time_dict[
332+
bucketable_ag_idx
333+
]
334+
seen_new_bucketable_ag = False
335+
336+
if len(current_ag_bucket) > 0 or len(all_gather_plan) == 0:
337+
all_gather_plan.append(current_ag_bucket)
338+
339+
if len(current_rs_bucket) > 0 or len(reduce_scatter_plan) == 0:
340+
reduce_scatter_plan.append(current_rs_bucket)
341+
342+
return all_gather_plan, reduce_scatter_plan

0 commit comments

Comments
 (0)