|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +# mypy: ignore-errors |
| 7 | +import os |
| 8 | +import pickle |
| 9 | +from collections import defaultdict |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.distributed as c10d |
| 13 | +from torch._inductor import memory, scheduler |
| 14 | +from torch._inductor.utils import is_collective |
| 15 | +from torch._inductor.virtualized import V |
| 16 | +from torch.utils._ordered_set import OrderedSet |
| 17 | + |
| 18 | +from ..auto_bucketing import simplefsdp_autobucketing_config |
| 19 | +from .bucket_utils import ( |
| 20 | + check_ir_node_bucketable, |
| 21 | + get_snode_process_group_info, |
| 22 | + get_snode_tensor_info, |
| 23 | +) |
| 24 | +from .estimation_utils import ( |
| 25 | + CommPerfCache, |
| 26 | + CompPerfCache, |
| 27 | + benchmark_and_cache_comm_dicts, |
| 28 | + estimate_comp_time, |
| 29 | +) |
| 30 | + |
| 31 | + |
| 32 | +def sync_dict_across_ranks(runtime_dict, world_size, group=None): |
| 33 | + gathered_lists = [None for _ in range(world_size)] |
| 34 | + c10d.all_gather_object(gathered_lists, list(runtime_dict.values()), group=group) |
| 35 | + median_gathered_time = torch.median(torch.tensor(gathered_lists), dim=0).values |
| 36 | + for idx, (key, value) in enumerate(runtime_dict.items()): |
| 37 | + runtime_dict[key] = median_gathered_time[idx] |
| 38 | + return runtime_dict |
| 39 | + |
| 40 | + |
| 41 | +def benchmark_and_sync_runtime( |
| 42 | + sched: "scheduler.Scheduler", |
| 43 | + snodes: list["scheduler.BaseSchedulerNode"], |
| 44 | + name_to_buf: dict[str, "scheduler.SchedulerBuffer"], |
| 45 | + name_to_fused_node: dict[str, "scheduler.BaseSchedulerNode"], |
| 46 | + bucketable_nodes: set[str], |
| 47 | + configs: "simplefsdp_autobucketing_config", |
| 48 | +): |
| 49 | + world_size = c10d.distributed_c10d.get_world_size() |
| 50 | + |
| 51 | + fsdp_ag_input_size_dict = defaultdict(list) |
| 52 | + fsdp_rs_output_size_dict = defaultdict(list) |
| 53 | + non_fsdp_ag_input_size_dict = defaultdict(list) |
| 54 | + non_fsdp_rs_input_size_dict = defaultdict(list) |
| 55 | + all_reduce_input_size_dict = defaultdict(list) |
| 56 | + all_to_all_input_size_dict = defaultdict(list) |
| 57 | + comp_cache, comm_cache = CompPerfCache(), CommPerfCache() |
| 58 | + |
| 59 | + cali_num_samples = configs.calibrate_number |
| 60 | + comp_time_dict = defaultdict(float) |
| 61 | + memory_dict = defaultdict(int) |
| 62 | + peak_memory_per_step_dict = defaultdict(int) |
| 63 | + fsdp_ag_idx = -1 |
| 64 | + release_steps = [0] |
| 65 | + |
| 66 | + graph_outputs = OrderedSet(V.graph.get_output_names()) |
| 67 | + graph_inputs = OrderedSet(V.graph.graph_inputs.keys()) |
| 68 | + _, name_to_freeable_input_buf = memory.prepare_planning_info( |
| 69 | + snodes, |
| 70 | + name_to_buf, |
| 71 | + name_to_fused_node, |
| 72 | + graph_inputs, |
| 73 | + graph_outputs, |
| 74 | + ) |
| 75 | + _, memories_at_nodes = memory.estimate_peak_memory( |
| 76 | + snodes, name_to_freeable_input_buf, graph_outputs |
| 77 | + ) |
| 78 | + |
| 79 | + for idx, snode in enumerate(snodes): |
| 80 | + if is_collective( |
| 81 | + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default |
| 82 | + ): |
| 83 | + fsdp_ag_idx += 1 |
| 84 | + release_steps.append(idx) |
| 85 | + node_tensor_info = get_snode_tensor_info(snode, return_data_size=True) |
| 86 | + node_pg_info = get_snode_process_group_info( |
| 87 | + snode, |
| 88 | + expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default, |
| 89 | + resolve_pg=True, |
| 90 | + ) |
| 91 | + if node_pg_info is None: |
| 92 | + continue |
| 93 | + node_info = node_tensor_info[:-2] + node_pg_info |
| 94 | + input_size = node_tensor_info[-2] |
| 95 | + if check_ir_node_bucketable(snode.node, bucketable_nodes): |
| 96 | + # For FSDP, we assume they have all have the |
| 97 | + fsdp_ag_input_size_dict[node_info].append(input_size) |
| 98 | + else: |
| 99 | + non_fsdp_ag_input_size_dict[node_info].append(input_size) |
| 100 | + elif is_collective( |
| 101 | + snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default |
| 102 | + ): |
| 103 | + node_tensor_info = get_snode_tensor_info(snode, return_data_size=True) |
| 104 | + node_pg_info = get_snode_process_group_info( |
| 105 | + snode, |
| 106 | + expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default, |
| 107 | + resolve_pg=True, |
| 108 | + ) |
| 109 | + if node_pg_info is None: |
| 110 | + continue |
| 111 | + node_info = node_tensor_info[:-2] + node_pg_info |
| 112 | + output_size = node_tensor_info[-1] |
| 113 | + if check_ir_node_bucketable(snode.node, bucketable_nodes): |
| 114 | + # For FSDP, we assume they have all have the same group size |
| 115 | + fsdp_rs_output_size_dict[node_info].append(output_size) |
| 116 | + else: |
| 117 | + non_fsdp_rs_input_size_dict[node_info].append(output_size) |
| 118 | + elif is_collective( |
| 119 | + snode.node, op=torch.ops._c10d_functional.all_reduce_.default |
| 120 | + ): |
| 121 | + node_tensor_info = get_snode_tensor_info(snode, return_data_size=True) |
| 122 | + node_pg_info = get_snode_process_group_info( |
| 123 | + snode, |
| 124 | + expected_op=torch.ops._c10d_functional.all_reduce_.default, |
| 125 | + resolve_pg=True, |
| 126 | + ) |
| 127 | + if node_pg_info is None: |
| 128 | + continue |
| 129 | + node_info = node_tensor_info[:-2] + node_pg_info |
| 130 | + input_size = node_tensor_info[-2] |
| 131 | + all_reduce_input_size_dict[node_info].append(input_size) |
| 132 | + elif is_collective( |
| 133 | + snode.node, op=torch.ops._c10d_functional.all_to_all_single.default |
| 134 | + ): |
| 135 | + node_tensor_info = get_snode_tensor_info(snode, return_data_size=True) |
| 136 | + node_pg_info = get_snode_process_group_info( |
| 137 | + snode, |
| 138 | + expected_op=torch.ops._c10d_functional.all_to_all_single.default, |
| 139 | + resolve_pg=True, |
| 140 | + ) |
| 141 | + if node_pg_info is None: |
| 142 | + continue |
| 143 | + node_info = node_tensor_info[:-2] + node_pg_info |
| 144 | + input_size = node_tensor_info[-2] |
| 145 | + all_to_all_input_size_dict[node_info].append(input_size) |
| 146 | + else: |
| 147 | + if not is_collective(snode.node): |
| 148 | + comp_time = estimate_comp_time( |
| 149 | + sched, snode, verbose=False, comp_cache=comp_cache |
| 150 | + ) |
| 151 | + comp_time_dict[fsdp_ag_idx] += comp_time |
| 152 | + memory_dict[fsdp_ag_idx] = max( |
| 153 | + abs( |
| 154 | + memories_at_nodes[idx + 1] |
| 155 | + - memories_at_nodes[release_steps[-1]] |
| 156 | + ), |
| 157 | + memory_dict[fsdp_ag_idx], |
| 158 | + ) |
| 159 | + peak_memory_per_step_dict[fsdp_ag_idx] = max( |
| 160 | + memories_at_nodes[idx + 1], peak_memory_per_step_dict[fsdp_ag_idx] |
| 161 | + ) |
| 162 | + else: |
| 163 | + print( |
| 164 | + "[Relaxed Setting] untracked communication", |
| 165 | + snode.node.python_kernel_name, |
| 166 | + ) |
| 167 | + |
| 168 | + # Sync total compute time |
| 169 | + comp_time_dict = sync_dict_across_ranks(comp_time_dict, world_size) |
| 170 | + memory_dict = sync_dict_across_ranks(memory_dict, world_size) |
| 171 | + peak_memory_per_step_dict = sync_dict_across_ranks( |
| 172 | + peak_memory_per_step_dict, world_size |
| 173 | + ) |
| 174 | + |
| 175 | + if configs.load_cache and os.path.exists(configs.save_estimation_path): |
| 176 | + with open(configs.save_estimation_path, "rb") as file: |
| 177 | + cache = pickle.load(file) |
| 178 | + comm_cache.cache = cache |
| 179 | + comm_cache._update_max_size() |
| 180 | + return comm_cache, comp_time_dict, memory_dict, peak_memory_per_step_dict |
| 181 | + |
| 182 | + benchmark_params = [ |
| 183 | + ( |
| 184 | + fsdp_ag_input_size_dict, |
| 185 | + "torch.ops._c10d_functional.all_gather_into_tensor.default", |
| 186 | + cali_num_samples, |
| 187 | + ), |
| 188 | + ( |
| 189 | + fsdp_rs_output_size_dict, |
| 190 | + "torch.ops._c10d_functional.reduce_scatter_tensor.default", |
| 191 | + cali_num_samples, |
| 192 | + ), |
| 193 | + ( |
| 194 | + non_fsdp_ag_input_size_dict, |
| 195 | + "torch.ops._c10d_functional.all_gather_into_tensor.default", |
| 196 | + 3, |
| 197 | + ), |
| 198 | + ( |
| 199 | + non_fsdp_rs_input_size_dict, |
| 200 | + "torch.ops._c10d_functional.reduce_scatter_tensor.default", |
| 201 | + 3, |
| 202 | + ), |
| 203 | + ( |
| 204 | + all_reduce_input_size_dict, |
| 205 | + "torch.ops._c10d_functional.all_reduce_.default", |
| 206 | + 3, |
| 207 | + ), |
| 208 | + ( |
| 209 | + all_to_all_input_size_dict, |
| 210 | + "torch.ops._c10d_functional.all_to_all_single.default", |
| 211 | + 3, |
| 212 | + ), |
| 213 | + ] |
| 214 | + for input_size_dict, op_name, num_samples in benchmark_params: |
| 215 | + if len(input_size_dict) > 0: |
| 216 | + benchmark_and_cache_comm_dicts( |
| 217 | + comm_cache, input_size_dict, op_name, num_samples |
| 218 | + ) |
| 219 | + |
| 220 | + median_runtimes = sync_dict_across_ranks(comm_cache.cache, world_size) |
| 221 | + comm_cache.cache = median_runtimes |
| 222 | + comm_cache._update_max_size() |
| 223 | + with open(configs.simplefsdp.save_estimation_path, "wb") as file: |
| 224 | + pickle.dump(comm_cache.cache, file) |
| 225 | + return comm_cache, comp_time_dict, memory_dict, peak_memory_per_step_dict |
0 commit comments