|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import functools |
| 7 | + |
| 8 | +# mypy: ignore-errors |
| 9 | +from collections import defaultdict |
| 10 | +from typing import Any, Dict |
| 11 | + |
| 12 | +import torch |
| 13 | +from torch._C._distributed_c10d import ReduceOp |
| 14 | +from torch._inductor import scheduler |
| 15 | +from torch._inductor.comm import _schedule_fallback_operation |
| 16 | +from torch._inductor.utils import is_collective |
| 17 | + |
| 18 | +from .bucket_utils import ( |
| 19 | + check_ir_node_bucketable, |
| 20 | + estimate_bucketed_snode_runtime, |
| 21 | + get_data_size, |
| 22 | + get_snode_process_group_info, |
| 23 | + get_snode_tensor_info, |
| 24 | +) |
| 25 | +from .estimation import benchmark_and_sync_runtime |
| 26 | + |
| 27 | + |
| 28 | +def get_dynamic_memory_threshold( |
| 29 | + peak_memory, |
| 30 | + peak_memory_per_step_dict, |
| 31 | + current_step, |
| 32 | +) -> int: |
| 33 | + """ |
| 34 | + this function calculates the memory gap from the current step's memory to the peak memory |
| 35 | + """ |
| 36 | + left_peak_memory = 0 |
| 37 | + right_peak_memory = 0 |
| 38 | + for idx, memory in peak_memory_per_step_dict.items(): |
| 39 | + if idx <= current_step: |
| 40 | + left_peak_memory = max(memory, left_peak_memory) |
| 41 | + if idx >= current_step: |
| 42 | + right_peak_memory = max(memory, right_peak_memory) |
| 43 | + current_peak_memory = min(left_peak_memory, right_peak_memory) |
| 44 | + return peak_memory - current_peak_memory |
| 45 | + |
| 46 | + |
| 47 | +def get_simplefsdp_auto_plan( |
| 48 | + sched: "scheduler.Scheduler", |
| 49 | + snodes: list["scheduler.BaseSchedulerNode"], |
| 50 | + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], |
| 51 | + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], |
| 52 | + bucketable_nodes: set[str], |
| 53 | + configs: Any, |
| 54 | + verbose: bool = True, |
| 55 | +) -> tuple[ |
| 56 | + list[Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]]], |
| 57 | + list[Dict[tuple[Any, ...], list["scheduler.BaseSchedulerNode"]]], |
| 58 | +]: |
| 59 | + """ |
| 60 | + This function implements a greedy algorithm, which decides if the node could be bucketed |
| 61 | + with the previous one based on several criteria below: |
| 62 | + FWD Pass: |
| 63 | + (1) the bucketed AG communication could be overlapped by the previous computation; |
| 64 | + (2) the bucketed AG memory doesn't exceed peak memory; |
| 65 | + (3) bucketed AG communication size doesn't exceed 0.2*sum(fwd_ag_tensor_list), such |
| 66 | + that the estimated AG communication time is always in the calibration bound. |
| 67 | + BWD Pass: |
| 68 | + (1) the bucketed AG + RS communication could be overlapped by the previous computation; |
| 69 | + (2) the bucketed AG+RS memory doesn't exceed peak memory; |
| 70 | + (3) RS always have future compute to overlap it, such that its final exposed communication is small; |
| 71 | + (4) bucketed AG/RS communication size doesn't exceed 0.2* sum(fwd_ag_tensor_list) & 0.2* sum(bwd_rs_tensor_list), |
| 72 | + such that the estimated AG/RS communication time is always in the calibration bound. |
| 73 | + """ |
| 74 | + all_gather_plan = [] |
| 75 | + reduce_scatter_plan = [] |
| 76 | + current_ag_bucket: Dict[ |
| 77 | + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] |
| 78 | + ] = defaultdict(list) |
| 79 | + current_rs_bucket: Dict[ |
| 80 | + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] |
| 81 | + ] = defaultdict(list) |
| 82 | + schedule_fallback_operation = functools.partial( |
| 83 | + _schedule_fallback_operation, |
| 84 | + scheduler=sched, |
| 85 | + name_to_buf=name_to_buf, |
| 86 | + name_to_fused_node=name_to_fused_node, |
| 87 | + ) |
| 88 | + |
| 89 | + heuristic_info = { |
| 90 | + # time info |
| 91 | + "last_step_rs_comm_time": 0.0, |
| 92 | + "this_step_comp_time": 0.0, |
| 93 | + "this_step_rs_comm_time": 0.0, |
| 94 | + "next_step_comp_time": 0.0, |
| 95 | + "next_step_nonfsdp_comm_time": 0.0, |
| 96 | + # memory info |
| 97 | + "accumulated_gradient_size": 0, |
| 98 | + "last_step_rs_comm_size": 0, |
| 99 | + "this_step_rs_comm_out_size": 0, |
| 100 | + "this_step_rs_comm_inp_size": 0, |
| 101 | + "this_step_memory": 0, |
| 102 | + "next_step_memory": 0, |
| 103 | + } |
| 104 | + |
| 105 | + # sync runtime info across ranks |
| 106 | + ( |
| 107 | + comm_cache, |
| 108 | + comp_time_dict, |
| 109 | + memory_dict, |
| 110 | + peak_memory_per_step_dict, |
| 111 | + ) = benchmark_and_sync_runtime(sched, snodes, bucketable_nodes) |
| 112 | + future_comp_time = sum(comp_time_dict.values()) |
| 113 | + peak_memory = max(peak_memory_per_step_dict.values()) + configs.peak_memory_offset |
| 114 | + |
| 115 | + # autobucket algorithm |
| 116 | + bucketable_ag_idx = -1 |
| 117 | + seen_new_bucketable_ag = True |
| 118 | + for _, snode in enumerate(snodes): |
| 119 | + if is_collective( |
| 120 | + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default |
| 121 | + ) and check_ir_node_bucketable(snode.node, bucketable_nodes): |
| 122 | + bucketable_ag_idx += 1 |
| 123 | + seen_new_bucketable_ag = True |
| 124 | + future_comp_time -= comp_time_dict[bucketable_ag_idx] |
| 125 | + |
| 126 | + ag_node_info = get_snode_tensor_info( |
| 127 | + snode, return_data_size=False |
| 128 | + ) + get_snode_process_group_info( |
| 129 | + snode, |
| 130 | + expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default, |
| 131 | + resolve_pg=False, |
| 132 | + ) |
| 133 | + current_ag_bucket[ag_node_info].append(snode) |
| 134 | + ( |
| 135 | + estimated_comm, |
| 136 | + comm_size_inp, |
| 137 | + comm_size_out, |
| 138 | + ) = estimate_bucketed_snode_runtime( |
| 139 | + current_ag_bucket, |
| 140 | + schedule_fallback_operation, |
| 141 | + name_to_buf, |
| 142 | + torch.ops._c10d_functional.all_gather_into_tensor.default, |
| 143 | + comm_cache, |
| 144 | + ) |
| 145 | + |
| 146 | + # Check if current bucketing breaks the greedy criteria |
| 147 | + # (1) Overlappping criteria |
| 148 | + comp_time = heuristic_info["this_step_comp_time"] * ( |
| 149 | + 1 + configs.relax_ratio |
| 150 | + ) |
| 151 | + comm_time = estimated_comm + heuristic_info["last_step_rs_comm_time"] |
| 152 | + break_overlap_criteria = comp_time < comm_time |
| 153 | + |
| 154 | + # (2) Memory criteria |
| 155 | + memory_threshold = get_dynamic_memory_threshold( |
| 156 | + peak_memory, |
| 157 | + peak_memory_per_step_dict, |
| 158 | + bucketable_ag_idx, |
| 159 | + ) |
| 160 | + # the buckted AG/RS are created on-the-fly, whose memory was not captured by the |
| 161 | + # estimate_peak_memory function. The bucketed_comm_memory consists of: |
| 162 | + # in FWD pass: |
| 163 | + # (1) all-gather copy-in (comm_size_inp): smaller buffers for dtype_conversion + bigger buffer to copy_in smaller buffers |
| 164 | + # thus, we have comm_size_inp*2 |
| 165 | + # (2) all-gather copy-out (comm_size_out): bigger buffer to copy_out from ag_wait + split out smaller buffers for compute |
| 166 | + # thus, we have comm_size_out*2 |
| 167 | + # in BWD pass: |
| 168 | + # TODO (ruisizhang123): we need to double check this. From memory trace, we can clearly see |
| 169 | + # these three regions stack together at a certain moment |
| 170 | + # due to reordering, the peak memory occurs at the end of current step's all-gather when last step & this step's reduce-scatter |
| 171 | + # are not cleared in time |
| 172 | + # (1) all-gather copy-in/copy-out (like FWD pass) |
| 173 | + # (2) last step's reduce-scatter: bigger buffer containts gradient |
| 174 | + # (3) next step's reduce-scatter: smaller buffers for dtype_conversion + bigger buffer to copy_in gradient |
| 175 | + bucketed_comm_memory = ( |
| 176 | + 2 * comm_size_inp |
| 177 | + + 2 * comm_size_out |
| 178 | + + heuristic_info["accumulated_gradient_size"] |
| 179 | + + heuristic_info["this_step_rs_comm_inp_size"] * 2 |
| 180 | + + heuristic_info["last_step_rs_comm_size"] |
| 181 | + ) |
| 182 | + break_memory_criteria = ( |
| 183 | + memory_threshold |
| 184 | + < heuristic_info["next_step_memory"] + bucketed_comm_memory |
| 185 | + ) |
| 186 | + |
| 187 | + # (3) Communication size criteria |
| 188 | + break_comm_size_criteria = ( |
| 189 | + comm_cache.ag_max_inp_size <= comm_size_inp |
| 190 | + or comm_cache.rs_max_out_size |
| 191 | + <= heuristic_info["this_step_rs_comm_out_size"] |
| 192 | + ) |
| 193 | + |
| 194 | + if ( |
| 195 | + break_overlap_criteria |
| 196 | + or break_memory_criteria |
| 197 | + or break_comm_size_criteria |
| 198 | + ): |
| 199 | + if heuristic_info["this_step_comp_time"] > 0: |
| 200 | + # if bucketing breaks the greedy criteria, pop the last node out |
| 201 | + overflow_ag = current_ag_bucket[ag_node_info].pop() |
| 202 | + all_gather_plan.append(current_ag_bucket) |
| 203 | + current_ag_bucket: Dict[ |
| 204 | + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] |
| 205 | + ] = defaultdict(list) |
| 206 | + current_ag_bucket[ag_node_info].append(overflow_ag) |
| 207 | + else: |
| 208 | + # if there is no compute, we have to keep the all_gather to avoid deadlock |
| 209 | + all_gather_plan.append(current_ag_bucket) |
| 210 | + current_ag_bucket: Dict[ |
| 211 | + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] |
| 212 | + ] = defaultdict(list) |
| 213 | + |
| 214 | + if verbose: |
| 215 | + print( |
| 216 | + "break_overlap_criteria", |
| 217 | + break_overlap_criteria, |
| 218 | + ) |
| 219 | + print("Current comm time", comm_time, "comp time", comp_time) |
| 220 | + print( |
| 221 | + "break_memory_criteria", |
| 222 | + break_memory_criteria, |
| 223 | + ) |
| 224 | + print( |
| 225 | + "memory_threshold", |
| 226 | + memory_threshold, |
| 227 | + "total memory", |
| 228 | + heuristic_info["next_step_memory"] + bucketed_comm_memory, |
| 229 | + ) |
| 230 | + print( |
| 231 | + "break_comm_size_criteria", |
| 232 | + break_comm_size_criteria, |
| 233 | + ) |
| 234 | + print("current_ag_bucket", all_gather_plan[-1]) |
| 235 | + |
| 236 | + # bucket reduce scatters if there are any |
| 237 | + if len(current_rs_bucket) > 0: |
| 238 | + ( |
| 239 | + current_estimated_rs, |
| 240 | + rs_comm_size_inp, |
| 241 | + rs_comm_size_out, |
| 242 | + ) = estimate_bucketed_snode_runtime( |
| 243 | + current_rs_bucket, |
| 244 | + schedule_fallback_operation, |
| 245 | + name_to_buf, |
| 246 | + torch.ops._c10d_functional.reduce_scatter_tensor.default, |
| 247 | + comm_cache, |
| 248 | + ReduceOp.AVG, |
| 249 | + ) |
| 250 | + heuristic_info["last_step_rs_comm_time"] = current_estimated_rs |
| 251 | + reduce_scatter_plan.append(current_rs_bucket) |
| 252 | + heuristic_info["last_step_rs_comm_size"] = rs_comm_size_out |
| 253 | + current_rs_bucket: Dict[ |
| 254 | + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] |
| 255 | + ] = defaultdict(list) |
| 256 | + |
| 257 | + # update heuristic info for the next step |
| 258 | + ( |
| 259 | + heuristic_info["this_step_comp_time"], |
| 260 | + heuristic_info["this_step_memory"], |
| 261 | + ) = ( |
| 262 | + heuristic_info["next_step_comp_time"] |
| 263 | + + heuristic_info["next_step_nonfsdp_comm_time"], |
| 264 | + heuristic_info["next_step_memory"], |
| 265 | + ) |
| 266 | + ( |
| 267 | + heuristic_info["next_step_comp_time"], |
| 268 | + heuristic_info["next_step_memory"], |
| 269 | + ) = ( |
| 270 | + 0, |
| 271 | + 0, |
| 272 | + ) |
| 273 | + elif is_collective( |
| 274 | + snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default |
| 275 | + ) and check_ir_node_bucketable(snode.node, bucketable_nodes): |
| 276 | + node_info = get_snode_tensor_info( |
| 277 | + snode, return_data_size=False |
| 278 | + ) + get_snode_process_group_info( |
| 279 | + snode, |
| 280 | + expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default, |
| 281 | + resolve_pg=False, |
| 282 | + ) |
| 283 | + current_rs_bucket[node_info].append(snode) |
| 284 | + |
| 285 | + ( |
| 286 | + heuristic_info["this_step_rs_comm_time"], |
| 287 | + rs_comm_size_inp, |
| 288 | + rs_comm_size_out, |
| 289 | + ) = estimate_bucketed_snode_runtime( |
| 290 | + current_rs_bucket, |
| 291 | + schedule_fallback_operation, |
| 292 | + name_to_buf, |
| 293 | + torch.ops._c10d_functional.reduce_scatter_tensor.default, |
| 294 | + comm_cache, |
| 295 | + ReduceOp.AVG, |
| 296 | + ) |
| 297 | + heuristic_info["this_step_rs_comm_out_size"] = rs_comm_size_out |
| 298 | + heuristic_info["this_step_rs_comm_inp_size"] = rs_comm_size_inp |
| 299 | + heuristic_info["accumulated_gradient_size"] += get_data_size( |
| 300 | + snode.node.layout.size |
| 301 | + ) |
| 302 | + |
| 303 | + # Check if current bucketing breaks the greedy criteria |
| 304 | + # (4) future compute to overlap RS criteria |
| 305 | + break_rs_overlap_criteria = ( |
| 306 | + future_comp_time < heuristic_info["this_step_rs_comm_time"] * 5 |
| 307 | + ) |
| 308 | + if break_rs_overlap_criteria: |
| 309 | + reduce_scatter_plan.append(current_rs_bucket) |
| 310 | + heuristic_info["last_step_rs_comm_time"] = heuristic_info[ |
| 311 | + "this_step_rs_comm_time" |
| 312 | + ] |
| 313 | + heuristic_info["this_step_rs_comm_time"] = 0 |
| 314 | + current_rs_bucket: Dict[ |
| 315 | + tuple[Any, ...], list["scheduler.BaseSchedulerNode"] |
| 316 | + ] = defaultdict(list) |
| 317 | + else: |
| 318 | + # TODO (ruisizhang123): for now, we only consider FSDP + (TP & CP), whose comms are AG & RS & All_Reduce |
| 319 | + # For TP and CP, we consider the node as a "COMP" node with exposed communication as Comp time |
| 320 | + if is_collective(snode.node): |
| 321 | + current_comm = comm_cache.get_comm_time( |
| 322 | + snode.node.inputs[0].layout.size, |
| 323 | + snode.node.layout.size, |
| 324 | + getattr(snode.node, "python_kernel_name", ""), |
| 325 | + calibrated=True, |
| 326 | + ) |
| 327 | + heuristic_info["next_step_nonfsdp_comm_time"] += current_comm |
| 328 | + else: |
| 329 | + if seen_new_bucketable_ag: |
| 330 | + heuristic_info["next_step_memory"] += memory_dict[bucketable_ag_idx] |
| 331 | + heuristic_info["next_step_comp_time"] += comp_time_dict[ |
| 332 | + bucketable_ag_idx |
| 333 | + ] |
| 334 | + seen_new_bucketable_ag = False |
| 335 | + |
| 336 | + if len(current_ag_bucket) > 0 or len(all_gather_plan) == 0: |
| 337 | + all_gather_plan.append(current_ag_bucket) |
| 338 | + |
| 339 | + if len(current_rs_bucket) > 0 or len(reduce_scatter_plan) == 0: |
| 340 | + reduce_scatter_plan.append(current_rs_bucket) |
| 341 | + |
| 342 | + return all_gather_plan, reduce_scatter_plan |
0 commit comments