Skip to content

Commit 85efc8f

Browse files
authored
Add alias to all nodes if num_users > 1 (#111)
* Add alias to all nodes if num_users > 1 * Make it optional
1 parent bf39515 commit 85efc8f

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

autoparallel/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def build_model_graph(self):
277277
_replace_view_mm_view_with_einsum(gm)
278278
# now add aliases nodes to the graph to
279279
# give more room for optimizations
280-
_add_alias(gm)
280+
_add_alias(gm, version="v1")
281281
trace_structured(
282282
"artifact",
283283
metadata_fn=lambda: {

autoparallel/graph_utils.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def update_joint_with_descriptors(
7373
joint_with_descriptors._aot_state.fw_metadata.traced_tangents = new_local_tangents
7474

7575

76-
def _add_alias(gm):
76+
def _add_alias(gm, version="v1"):
7777
"""
7878
Helper function to add alias nodes to every node in the graph
7979
this gives more configuration opportunities
@@ -82,16 +82,8 @@ def _add_alias(gm):
8282

8383
nodes = [n for n in graph.nodes if n.op == "call_function"]
8484
node_map = {node: idx for idx, node in enumerate(nodes)}
85-
inputs = graph.find_nodes(op="placeholder")
86-
for node in inputs:
87-
if len(node.users) == 0:
88-
# node is not used, don't add alias for it
89-
continue
90-
if (
91-
len(node.users) == 1
92-
and list(node.users)[0].target == torch.ops.autoparallel.dtype_cast.default
93-
):
94-
node = list(node.users)[0]
85+
86+
def _insert_alias(node):
9587
first_user = nodes[min(node_map[n] for n in node.users)]
9688
with graph.inserting_before(first_user):
9789
alias_node = graph.call_function(torch.ops.aten.alias.default, args=(node,))
@@ -102,6 +94,32 @@ def delete_user_cb(n):
10294

10395
node.replace_all_uses_with(alias_node, delete_user_cb=delete_user_cb)
10496

97+
inputs = graph.find_nodes(op="placeholder")
98+
if version == "v1":
99+
# only on inputs
100+
for node in inputs:
101+
if len(node.users) == 0:
102+
# node is not used, don't add alias for it
103+
continue
104+
if (
105+
len(node.users) == 1
106+
and list(node.users)[0].target
107+
== torch.ops.autoparallel.dtype_cast.default
108+
):
109+
node = list(node.users)[0]
110+
_insert_alias(node)
111+
elif version == "v2":
112+
# for every node that has more than one user
113+
for node in inputs + nodes:
114+
if len(node.users) < 2:
115+
continue
116+
# don't add alias for ops which return tuple for now
117+
if not isinstance(node.meta["val"], torch.Tensor):
118+
continue
119+
_insert_alias(node)
120+
else:
121+
raise ValueError(f"Unknown version {version}")
122+
105123
"""
106124
for node in nodes:
107125
# skip ops which return tuple

0 commit comments

Comments
 (0)