@@ -73,7 +73,7 @@ def update_joint_with_descriptors(
7373 joint_with_descriptors ._aot_state .fw_metadata .traced_tangents = new_local_tangents
7474
7575
76- def _add_alias (gm ):
76+ def _add_alias (gm , version = "v1" ):
7777 """
7878 Helper function to add alias nodes to every node in the graph
7979 this gives more configuration opportunities
@@ -82,16 +82,8 @@ def _add_alias(gm):
8282
8383 nodes = [n for n in graph .nodes if n .op == "call_function" ]
8484 node_map = {node : idx for idx , node in enumerate (nodes )}
85- inputs = graph .find_nodes (op = "placeholder" )
86- for node in inputs :
87- if len (node .users ) == 0 :
88- # node is not used, don't add alias for it
89- continue
90- if (
91- len (node .users ) == 1
92- and list (node .users )[0 ].target == torch .ops .autoparallel .dtype_cast .default
93- ):
94- node = list (node .users )[0 ]
85+
86+ def _insert_alias (node ):
9587 first_user = nodes [min (node_map [n ] for n in node .users )]
9688 with graph .inserting_before (first_user ):
9789 alias_node = graph .call_function (torch .ops .aten .alias .default , args = (node ,))
@@ -102,6 +94,32 @@ def delete_user_cb(n):
10294
10395 node .replace_all_uses_with (alias_node , delete_user_cb = delete_user_cb )
10496
97+ inputs = graph .find_nodes (op = "placeholder" )
98+ if version == "v1" :
99+ # only on inputs
100+ for node in inputs :
101+ if len (node .users ) == 0 :
102+ # node is not used, don't add alias for it
103+ continue
104+ if (
105+ len (node .users ) == 1
106+ and list (node .users )[0 ].target
107+ == torch .ops .autoparallel .dtype_cast .default
108+ ):
109+ node = list (node .users )[0 ]
110+ _insert_alias (node )
111+ elif version == "v2" :
112+ # for every node that has more than one user
113+ for node in inputs + nodes :
114+ if len (node .users ) < 2 :
115+ continue
116+ # don't add alias for ops which return tuple for now
117+ if not isinstance (node .meta ["val" ], torch .Tensor ):
118+ continue
119+ _insert_alias (node )
120+ else :
121+ raise ValueError (f"Unknown version { version } " )
122+
105123 """
106124 for node in nodes:
107125 # skip ops which return tuple
0 commit comments