|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +# mypy: ignore-errors |
| 7 | +from typing import Any, Callable, Dict |
| 8 | + |
| 9 | +import torch |
| 10 | +from torch._inductor import scheduler |
| 11 | +from torch._inductor.dependencies import WeakDep |
| 12 | +from torch._inductor.utils import buf_name_to_fused_snode, is_collective |
| 13 | +from torch.utils._ordered_set import OrderedSet |
| 14 | + |
| 15 | + |
| 16 | +def _find_recursive_deps_of_snode( |
| 17 | + snode: "scheduler.BaseSchedulerNode", |
| 18 | + collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"], |
| 19 | + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], |
| 20 | + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], |
| 21 | + criteria_cb: Callable[[Any], bool] = lambda snode: False, |
| 22 | + allow_weak_dep: bool = True, |
| 23 | +): |
| 24 | + if criteria_cb(snode): |
| 25 | + return |
| 26 | + collected_node_set.add(snode) |
| 27 | + for dep in snode.unmet_dependencies: |
| 28 | + if isinstance(dep, WeakDep) and not allow_weak_dep: |
| 29 | + continue |
| 30 | + defining_op_for_dep = buf_name_to_fused_snode( |
| 31 | + dep.name, name_to_buf, name_to_fused_node |
| 32 | + ) |
| 33 | + if defining_op_for_dep in collected_node_set: |
| 34 | + continue |
| 35 | + _find_recursive_deps_of_snode( |
| 36 | + defining_op_for_dep, |
| 37 | + collected_node_set, |
| 38 | + name_to_buf, |
| 39 | + name_to_fused_node, |
| 40 | + criteria_cb=criteria_cb, |
| 41 | + ) |
| 42 | + |
| 43 | + |
| 44 | +def _find_recursive_users_of_snode( |
| 45 | + snode: "scheduler.BaseSchedulerNode", |
| 46 | + collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"], |
| 47 | + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], |
| 48 | + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], |
| 49 | + criteria_cb: Callable[[Any], bool] = lambda snode: False, |
| 50 | +): |
| 51 | + if criteria_cb(snode): |
| 52 | + return |
| 53 | + collected_node_set.add(snode) |
| 54 | + for o in snode.get_outputs(): |
| 55 | + for user in o.users: |
| 56 | + assert user.node is not None |
| 57 | + if user.node.get_name() == "OUTPUT": |
| 58 | + continue |
| 59 | + if user.node.get_name() not in name_to_fused_node: |
| 60 | + continue |
| 61 | + user_op = name_to_fused_node[user.node.get_name()] |
| 62 | + if user_op in collected_node_set: |
| 63 | + continue |
| 64 | + _find_recursive_users_of_snode( |
| 65 | + user_op, |
| 66 | + collected_node_set, |
| 67 | + name_to_buf, |
| 68 | + name_to_fused_node, |
| 69 | + criteria_cb=criteria_cb, |
| 70 | + ) |
| 71 | + |
| 72 | + |
| 73 | +def get_bucketable_ir_nodes( |
| 74 | + snodes: list["torch._inductor.scheduler.BaseSchedulerNode"], |
| 75 | + name_to_fused_node: Dict[str, "scheduler.BaseSchedulerNode"], |
| 76 | + name_to_buf: Dict[str, "scheduler.SchedulerBuffer"], |
| 77 | +) -> set[str]: |
| 78 | + """ |
| 79 | + This function selects the ir nodes' names that are bucketable |
| 80 | + From first principle, only all-gathers that gather parameters and reduce-scatters |
| 81 | + that update model gradients could be bucketed together. |
| 82 | + Thus, bucketable all-gathers's deps are (1) computed buffer for dtype conversion (optional) |
| 83 | + (2) all-gather itself |
| 84 | + bucketable reduce-scatter wait's users are (1) reduce-scatter wait itself |
| 85 | + """ |
| 86 | + bucketable_ir_nodes = set() |
| 87 | + for snode in snodes: |
| 88 | + if is_collective( |
| 89 | + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default |
| 90 | + ): |
| 91 | + ag_related_snode_set: OrderedSet[ |
| 92 | + "torch._inductor.scheduler.BaseSchedulerNode" |
| 93 | + ] = OrderedSet() |
| 94 | + _find_recursive_deps_of_snode( |
| 95 | + snode, |
| 96 | + ag_related_snode_set, |
| 97 | + name_to_buf, |
| 98 | + name_to_fused_node, |
| 99 | + allow_weak_dep=False, |
| 100 | + ) |
| 101 | + if len(ag_related_snode_set) <= 2: |
| 102 | + bucketable_ir_nodes.add(snode.node.get_name()) |
| 103 | + elif is_collective( |
| 104 | + snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default |
| 105 | + ): |
| 106 | + wait_snode = snode.get_outputs()[0].users[0].node |
| 107 | + wait_snode_recursive_users: OrderedSet[ |
| 108 | + "torch._inductor.scheduler.BaseSchedulerNode" |
| 109 | + ] = OrderedSet() |
| 110 | + _find_recursive_users_of_snode( |
| 111 | + wait_snode, |
| 112 | + wait_snode_recursive_users, |
| 113 | + name_to_buf, |
| 114 | + name_to_fused_node, |
| 115 | + ) |
| 116 | + if len(wait_snode_recursive_users) <= 1: |
| 117 | + bucketable_ir_nodes.add(snode.node.get_name()) |
| 118 | + |
| 119 | + return bucketable_ir_nodes |
0 commit comments