Skip to content

Commit a97802c

Browse files
Revert "[2/N] Autobucketing: add estimation utils for autobucketing (#128)"
This reverts commit 6206d9d.
1 parent 6206d9d commit a97802c

File tree

4 files changed

+3
-825
lines changed

4 files changed

+3
-825
lines changed

autoparallel/auto_bucketing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ class simplefsdp_autobucketing_config:
1919
- load_cache: set to True to load cache from save_estimation_path
2020
- enable_bucket_ir: set to True to bucket all_gather/reduce_scatter
2121
- enable_reorder_ir: set to True to reorder all_gather/reduce_satter
22-
- calibrate_number: number of samples to calibrate during comm estimation
2322
"""
2423

2524
relax_ratio = 0
@@ -28,7 +27,6 @@ class simplefsdp_autobucketing_config:
2827
save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl"
2928
enable_bucket_ir = True
3029
enable_reorder_ir = True
31-
calibrate_number = 40
3230

3331

3432
def simple_fsdp_autobucketing_reordering_pass(

autoparallel/autobucketing_util/bucket_utils.py

Lines changed: 3 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,15 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# mypy: ignore-errors
7-
from functools import reduce
8-
from typing import Any, Callable, Dict, Union
7+
from typing import Any, Callable, Dict
98

109
import torch
11-
from torch._inductor import ir, scheduler
10+
from torch._inductor import scheduler
1211
from torch._inductor.dependencies import WeakDep
13-
from torch._inductor.ir import NoneLayout
14-
from torch._inductor.utils import buf_name_to_fused_snode, is_collective, is_wait
15-
from torch.distributed import ProcessGroup
16-
from torch.distributed.distributed_c10d import _resolve_process_group
12+
from torch._inductor.utils import buf_name_to_fused_snode, is_collective
1713
from torch.utils._ordered_set import OrderedSet
1814

1915

20-
def get_data_size(size):
21-
return reduce(lambda x, y: x * y, size)
22-
23-
2416
def _find_recursive_deps_of_snode(
2517
snode: "scheduler.BaseSchedulerNode",
2618
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
@@ -125,119 +117,3 @@ def get_bucketable_ir_nodes(
125117
bucketable_ir_nodes.add(snode.node.get_name())
126118

127119
return bucketable_ir_nodes
128-
129-
130-
def check_ir_node_bucketable(
131-
ir_node: "ir.IRNode", bucketable_ir_nodes: set[str]
132-
) -> bool:
133-
"""
134-
Determine if the AG/RS & AG/RS wait node is from bucketable nodes or not
135-
"""
136-
ir_node_origins = list(getattr(ir_node, "origins", None))
137-
if len(ir_node_origins) == 0:
138-
# bucketed AG and RS doesn't have origins
139-
return True
140-
141-
if is_wait(ir_node):
142-
ir_node = ir_node.inputs[0]
143-
144-
if is_collective(
145-
ir_node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
146-
):
147-
ir_node_name = ir_node.get_name()
148-
elif is_collective(
149-
ir_node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
150-
):
151-
ir_node_name = ir_node.get_name()
152-
else:
153-
return False
154-
155-
if ir_node_name in bucketable_ir_nodes:
156-
return True
157-
158-
return False
159-
160-
161-
def _get_fx_node(
162-
snode_or_ir_node: Union["scheduler.BaseSchedulerNode", "ir.IRNode"],
163-
expected_op: Any,
164-
) -> torch.fx.Node:
165-
origins = None
166-
if isinstance(snode_or_ir_node, scheduler.BaseSchedulerNode):
167-
origins = snode_or_ir_node.node.get_origins()
168-
elif isinstance(snode_or_ir_node, ir.IRNode):
169-
origins = snode_or_ir_node.origins
170-
else:
171-
raise ValueError(
172-
f"Expected BaseSchedulerNode or IRNode, got {type(snode_or_ir_node)}. Offending value: {snode_or_ir_node}"
173-
)
174-
origins_with_expected_op = [o for o in origins if o.target == expected_op]
175-
if len(origins_with_expected_op) != 1:
176-
print(
177-
"[Get FX exception] origins_with_expected_op",
178-
origins_with_expected_op,
179-
"expected_op",
180-
expected_op,
181-
"snode_or_ir_node",
182-
snode_or_ir_node,
183-
)
184-
return None
185-
return origins_with_expected_op[0]
186-
187-
188-
def get_snode_process_group_info(
189-
snode: "scheduler.BaseSchedulerNode",
190-
expected_op: Any,
191-
resolve_pg: bool = False,
192-
) -> tuple[int, Union[str, ProcessGroup]]:
193-
fx_node = _get_fx_node(snode, expected_op=expected_op)
194-
# return None if the snode doesn't have a valid fx_node
195-
if fx_node is None:
196-
return None
197-
198-
if expected_op == torch.ops._c10d_functional.all_gather_into_tensor.default:
199-
group_size, group_name = (
200-
snode.node.constant_args[0],
201-
snode.node.constant_args[1],
202-
)
203-
elif expected_op == torch.ops._c10d_functional.reduce_scatter_tensor.default:
204-
group_size, group_name = (
205-
snode.node.constant_args[1],
206-
snode.node.constant_args[2],
207-
)
208-
elif expected_op == torch.ops._c10d_functional.all_reduce_.default:
209-
group_size, group_name = fx_node.args[1], fx_node.args[2]
210-
elif expected_op == torch.ops._c10d_functional.all_to_all_single.default:
211-
group_size, group_name = 0, fx_node.args[3]
212-
else:
213-
raise ValueError(f"Unsupported op {expected_op}")
214-
215-
if resolve_pg:
216-
group_name = _resolve_process_group(group_name)
217-
return group_size, group_name
218-
219-
220-
def get_snode_tensor_info(
221-
snode: "scheduler.BaseSchedulerNode", return_data_size: bool = False
222-
) -> tuple[Any, ...]:
223-
input_dtype, input_device = (
224-
snode.node.inputs[0].layout.dtype,
225-
snode.node.inputs[0].layout.device,
226-
)
227-
input_size = get_data_size(snode.node.inputs[0].layout.size)
228-
229-
if not isinstance(snode.node.layout, NoneLayout):
230-
output_dtype, output_device = (
231-
snode.node.layout.dtype,
232-
snode.node.layout.device,
233-
)
234-
output_size = get_data_size(snode.node.layout.size)
235-
else:
236-
# In all_reduce, layout is NoneLayout
237-
# We set output info to be the same as input info as a special treatment
238-
output_dtype, output_device, output_size = input_dtype, input_device, input_size
239-
240-
result = (input_dtype, input_device, output_dtype, output_device)
241-
if return_data_size:
242-
result += (input_size, output_size)
243-
return result

autoparallel/autobucketing_util/estimation.py

Lines changed: 0 additions & 229 deletions
This file was deleted.

0 commit comments

Comments
 (0)