Skip to content

Commit 004bbd2

Browse files
Erik-Lundellzingo
andauthored
Add get_cond_while_submodules utility (#15699)
Adding while to get_control_flow_submodules seems risky, could have unintended side effects. The new function is similar to get_control_flow_submodules though, so refactor that to share logic. cc @JacobSzwejbka @angelayi @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Erik Lundell <erik.lundell@arm.com> Co-authored-by: Zingo Andersen <zingo.andersen@arm.com>
1 parent e774b77 commit 004bbd2

File tree

1 file changed

+47
-9
lines changed

1 file changed

+47
-9
lines changed

exir/graph_module.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -10,6 +11,7 @@
1011
from typing import Callable, Dict, List, Tuple, Union
1112

1213
import torch
14+
from torch._ops import HigherOrderOperator
1315

1416

1517
LeafValue = Union[
@@ -46,30 +48,66 @@ def _get_submodule(
4648
return submod_node.target, submodule, node
4749

4850

49-
def get_control_flow_submodules(
51+
def _get_control_flow_submodules(
5052
graph_module: torch.fx.GraphModule,
53+
op_to_submodule_arg_index: dict[HigherOrderOperator, list[int]],
5154
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
5255
"""
5356
Returns a list of submodules used for control flow operations
54-
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
55-
into submodules). Specifically, the returned value is a list containing a
56-
tuple of (name of the submodule that's stored in the graph module, the
57+
that are in the given toplevel graph (does not look
58+
into submodules). Specifically, the returned value is a list containing
59+
tuples of (name of the submodule that's stored in the graph module, the
5760
submodule itself, and the fx node that uses this submodule).
5861
"""
5962
control_flow_submodules = []
6063
for node in graph_module.graph.nodes:
6164
if node.op != "call_function":
6265
continue
6366

64-
if node.target is torch.ops.higher_order.cond:
65-
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
66-
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
67-
if node.target is torch.ops.higher_order.map_impl:
68-
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
67+
for op in op_to_submodule_arg_index:
68+
if node.target is not op:
69+
continue
70+
for i in op_to_submodule_arg_index[op]:
71+
control_flow_submodules.append(_get_submodule(graph_module, node, i))
6972

7073
return control_flow_submodules
7174

7275

76+
def get_control_flow_submodules(
77+
graph_module: torch.fx.GraphModule,
78+
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
79+
"""
80+
Returns a list of submodules used for control flow operations
81+
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
82+
into submodules). Specifically, the returned value is a list containing
83+
tuples of (name of the submodule that's stored in the graph module, the
84+
submodule itself, and the fx node that uses this submodule).
85+
"""
86+
return _get_control_flow_submodules(
87+
graph_module,
88+
{torch.ops.higher_order.cond: [1, 2], torch.ops.higher_order.map_impl: [0]},
89+
)
90+
91+
92+
def get_cond_while_submodules(
93+
graph_module: torch.fx.GraphModule,
94+
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
95+
"""
96+
Returns a list of submodules used for control flow operations
97+
(torch.ops.higher_order.cond/while_loop) that are in the given toplevel graph (does not look
98+
into submodules). Specifically, the returned value is a list containing
99+
tuples of (name of the submodule that's stored in the graph module, the
100+
submodule itself, and the fx node that uses this submodule).
101+
"""
102+
return _get_control_flow_submodules(
103+
graph_module,
104+
{
105+
torch.ops.higher_order.cond: [1, 2],
106+
torch.ops.higher_order.while_loop: [0, 1],
107+
},
108+
)
109+
110+
73111
def bfs_trace_with_node_process(
74112
gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None]
75113
) -> None:

0 commit comments

Comments
 (0)