Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autoparallel/auto_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class simplefsdp_autobucketing_config:
- load_cache: set to True to load cache from save_estimation_path
- enable_bucket_ir: set to True to bucket all_gather/reduce_scatter
- enable_reorder_ir: set to True to reorder all_gather/reduce_satter
- calibrate_number: number of samples to calibrate during comm estimation
"""

relax_ratio = 0
Expand All @@ -27,6 +28,7 @@ class simplefsdp_autobucketing_config:
save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl"
enable_bucket_ir = True
enable_reorder_ir = True
calibrate_number = 40


def simple_fsdp_autobucketing_reordering_pass(
Expand Down
130 changes: 127 additions & 3 deletions autoparallel/autobucketing_util/bucket_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@
# LICENSE file in the root directory of this source tree.

# mypy: ignore-errors
from typing import Any, Callable, Dict
from functools import reduce
from typing import Any, Callable, Dict, Union

import torch
from torch._inductor import scheduler
from torch._inductor import ir, scheduler
from torch._inductor.dependencies import WeakDep
from torch._inductor.utils import buf_name_to_fused_snode, is_collective
from torch._inductor.ir import NoneLayout
from torch._inductor.utils import buf_name_to_fused_snode, is_collective, is_wait
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _resolve_process_group
from torch.utils._ordered_set import OrderedSet


def get_data_size(size):
return reduce(lambda x, y: x * y, size)


def _find_recursive_deps_of_snode(
snode: "scheduler.BaseSchedulerNode",
collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"],
Expand Down Expand Up @@ -117,3 +125,119 @@ def get_bucketable_ir_nodes(
bucketable_ir_nodes.add(snode.node.get_name())

return bucketable_ir_nodes


def check_ir_node_bucketable(
ir_node: "ir.IRNode", bucketable_ir_nodes: set[str]
) -> bool:
"""
Determine if the AG/RS & AG/RS wait node is from bucketable nodes or not
"""
ir_node_origins = list(getattr(ir_node, "origins", None))
if len(ir_node_origins) == 0:
# bucketed AG and RS doesn't have origins
return True

if is_wait(ir_node):
ir_node = ir_node.inputs[0]

if is_collective(
ir_node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
):
ir_node_name = ir_node.get_name()
elif is_collective(
ir_node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
):
ir_node_name = ir_node.get_name()
else:
return False

if ir_node_name in bucketable_ir_nodes:
return True

return False


def _get_fx_node(
snode_or_ir_node: Union["scheduler.BaseSchedulerNode", "ir.IRNode"],
expected_op: Any,
) -> torch.fx.Node:
origins = None
if isinstance(snode_or_ir_node, scheduler.BaseSchedulerNode):
origins = snode_or_ir_node.node.get_origins()
elif isinstance(snode_or_ir_node, ir.IRNode):
origins = snode_or_ir_node.origins
else:
raise ValueError(
f"Expected BaseSchedulerNode or IRNode, got {type(snode_or_ir_node)}. Offending value: {snode_or_ir_node}"
)
origins_with_expected_op = [o for o in origins if o.target == expected_op]
if len(origins_with_expected_op) != 1:
print(
"[Get FX exception] origins_with_expected_op",
origins_with_expected_op,
"expected_op",
expected_op,
"snode_or_ir_node",
snode_or_ir_node,
)
return None
return origins_with_expected_op[0]


def get_snode_process_group_info(
snode: "scheduler.BaseSchedulerNode",
expected_op: Any,
resolve_pg: bool = False,
) -> tuple[int, Union[str, ProcessGroup]]:
fx_node = _get_fx_node(snode, expected_op=expected_op)
# return None if the snode doesn't have a valid fx_node
if fx_node is None:
return None

if expected_op == torch.ops._c10d_functional.all_gather_into_tensor.default:
group_size, group_name = (
snode.node.constant_args[0],
snode.node.constant_args[1],
)
elif expected_op == torch.ops._c10d_functional.reduce_scatter_tensor.default:
group_size, group_name = (
snode.node.constant_args[1],
snode.node.constant_args[2],
)
elif expected_op == torch.ops._c10d_functional.all_reduce_.default:
group_size, group_name = fx_node.args[1], fx_node.args[2]
elif expected_op == torch.ops._c10d_functional.all_to_all_single.default:
group_size, group_name = 0, fx_node.args[3]
else:
raise ValueError(f"Unsupported op {expected_op}")

if resolve_pg:
group_name = _resolve_process_group(group_name)
return group_size, group_name


def get_snode_tensor_info(
snode: "scheduler.BaseSchedulerNode", return_data_size: bool = False
) -> tuple[Any, ...]:
input_dtype, input_device = (
snode.node.inputs[0].layout.dtype,
snode.node.inputs[0].layout.device,
)
input_size = get_data_size(snode.node.inputs[0].layout.size)

if not isinstance(snode.node.layout, NoneLayout):
output_dtype, output_device = (
snode.node.layout.dtype,
snode.node.layout.device,
)
output_size = get_data_size(snode.node.layout.size)
else:
# In all_reduce, layout is NoneLayout
# We set output info to be the same as input info as a special treatment
output_dtype, output_device, output_size = input_dtype, input_device, input_size

result = (input_dtype, input_device, output_dtype, output_device)
if return_data_size:
result += (input_size, output_size)
return result
229 changes: 229 additions & 0 deletions autoparallel/autobucketing_util/estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# mypy: ignore-errors
import os
import pickle
from collections import defaultdict
from typing import Any

import torch
import torch.distributed as c10d
from torch._inductor import memory, scheduler
from torch._inductor.utils import is_collective
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet

from .bucket_utils import (
check_ir_node_bucketable,
get_snode_process_group_info,
get_snode_tensor_info,
)
from .estimation_utils import (
CommPerfCache,
CompPerfCache,
benchmark_and_cache_comm_dicts,
estimate_comp_time,
)


def sync_dict_across_ranks(runtime_dict, world_size, group=None):
gathered_lists = [None for _ in range(world_size)]
c10d.all_gather_object(gathered_lists, list(runtime_dict.values()), group=group)
median_gathered_time = torch.median(torch.tensor(gathered_lists), dim=0).values
for idx, (key, value) in enumerate(runtime_dict.items()):
runtime_dict[key] = median_gathered_time[idx]
return runtime_dict


def benchmark_and_sync_runtime(
sched: "scheduler.Scheduler",
snodes: list["scheduler.BaseSchedulerNode"],
name_to_buf: dict[str, "scheduler.SchedulerBuffer"],
name_to_fused_node: dict[str, "scheduler.BaseSchedulerNode"],
bucketable_nodes: set[str],
configs: Any,
):
world_size = c10d.distributed_c10d.get_world_size()

fsdp_ag_input_size_dict = defaultdict(list)
fsdp_rs_output_size_dict = defaultdict(list)
non_fsdp_ag_input_size_dict = defaultdict(list)
non_fsdp_rs_input_size_dict = defaultdict(list)
all_reduce_input_size_dict = defaultdict(list)
all_to_all_input_size_dict = defaultdict(list)
comp_cache, comm_cache = CompPerfCache(), CommPerfCache()

cali_num_samples = configs.calibrate_number
comp_time_dict = defaultdict(float)
memory_dict = defaultdict(int)
peak_memory_per_step_dict = defaultdict(int)
fsdp_ag_idx = -1
release_steps = [0]

graph_outputs = OrderedSet(V.graph.get_output_names())
graph_inputs = OrderedSet(V.graph.graph_inputs.keys())
_, name_to_freeable_input_buf = memory.prepare_planning_info(
snodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
)
_, memories_at_nodes = memory.estimate_peak_memory(
snodes, name_to_freeable_input_buf, graph_outputs
)
# ensure memory offset is always positive
if min(memories_at_nodes) < 0:
shift_value = abs(min(memories_at_nodes))
memories_at_nodes = [x + shift_value for x in memories_at_nodes]

for idx, snode in enumerate(snodes):
if is_collective(
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
):
fsdp_ag_idx += 1
release_steps.append(idx)
node_tensor_info = get_snode_tensor_info(snode, return_data_size=True)
node_pg_info = get_snode_process_group_info(
snode,
expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default,
resolve_pg=True,
)
if node_pg_info is None:
continue
node_info = node_tensor_info[:-2] + node_pg_info
input_size = node_tensor_info[-2]
if check_ir_node_bucketable(snode.node, bucketable_nodes):
# For FSDP, we assume they have all have the
fsdp_ag_input_size_dict[node_info].append(input_size)
else:
non_fsdp_ag_input_size_dict[node_info].append(input_size)
elif is_collective(
snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
):
node_tensor_info = get_snode_tensor_info(snode, return_data_size=True)
node_pg_info = get_snode_process_group_info(
snode,
expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default,
resolve_pg=True,
)
if node_pg_info is None:
continue
node_info = node_tensor_info[:-2] + node_pg_info
output_size = node_tensor_info[-1]
if check_ir_node_bucketable(snode.node, bucketable_nodes):
# For FSDP, we assume they have all have the same group size
fsdp_rs_output_size_dict[node_info].append(output_size)
else:
non_fsdp_rs_input_size_dict[node_info].append(output_size)
elif is_collective(
snode.node, op=torch.ops._c10d_functional.all_reduce_.default
):
node_tensor_info = get_snode_tensor_info(snode, return_data_size=True)
node_pg_info = get_snode_process_group_info(
snode,
expected_op=torch.ops._c10d_functional.all_reduce_.default,
resolve_pg=True,
)
if node_pg_info is None:
continue
node_info = node_tensor_info[:-2] + node_pg_info
input_size = node_tensor_info[-2]
all_reduce_input_size_dict[node_info].append(input_size)
elif is_collective(
snode.node, op=torch.ops._c10d_functional.all_to_all_single.default
):
node_tensor_info = get_snode_tensor_info(snode, return_data_size=True)
node_pg_info = get_snode_process_group_info(
snode,
expected_op=torch.ops._c10d_functional.all_to_all_single.default,
resolve_pg=True,
)
if node_pg_info is None:
continue
node_info = node_tensor_info[:-2] + node_pg_info
input_size = node_tensor_info[-2]
all_to_all_input_size_dict[node_info].append(input_size)
else:
if not is_collective(snode.node):
comp_time = estimate_comp_time(
sched, snode, verbose=False, comp_cache=comp_cache
)
comp_time_dict[fsdp_ag_idx] += comp_time
memory_dict[fsdp_ag_idx] = max(
abs(
memories_at_nodes[idx + 1]
- memories_at_nodes[release_steps[-1]]
),
memory_dict[fsdp_ag_idx],
)
peak_memory_per_step_dict[fsdp_ag_idx] = max(
memories_at_nodes[idx + 1], peak_memory_per_step_dict[fsdp_ag_idx]
)
else:
print(
"[Relaxed Setting] untracked communication",
snode.node.python_kernel_name,
)

# Sync total compute time
comp_time_dict = sync_dict_across_ranks(comp_time_dict, world_size)
memory_dict = sync_dict_across_ranks(memory_dict, world_size)
peak_memory_per_step_dict = sync_dict_across_ranks(
peak_memory_per_step_dict, world_size
)

if configs.load_cache and os.path.exists(configs.save_estimation_path):
with open(configs.save_estimation_path, "rb") as file:
cache = pickle.load(file)
comm_cache.cache = cache
comm_cache._update_max_size()
return comm_cache, comp_time_dict, memory_dict, peak_memory_per_step_dict

benchmark_params = [
(
fsdp_ag_input_size_dict,
"torch.ops._c10d_functional.all_gather_into_tensor.default",
cali_num_samples,
),
(
fsdp_rs_output_size_dict,
"torch.ops._c10d_functional.reduce_scatter_tensor.default",
cali_num_samples,
),
(
non_fsdp_ag_input_size_dict,
"torch.ops._c10d_functional.all_gather_into_tensor.default",
3,
),
(
non_fsdp_rs_input_size_dict,
"torch.ops._c10d_functional.reduce_scatter_tensor.default",
3,
),
(
all_reduce_input_size_dict,
"torch.ops._c10d_functional.all_reduce_.default",
3,
),
(
all_to_all_input_size_dict,
"torch.ops._c10d_functional.all_to_all_single.default",
3,
),
]
for input_size_dict, op_name, num_samples in benchmark_params:
if len(input_size_dict) > 0:
benchmark_and_cache_comm_dicts(
comm_cache, input_size_dict, op_name, num_samples
)

median_runtimes = sync_dict_across_ranks(comm_cache.cache, world_size)
comm_cache.cache = median_runtimes
comm_cache._update_max_size()
with open(configs.save_estimation_path, "wb") as file:
pickle.dump(comm_cache.cache, file)
return comm_cache, comp_time_dict, memory_dict, peak_memory_per_step_dict
Loading