Skip to content

Commit 97d0c40

Browse files
committed
[2/N] Autobucketing: add estimation utils for autobucketing
ghstack-source-id: 10eff9b Pull-Request: #128
1 parent 6e80732 commit 97d0c40

File tree

4 files changed

+820
-3
lines changed

4 files changed

+820
-3
lines changed

autoparallel/auto_bucketing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class simplefsdp_autobucketing_config:
1919
- load_cache: set to True to load cache from save_estimation_path
2020
- enable_bucket_ir: set to True to bucket all_gather/reduce_scatter
2121
- enable_reorder_ir: set to True to reorder all_gather/reduce_satter
22+
- calibrate_number: number of samples to calibrate during comm estimation
2223
"""
2324

2425
relax_ratio = 0
@@ -27,6 +28,7 @@ class simplefsdp_autobucketing_config:
2728
save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl"
2829
enable_bucket_ir = True
2930
enable_reorder_ir = True
31+
calibrate_number = 40
3032

3133

3234
def simple_fsdp_autobucketing_reordering_pass(

autoparallel/autobucketing_util/bucket_utils.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,23 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# mypy: ignore-errors
7-
from typing import Any, Callable, Dict
7+
from functools import reduce
8+
from typing import Any, Callable, Dict, Union
89

910
import torch
10-
from torch._inductor import scheduler
11+
from torch._inductor import ir, scheduler
1112
from torch._inductor.dependencies import WeakDep
12-
from torch._inductor.utils import buf_name_to_fused_snode, is_collective
13+
from torch._inductor.ir import NoneLayout
14+
from torch._inductor.utils import buf_name_to_fused_snode, is_collective, is_wait
15+
from torch.distributed import ProcessGroup
16+
from torch.distributed.distributed_c10d import _resolve_process_group
1317
from torch.utils._ordered_set import OrderedSet
1418

1519

20+
def get_data_size(size):
21+
return reduce(lambda x, y: x * y, size)
22+
23+
1624
def _find_recursive_deps_of_snode(
1725
snode: "scheduler.BaseSchedulerNode",
1826
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
@@ -117,3 +125,119 @@ def get_bucketable_ir_nodes(
117125
bucketable_ir_nodes.add(snode.node.get_name())
118126

119127
return bucketable_ir_nodes
128+
129+
130+
def check_ir_node_bucketable(
131+
ir_node: "ir.IRNode", bucketable_ir_nodes: set[str]
132+
) -> bool:
133+
"""
134+
Determine if the AG/RS & AG/RS wait node is from bucketable nodes or not
135+
"""
136+
ir_node_origins = list(getattr(ir_node, "origins", None))
137+
if len(ir_node_origins) == 0:
138+
# bucketed AG and RS doesn't have origins
139+
return True
140+
141+
if is_wait(ir_node):
142+
ir_node = ir_node.inputs[0]
143+
144+
if is_collective(
145+
ir_node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
146+
):
147+
ir_node_name = ir_node.get_name()
148+
elif is_collective(
149+
ir_node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
150+
):
151+
ir_node_name = ir_node.get_name()
152+
else:
153+
return False
154+
155+
if ir_node_name in bucketable_ir_nodes:
156+
return True
157+
158+
return False
159+
160+
161+
def _get_fx_node(
162+
snode_or_ir_node: Union["scheduler.BaseSchedulerNode", "ir.IRNode"],
163+
expected_op: Callable[[Any]],
164+
) -> torch.fx.Node:
165+
origins = None
166+
if isinstance(snode_or_ir_node, scheduler.BaseSchedulerNode):
167+
origins = snode_or_ir_node.node.get_origins()
168+
elif isinstance(snode_or_ir_node, ir.IRNode):
169+
origins = snode_or_ir_node.origins
170+
else:
171+
raise ValueError(
172+
f"Expected BaseSchedulerNode or IRNode, got {type(snode_or_ir_node)}. Offending value: {snode_or_ir_node}"
173+
)
174+
origins_with_expected_op = [o for o in origins if o.target == expected_op]
175+
if len(origins_with_expected_op) != 1:
176+
print(
177+
"[Get FX exception] origins_with_expected_op",
178+
origins_with_expected_op,
179+
"expected_op",
180+
expected_op,
181+
"snode_or_ir_node",
182+
snode_or_ir_node,
183+
)
184+
return None
185+
return origins_with_expected_op[0]
186+
187+
188+
def get_snode_process_group_info(
189+
snode: "scheduler.BaseSchedulerNode",
190+
expected_op: Callable[[Any]],
191+
resolve_pg: bool = False,
192+
) -> tuple[int, Union[str, ProcessGroup]]:
193+
fx_node = _get_fx_node(snode, expected_op=expected_op)
194+
# return None if the snode doesn't have a valid fx_node
195+
if fx_node is None:
196+
return None
197+
198+
if expected_op == torch.ops._c10d_functional.all_gather_into_tensor.default:
199+
group_size, group_name = (
200+
snode.node.constant_args[0],
201+
snode.node.constant_args[1],
202+
)
203+
elif expected_op == torch.ops._c10d_functional.reduce_scatter_tensor.default:
204+
group_size, group_name = (
205+
snode.node.constant_args[1],
206+
snode.node.constant_args[2],
207+
)
208+
elif expected_op == torch.ops._c10d_functional.all_reduce_.default:
209+
group_size, group_name = fx_node.args[1], fx_node.args[2]
210+
elif expected_op == torch.ops._c10d_functional.all_to_all_single.default:
211+
group_size, group_name = 0, fx_node.args[3]
212+
else:
213+
raise ValueError(f"Unsupported op {expected_op}")
214+
215+
if resolve_pg:
216+
group_name = _resolve_process_group(group_name)
217+
return group_size, group_name
218+
219+
220+
def get_snode_tensor_info(
221+
snode: "scheduler.BaseSchedulerNode", return_data_size: bool = False
222+
) -> tuple[Any, ...]:
223+
input_dtype, input_device = (
224+
snode.node.inputs[0].layout.dtype,
225+
snode.node.inputs[0].layout.device,
226+
)
227+
input_size = get_data_size(snode.node.inputs[0].layout.size)
228+
229+
if not isinstance(snode.node.layout, NoneLayout):
230+
output_dtype, output_device = (
231+
snode.node.layout.dtype,
232+
snode.node.layout.device,
233+
)
234+
output_size = get_data_size(snode.node.layout.size)
235+
else:
236+
# In all_reduce, layout is NoneLayout
237+
# We set output info to be the same as input info as a special treatment
238+
output_dtype, output_device, output_size = input_dtype, input_device, input_size
239+
240+
result = (input_dtype, input_device, output_dtype, output_device)
241+
if return_data_size:
242+
result += (input_size, output_size)
243+
return result
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# mypy: ignore-errors
7+
import os
8+
import pickle
9+
from collections import defaultdict
10+
11+
import torch
12+
import torch.distributed as c10d
13+
from torch._inductor import memory, scheduler
14+
from torch._inductor.utils import is_collective
15+
from torch._inductor.virtualized import V
16+
from torch.utils._ordered_set import OrderedSet
17+
18+
from ..auto_bucketing import simplefsdp_autobucketing_config
19+
from .bucket_utils import (
20+
check_ir_node_bucketable,
21+
get_snode_process_group_info,
22+
get_snode_tensor_info,
23+
)
24+
from .estimation_utils import (
25+
CommPerfCache,
26+
CompPerfCache,
27+
benchmark_and_cache_comm_dicts,
28+
estimate_comp_time,
29+
)
30+
31+
32+
def sync_dict_across_ranks(runtime_dict, world_size, group=None):
33+
gathered_lists = [None for _ in range(world_size)]
34+
c10d.all_gather_object(gathered_lists, list(runtime_dict.values()), group=group)
35+
median_gathered_time = torch.median(torch.tensor(gathered_lists), dim=0).values
36+
for idx, (key, value) in enumerate(runtime_dict.items()):
37+
runtime_dict[key] = median_gathered_time[idx]
38+
return runtime_dict
39+
40+
41+
def benchmark_and_sync_runtime(
42+
sched: "scheduler.Scheduler",
43+
snodes: list["scheduler.BaseSchedulerNode"],
44+
name_to_buf: dict[str, "scheduler.SchedulerBuffer"],
45+
name_to_fused_node: dict[str, "scheduler.BaseSchedulerNode"],
46+
bucketable_nodes: set[str],
47+
configs: "simplefsdp_autobucketing_config",
48+
):
49+
world_size = c10d.distributed_c10d.get_world_size()
50+
51+
fsdp_ag_input_size_dict = defaultdict(list)
52+
fsdp_rs_output_size_dict = defaultdict(list)
53+
non_fsdp_ag_input_size_dict = defaultdict(list)
54+
non_fsdp_rs_input_size_dict = defaultdict(list)
55+
all_reduce_input_size_dict = defaultdict(list)
56+
all_to_all_input_size_dict = defaultdict(list)
57+
comp_cache, comm_cache = CompPerfCache(), CommPerfCache()
58+
59+
cali_num_samples = configs.calibrate_number
60+
comp_time_dict = defaultdict(float)
61+
memory_dict = defaultdict(int)
62+
peak_memory_per_step_dict = defaultdict(int)
63+
fsdp_ag_idx = -1
64+
release_steps = [0]
65+
66+
graph_outputs = OrderedSet(V.graph.get_output_names())
67+
graph_inputs = OrderedSet(V.graph.graph_inputs.keys())
68+
_, name_to_freeable_input_buf = memory.prepare_planning_info(
69+
snodes,
70+
name_to_buf,
71+
name_to_fused_node,
72+
graph_inputs,
73+
graph_outputs,
74+
)
75+
_, memories_at_nodes = memory.estimate_peak_memory(
76+
snodes, name_to_freeable_input_buf, graph_outputs
77+
)
78+
79+
for idx, snode in enumerate(snodes):
80+
if is_collective(
81+
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
82+
):
83+
fsdp_ag_idx += 1
84+
release_steps.append(idx)
85+
node_tensor_info = get_snode_tensor_info(snode, return_data_size=True)
86+
node_pg_info = get_snode_process_group_info(
87+
snode,
88+
expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default,
89+
resolve_pg=True,
90+
)
91+
if node_pg_info is None:
92+
continue
93+
node_info = node_tensor_info[:-2] + node_pg_info
94+
input_size = node_tensor_info[-2]
95+
if check_ir_node_bucketable(snode.node, bucketable_nodes):
96+
# For FSDP, we assume they have all have the
97+
fsdp_ag_input_size_dict[node_info].append(input_size)
98+
else:
99+
non_fsdp_ag_input_size_dict[node_info].append(input_size)
100+
elif is_collective(
101+
snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
102+
):
103+
node_tensor_info = get_snode_tensor_info(snode, return_data_size=True)
104+
node_pg_info = get_snode_process_group_info(
105+
snode,
106+
expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default,
107+
resolve_pg=True,
108+
)
109+
if node_pg_info is None:
110+
continue
111+
node_info = node_tensor_info[:-2] + node_pg_info
112+
output_size = node_tensor_info[-1]
113+
if check_ir_node_bucketable(snode.node, bucketable_nodes):
114+
# For FSDP, we assume they have all have the same group size
115+
fsdp_rs_output_size_dict[node_info].append(output_size)
116+
else:
117+
non_fsdp_rs_input_size_dict[node_info].append(output_size)
118+
elif is_collective(
119+
snode.node, op=torch.ops._c10d_functional.all_reduce_.default
120+
):
121+
node_tensor_info = get_snode_tensor_info(snode, return_data_size=True)
122+
node_pg_info = get_snode_process_group_info(
123+
snode,
124+
expected_op=torch.ops._c10d_functional.all_reduce_.default,
125+
resolve_pg=True,
126+
)
127+
if node_pg_info is None:
128+
continue
129+
node_info = node_tensor_info[:-2] + node_pg_info
130+
input_size = node_tensor_info[-2]
131+
all_reduce_input_size_dict[node_info].append(input_size)
132+
elif is_collective(
133+
snode.node, op=torch.ops._c10d_functional.all_to_all_single.default
134+
):
135+
node_tensor_info = get_snode_tensor_info(snode, return_data_size=True)
136+
node_pg_info = get_snode_process_group_info(
137+
snode,
138+
expected_op=torch.ops._c10d_functional.all_to_all_single.default,
139+
resolve_pg=True,
140+
)
141+
if node_pg_info is None:
142+
continue
143+
node_info = node_tensor_info[:-2] + node_pg_info
144+
input_size = node_tensor_info[-2]
145+
all_to_all_input_size_dict[node_info].append(input_size)
146+
else:
147+
if not is_collective(snode.node):
148+
comp_time = estimate_comp_time(
149+
sched, snode, verbose=False, comp_cache=comp_cache
150+
)
151+
comp_time_dict[fsdp_ag_idx] += comp_time
152+
memory_dict[fsdp_ag_idx] = max(
153+
abs(
154+
memories_at_nodes[idx + 1]
155+
- memories_at_nodes[release_steps[-1]]
156+
),
157+
memory_dict[fsdp_ag_idx],
158+
)
159+
peak_memory_per_step_dict[fsdp_ag_idx] = max(
160+
memories_at_nodes[idx + 1], peak_memory_per_step_dict[fsdp_ag_idx]
161+
)
162+
else:
163+
print(
164+
"[Relaxed Setting] untracked communication",
165+
snode.node.python_kernel_name,
166+
)
167+
168+
# Sync total compute time
169+
comp_time_dict = sync_dict_across_ranks(comp_time_dict, world_size)
170+
memory_dict = sync_dict_across_ranks(memory_dict, world_size)
171+
peak_memory_per_step_dict = sync_dict_across_ranks(
172+
peak_memory_per_step_dict, world_size
173+
)
174+
175+
if configs.load_cache and os.path.exists(configs.save_estimation_path):
176+
with open(configs.save_estimation_path, "rb") as file:
177+
cache = pickle.load(file)
178+
comm_cache.cache = cache
179+
comm_cache._update_max_size()
180+
return comm_cache, comp_time_dict, memory_dict, peak_memory_per_step_dict
181+
182+
benchmark_params = [
183+
(
184+
fsdp_ag_input_size_dict,
185+
"torch.ops._c10d_functional.all_gather_into_tensor.default",
186+
cali_num_samples,
187+
),
188+
(
189+
fsdp_rs_output_size_dict,
190+
"torch.ops._c10d_functional.reduce_scatter_tensor.default",
191+
cali_num_samples,
192+
),
193+
(
194+
non_fsdp_ag_input_size_dict,
195+
"torch.ops._c10d_functional.all_gather_into_tensor.default",
196+
3,
197+
),
198+
(
199+
non_fsdp_rs_input_size_dict,
200+
"torch.ops._c10d_functional.reduce_scatter_tensor.default",
201+
3,
202+
),
203+
(
204+
all_reduce_input_size_dict,
205+
"torch.ops._c10d_functional.all_reduce_.default",
206+
3,
207+
),
208+
(
209+
all_to_all_input_size_dict,
210+
"torch.ops._c10d_functional.all_to_all_single.default",
211+
3,
212+
),
213+
]
214+
for input_size_dict, op_name, num_samples in benchmark_params:
215+
if len(input_size_dict) > 0:
216+
benchmark_and_cache_comm_dicts(
217+
comm_cache, input_size_dict, op_name, num_samples
218+
)
219+
220+
median_runtimes = sync_dict_across_ranks(comm_cache.cache, world_size)
221+
comm_cache.cache = median_runtimes
222+
comm_cache._update_max_size()
223+
with open(configs.simplefsdp.save_estimation_path, "wb") as file:
224+
pickle.dump(comm_cache.cache, file)
225+
return comm_cache, comp_time_dict, memory_dict, peak_memory_per_step_dict

0 commit comments

Comments
 (0)