|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 | # mypy: ignore-errors |
7 | | -from functools import reduce |
8 | | -from typing import Any, Callable, Dict, Union |
| 7 | +from typing import Any, Callable, Dict |
9 | 8 |
|
10 | 9 | import torch |
11 | | -from torch._inductor import ir, scheduler |
| 10 | +from torch._inductor import scheduler |
12 | 11 | from torch._inductor.dependencies import WeakDep |
13 | | -from torch._inductor.ir import NoneLayout |
14 | | -from torch._inductor.utils import buf_name_to_fused_snode, is_collective, is_wait |
15 | | -from torch.distributed import ProcessGroup |
16 | | -from torch.distributed.distributed_c10d import _resolve_process_group |
| 12 | +from torch._inductor.utils import buf_name_to_fused_snode, is_collective |
17 | 13 | from torch.utils._ordered_set import OrderedSet |
18 | 14 |
|
19 | 15 |
|
20 | | -def get_data_size(size): |
21 | | - return reduce(lambda x, y: x * y, size) |
22 | | - |
23 | | - |
24 | 16 | def _find_recursive_deps_of_snode( |
25 | 17 | snode: "scheduler.BaseSchedulerNode", |
26 | 18 | collected_node_set: OrderedSet["scheduler.BaseSchedulerNode"], |
@@ -125,119 +117,3 @@ def get_bucketable_ir_nodes( |
125 | 117 | bucketable_ir_nodes.add(snode.node.get_name()) |
126 | 118 |
|
127 | 119 | return bucketable_ir_nodes |
128 | | - |
129 | | - |
130 | | -def check_ir_node_bucketable( |
131 | | - ir_node: "ir.IRNode", bucketable_ir_nodes: set[str] |
132 | | -) -> bool: |
133 | | - """ |
134 | | - Determine if the AG/RS & AG/RS wait node is from bucketable nodes or not |
135 | | - """ |
136 | | - ir_node_origins = list(getattr(ir_node, "origins", None)) |
137 | | - if len(ir_node_origins) == 0: |
138 | | - # bucketed AG and RS doesn't have origins |
139 | | - return True |
140 | | - |
141 | | - if is_wait(ir_node): |
142 | | - ir_node = ir_node.inputs[0] |
143 | | - |
144 | | - if is_collective( |
145 | | - ir_node, op=torch.ops._c10d_functional.all_gather_into_tensor.default |
146 | | - ): |
147 | | - ir_node_name = ir_node.get_name() |
148 | | - elif is_collective( |
149 | | - ir_node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default |
150 | | - ): |
151 | | - ir_node_name = ir_node.get_name() |
152 | | - else: |
153 | | - return False |
154 | | - |
155 | | - if ir_node_name in bucketable_ir_nodes: |
156 | | - return True |
157 | | - |
158 | | - return False |
159 | | - |
160 | | - |
161 | | -def _get_fx_node( |
162 | | - snode_or_ir_node: Union["scheduler.BaseSchedulerNode", "ir.IRNode"], |
163 | | - expected_op: Any, |
164 | | -) -> torch.fx.Node: |
165 | | - origins = None |
166 | | - if isinstance(snode_or_ir_node, scheduler.BaseSchedulerNode): |
167 | | - origins = snode_or_ir_node.node.get_origins() |
168 | | - elif isinstance(snode_or_ir_node, ir.IRNode): |
169 | | - origins = snode_or_ir_node.origins |
170 | | - else: |
171 | | - raise ValueError( |
172 | | - f"Expected BaseSchedulerNode or IRNode, got {type(snode_or_ir_node)}. Offending value: {snode_or_ir_node}" |
173 | | - ) |
174 | | - origins_with_expected_op = [o for o in origins if o.target == expected_op] |
175 | | - if len(origins_with_expected_op) != 1: |
176 | | - print( |
177 | | - "[Get FX exception] origins_with_expected_op", |
178 | | - origins_with_expected_op, |
179 | | - "expected_op", |
180 | | - expected_op, |
181 | | - "snode_or_ir_node", |
182 | | - snode_or_ir_node, |
183 | | - ) |
184 | | - return None |
185 | | - return origins_with_expected_op[0] |
186 | | - |
187 | | - |
188 | | -def get_snode_process_group_info( |
189 | | - snode: "scheduler.BaseSchedulerNode", |
190 | | - expected_op: Any, |
191 | | - resolve_pg: bool = False, |
192 | | -) -> tuple[int, Union[str, ProcessGroup]]: |
193 | | - fx_node = _get_fx_node(snode, expected_op=expected_op) |
194 | | - # return None if the snode doesn't have a valid fx_node |
195 | | - if fx_node is None: |
196 | | - return None |
197 | | - |
198 | | - if expected_op == torch.ops._c10d_functional.all_gather_into_tensor.default: |
199 | | - group_size, group_name = ( |
200 | | - snode.node.constant_args[0], |
201 | | - snode.node.constant_args[1], |
202 | | - ) |
203 | | - elif expected_op == torch.ops._c10d_functional.reduce_scatter_tensor.default: |
204 | | - group_size, group_name = ( |
205 | | - snode.node.constant_args[1], |
206 | | - snode.node.constant_args[2], |
207 | | - ) |
208 | | - elif expected_op == torch.ops._c10d_functional.all_reduce_.default: |
209 | | - group_size, group_name = fx_node.args[1], fx_node.args[2] |
210 | | - elif expected_op == torch.ops._c10d_functional.all_to_all_single.default: |
211 | | - group_size, group_name = 0, fx_node.args[3] |
212 | | - else: |
213 | | - raise ValueError(f"Unsupported op {expected_op}") |
214 | | - |
215 | | - if resolve_pg: |
216 | | - group_name = _resolve_process_group(group_name) |
217 | | - return group_size, group_name |
218 | | - |
219 | | - |
220 | | -def get_snode_tensor_info( |
221 | | - snode: "scheduler.BaseSchedulerNode", return_data_size: bool = False |
222 | | -) -> tuple[Any, ...]: |
223 | | - input_dtype, input_device = ( |
224 | | - snode.node.inputs[0].layout.dtype, |
225 | | - snode.node.inputs[0].layout.device, |
226 | | - ) |
227 | | - input_size = get_data_size(snode.node.inputs[0].layout.size) |
228 | | - |
229 | | - if not isinstance(snode.node.layout, NoneLayout): |
230 | | - output_dtype, output_device = ( |
231 | | - snode.node.layout.dtype, |
232 | | - snode.node.layout.device, |
233 | | - ) |
234 | | - output_size = get_data_size(snode.node.layout.size) |
235 | | - else: |
236 | | - # In all_reduce, layout is NoneLayout |
237 | | - # We set output info to be the same as input info as a special treatment |
238 | | - output_dtype, output_device, output_size = input_dtype, input_device, input_size |
239 | | - |
240 | | - result = (input_dtype, input_device, output_dtype, output_device) |
241 | | - if return_data_size: |
242 | | - result += (input_size, output_size) |
243 | | - return result |
0 commit comments