Skip to content

Commit 5dc7399

Browse files
committed
[3/N] Autobucketing: add greedy algorithm for autobucketing
ghstack-source-id: 888f10e Pull-Request: #129
1 parent 861edfa commit 5dc7399

File tree

3 files changed

+460
-2
lines changed

3 files changed

+460
-2
lines changed

autoparallel/auto_bucketing.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from .autobucketing_util import bucket_utils
8+
from .autobucketing_util import bucket_plan, bucket_utils
99

1010

1111
class simplefsdp_autobucketing_config:
@@ -36,7 +36,21 @@ def simple_fsdp_autobucketing_reordering_pass(
3636
configs: "simplefsdp_autobucketing_config",
3737
) -> list["torch._inductor.scheduler.BaseSchedulerNode"]:
3838
scheduler = snodes[0].scheduler
39-
bucket_utils.get_bucketable_ir_nodes(
39+
bucketable_nodes = bucket_utils.get_bucketable_ir_nodes(
4040
snodes, scheduler.name_to_fused_node, scheduler.name_to_buf
4141
)
42+
43+
assert (
44+
not torch._inductor.config.allow_buffer_reuse
45+
), "bucketing algorithm requires torch._inductor.config.allow_buffer_reuse to be False"
46+
47+
if configs.enable_bucket_ir:
48+
all_gather_plan, reduce_scatter_plan = bucket_plan.get_simplefsdp_auto_plan(
49+
scheduler,
50+
snodes,
51+
scheduler.name_to_buf,
52+
scheduler.name_to_fused_node,
53+
bucketable_nodes,
54+
configs,
55+
)
4256
return snodes
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import functools
7+
8+
# mypy: ignore-errors
9+
from collections import defaultdict
10+
from typing import Any, Dict
11+
12+
import torch
13+
from torch._C._distributed_c10d import ReduceOp
14+
from torch._inductor import scheduler
15+
from torch._inductor.comm import _schedule_fallback_operation
16+
from torch._inductor.utils import is_collective
17+
18+
from .bucket_utils import (
19+
check_ir_node_bucketable,
20+
estimate_bucketed_snode_runtime,
21+
get_data_size,
22+
get_snode_process_group_info,
23+
get_snode_tensor_info,
24+
)
25+
from .estimation import benchmark_and_sync_runtime
26+
27+
28+
def get_dynamic_memory_threshold(
29+
peak_memory,
30+
peak_memory_per_step_dict,
31+
current_step,
32+
) -> int:
33+
"""
34+
this function calculates the memory gap from the current step's memory to the peak memory
35+
"""
36+
left_peak_memory = 0
37+
right_peak_memory = 0
38+
for idx, memory in peak_memory_per_step_dict.items():
39+
if idx <= current_step:
40+
left_peak_memory = max(memory, left_peak_memory)
41+
if idx >= current_step:
42+
right_peak_memory = max(memory, right_peak_memory)
43+
current_peak_memory = min(left_peak_memory, right_peak_memory)
44+
return peak_memory - current_peak_memory
45+
46+
47+
def get_simplefsdp_auto_plan(
48+
sched: "scheduler.Scheduler",
49+
snodes: list["scheduler.BaseSchedulerNode"],
50+
name_to_buf: Dict[str, "scheduler.SchedulerBuffer"],
51+
name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"],
52+
bucketable_nodes: set[str],
53+
configs: Any,
54+
verbose: bool = True,
55+
) -> tuple[
56+
list[Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]]],
57+
list[Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]]],
58+
]:
59+
"""
60+
This function implements a greedy algorithm, which decides if the node could be bucketed
61+
with the previous one based on several criteria below:
62+
FWD Pass:
63+
(1) the bucketed AG communication could be overlapped by the previous computation;
64+
(2) the bucketed AG memory doesn't exceed peak memory;
65+
(3) bucketed AG communication size doesn't exceed 0.2*sum(fwd_ag_tensor_list), such
66+
that the estimated AG communication time is always in the calibration bound.
67+
BWD Pass:
68+
(1) the bucketed AG + RS communication could be overlapped by the previous computation;
69+
(2) the bucketed AG+RS memory doesn't exceed peak memory;
70+
(3) RS always have future compute to overlap it, such that its final exposed communication is small;
71+
(4) bucketed AG/RS communication size doesn't exceed 0.2* sum(fwd_ag_tensor_list) & 0.2* sum(bwd_rs_tensor_list),
72+
such that the estimated AG/RS communication time is always in the calibration bound.
73+
"""
74+
all_gather_plan = []
75+
reduce_scatter_plan = []
76+
current_ag_bucket: Dict[
77+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
78+
] = defaultdict(list)
79+
current_rs_bucket: Dict[
80+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
81+
] = defaultdict(list)
82+
schedule_fallback_operation = functools.partial(
83+
_schedule_fallback_operation,
84+
scheduler=sched,
85+
name_to_buf=name_to_buf,
86+
name_to_fused_node=name_to_fused_node,
87+
)
88+
89+
heuristic_info = {
90+
# time info
91+
"last_step_rs_comm_time": 0.0,
92+
"this_step_comp_time": 0.0,
93+
"this_step_rs_comm_time": 0.0,
94+
"next_step_comp_time": 0.0,
95+
"next_step_nonfsdp_comm_time": 0.0,
96+
# memory info
97+
"accumulated_gradient_size": 0,
98+
"last_step_rs_comm_size": 0,
99+
"this_step_rs_comm_out_size": 0,
100+
"this_step_rs_comm_inp_size": 0,
101+
"this_step_memory": 0,
102+
"next_step_memory": 0,
103+
}
104+
105+
# sync runtime info across ranks
106+
(
107+
comm_cache,
108+
comp_time_dict,
109+
memory_dict,
110+
peak_memory_per_step_dict,
111+
) = benchmark_and_sync_runtime(
112+
sched, snodes, name_to_buf, name_to_fused_node, bucketable_nodes, configs
113+
)
114+
future_comp_time = sum(comp_time_dict.values())
115+
peak_memory = max(peak_memory_per_step_dict.values()) + configs.peak_memory_offset
116+
117+
# autobucket algorithm
118+
bucketable_ag_idx = -1
119+
seen_new_bucketable_ag = True
120+
for _, snode in enumerate(snodes):
121+
if is_collective(
122+
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
123+
) and check_ir_node_bucketable(snode.node, bucketable_nodes):
124+
bucketable_ag_idx += 1
125+
seen_new_bucketable_ag = True
126+
future_comp_time -= comp_time_dict[bucketable_ag_idx]
127+
128+
ag_node_info = get_snode_tensor_info(
129+
snode, return_data_size=False
130+
) + get_snode_process_group_info(
131+
snode,
132+
expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default,
133+
resolve_pg=False,
134+
)
135+
current_ag_bucket[ag_node_info].append(snode)
136+
(
137+
estimated_comm,
138+
comm_size_inp,
139+
comm_size_out,
140+
) = estimate_bucketed_snode_runtime(
141+
current_ag_bucket,
142+
schedule_fallback_operation,
143+
name_to_buf,
144+
"torch.ops._c10d_functional.all_gather_into_tensor.default",
145+
comm_cache,
146+
)
147+
148+
# Check if current bucketing breaks the greedy criteria
149+
# (1) Overlappping criteria
150+
comp_time = heuristic_info["this_step_comp_time"] * (
151+
1 + configs.relax_ratio
152+
)
153+
comm_time = estimated_comm + heuristic_info["last_step_rs_comm_time"]
154+
break_overlap_criteria = comp_time < comm_time
155+
156+
# (2) Memory criteria
157+
memory_threshold = get_dynamic_memory_threshold(
158+
peak_memory,
159+
peak_memory_per_step_dict,
160+
bucketable_ag_idx,
161+
)
162+
# the buckted AG/RS are created on-the-fly, whose memory was not captured by the
163+
# estimate_peak_memory function. The bucketed_comm_memory consists of:
164+
# in FWD pass:
165+
# (1) all-gather copy-in (comm_size_inp): smaller buffers for dtype_conversion + bigger buffer to copy_in smaller buffers
166+
# thus, we have comm_size_inp*2
167+
# (2) all-gather copy-out (comm_size_out): bigger buffer to copy_out from ag_wait + split out smaller buffers for compute
168+
# thus, we have comm_size_out*2
169+
# in BWD pass:
170+
# TODO (ruisizhang123): we need to double check this. From memory trace, we can clearly see
171+
# these three regions stack together at a certain moment
172+
# due to reordering, the peak memory occurs at the end of current step's all-gather when last step & this step's reduce-scatter
173+
# are not cleared in time
174+
# (1) all-gather copy-in/copy-out (like FWD pass)
175+
# (2) last step's reduce-scatter: bigger buffer containts gradient
176+
# (3) next step's reduce-scatter: smaller buffers for dtype_conversion + bigger buffer to copy_in gradient
177+
bucketed_comm_memory = (
178+
2 * comm_size_inp
179+
+ 2 * comm_size_out
180+
+ heuristic_info["this_step_rs_comm_inp_size"] * 2
181+
+ heuristic_info["last_step_rs_comm_size"]
182+
)
183+
break_memory_criteria = (
184+
memory_threshold
185+
< heuristic_info["next_step_memory"] + bucketed_comm_memory
186+
)
187+
188+
# (3) Communication size criteria
189+
break_comm_size_criteria = comm_cache.ag_max_inp_size < comm_size_inp
190+
if comm_cache.rs_max_out_size > 0:
191+
break_comm_size_criteria = (
192+
break_comm_size_criteria
193+
or comm_cache.rs_max_out_size
194+
< heuristic_info["this_step_rs_comm_out_size"]
195+
)
196+
197+
if (
198+
break_overlap_criteria
199+
or break_memory_criteria
200+
or break_comm_size_criteria
201+
):
202+
if heuristic_info["this_step_comp_time"] > 0:
203+
# if bucketing breaks the greedy criteria, pop the last node out
204+
overflow_ag = current_ag_bucket[ag_node_info].pop()
205+
all_gather_plan.append(current_ag_bucket)
206+
current_ag_bucket: Dict[
207+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
208+
] = defaultdict(list)
209+
current_ag_bucket[ag_node_info].append(overflow_ag)
210+
else:
211+
# if there is no compute, we have to keep the all_gather to avoid deadlock
212+
all_gather_plan.append(current_ag_bucket)
213+
current_ag_bucket: Dict[
214+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
215+
] = defaultdict(list)
216+
217+
if verbose:
218+
print(
219+
"break_overlap_criteria",
220+
break_overlap_criteria,
221+
)
222+
print("Current comm time", comm_time, "comp time", comp_time)
223+
print(
224+
"break_memory_criteria",
225+
break_memory_criteria,
226+
)
227+
print(
228+
"memory_threshold",
229+
memory_threshold,
230+
"total memory",
231+
heuristic_info["next_step_memory"] + bucketed_comm_memory,
232+
)
233+
print(
234+
"break_comm_size_criteria",
235+
break_comm_size_criteria,
236+
)
237+
print("current_ag_bucket", all_gather_plan[-1])
238+
239+
# bucket reduce scatters if there are any
240+
if len(current_rs_bucket) > 0:
241+
(
242+
current_estimated_rs,
243+
rs_comm_size_inp,
244+
rs_comm_size_out,
245+
) = estimate_bucketed_snode_runtime(
246+
current_rs_bucket,
247+
schedule_fallback_operation,
248+
name_to_buf,
249+
"torch.ops._c10d_functional.reduce_scatter_tensor.default",
250+
comm_cache,
251+
ReduceOp.AVG,
252+
)
253+
heuristic_info["last_step_rs_comm_time"] = current_estimated_rs
254+
reduce_scatter_plan.append(current_rs_bucket)
255+
heuristic_info["last_step_rs_comm_size"] = rs_comm_size_out
256+
current_rs_bucket: Dict[
257+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
258+
] = defaultdict(list)
259+
260+
# update heuristic info for the next step
261+
(
262+
heuristic_info["this_step_comp_time"],
263+
heuristic_info["this_step_memory"],
264+
) = (
265+
heuristic_info["next_step_comp_time"]
266+
+ heuristic_info["next_step_nonfsdp_comm_time"],
267+
heuristic_info["next_step_memory"],
268+
)
269+
(
270+
heuristic_info["next_step_comp_time"],
271+
heuristic_info["next_step_memory"],
272+
) = (
273+
0,
274+
0,
275+
)
276+
elif is_collective(
277+
snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
278+
) and check_ir_node_bucketable(snode.node, bucketable_nodes):
279+
node_info = get_snode_tensor_info(
280+
snode, return_data_size=False
281+
) + get_snode_process_group_info(
282+
snode,
283+
expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default,
284+
resolve_pg=False,
285+
)
286+
current_rs_bucket[node_info].append(snode)
287+
288+
(
289+
heuristic_info["this_step_rs_comm_time"],
290+
rs_comm_size_inp,
291+
rs_comm_size_out,
292+
) = estimate_bucketed_snode_runtime(
293+
current_rs_bucket,
294+
schedule_fallback_operation,
295+
name_to_buf,
296+
"torch.ops._c10d_functional.reduce_scatter_tensor.default",
297+
comm_cache,
298+
ReduceOp.AVG,
299+
)
300+
heuristic_info["this_step_rs_comm_out_size"] = rs_comm_size_out
301+
heuristic_info["this_step_rs_comm_inp_size"] = rs_comm_size_inp
302+
heuristic_info["accumulated_gradient_size"] += get_data_size(
303+
snode.node.layout.size
304+
)
305+
306+
# Check if current bucketing breaks the greedy criteria
307+
# (4) future compute to overlap RS criteria
308+
break_rs_overlap_criteria = (
309+
future_comp_time < heuristic_info["this_step_rs_comm_time"] * 5
310+
)
311+
if break_rs_overlap_criteria:
312+
reduce_scatter_plan.append(current_rs_bucket)
313+
heuristic_info["last_step_rs_comm_time"] = heuristic_info[
314+
"this_step_rs_comm_time"
315+
]
316+
heuristic_info["this_step_rs_comm_time"] = 0
317+
current_rs_bucket: Dict[
318+
tuple[Any, ...], list["scheduler.BaseSchedulerNode"]
319+
] = defaultdict(list)
320+
else:
321+
# TODO (ruisizhang123): for now, we only consider FSDP + (TP & CP), whose comms are AG & RS & All_Reduce
322+
# For TP and CP, we consider the node as a "COMP" node with exposed communication as Comp time
323+
if is_collective(snode.node):
324+
current_comm = comm_cache.get_comm_time(
325+
snode.node.inputs[0].layout.size,
326+
snode.node.layout.size,
327+
getattr(snode.node, "python_kernel_name", ""),
328+
calibrated=True,
329+
)
330+
heuristic_info["next_step_nonfsdp_comm_time"] += current_comm
331+
else:
332+
if seen_new_bucketable_ag:
333+
heuristic_info["next_step_memory"] += memory_dict[bucketable_ag_idx]
334+
heuristic_info["next_step_comp_time"] += comp_time_dict[
335+
bucketable_ag_idx
336+
]
337+
seen_new_bucketable_ag = False
338+
339+
if len(current_ag_bucket) > 0:
340+
all_gather_plan.append(current_ag_bucket)
341+
342+
if len(current_rs_bucket) > 0:
343+
reduce_scatter_plan.append(current_rs_bucket)
344+
345+
return all_gather_plan, reduce_scatter_plan

0 commit comments

Comments
 (0)