@@ -34,7 +34,7 @@ def get_inputs_for_current_iteration(g, input_id, iter_index):
3434
3535
3636def create_loop_body_graph (parent_g , gather_input_ids , output_data_type , output_shape , trip_count_input_ids ,
37- rank , loop_name ):
37+ rank ):
3838 g = parent_g .create_new_graph_with_same_config ()
3939 g .parent_graph = parent_g
4040 iter_name = utils .make_name ("i" )
@@ -112,9 +112,9 @@ def create_if_op(g, input_ids, output_data_type, output_shape):
112112 out_name = utils .port_name (op_name )
113113
114114 # output a scalar
115- if_node = g . make_node ( "If" , [ input_ids [ 0 ]], outputs = [ out_name ], name = op_name , skip_conversion = True )
116- if_node . set_body_graph_as_attr ( "then_branch " , true_graph )
117- if_node . set_body_graph_as_attr ( "else_branch" , false_graph )
115+ branches = { "then_branch" : true_graph , "else_branch" : false_graph }
116+ if_node = g . make_node ( "If " , [ input_ids [ 0 ]], outputs = [ out_name ], name = op_name ,
117+ skip_conversion = True , branches = branches )
118118 return if_node , out_name
119119
120120
@@ -152,12 +152,11 @@ def create_loop_op(g, gather_input_ids, output_type, output_shape, trip_count_in
152152 cond_var_name , # termination condition
153153 fake_val_name # initial value of loop-carried dependencies
154154 ]
155+ loop_body = create_loop_body_graph (g , gather_input_ids , output_type , output_shape , trip_count_input_ids , rank )
155156 # define an extra scan output
157+ branches = {"body" : loop_body }
156158 loop_node = g .make_node ("Loop" , loop_inputs , output_count = 2 , op_name_scope = "select_loop" ,
157- skip_conversion = False )
158- loop_body = create_loop_body_graph (g , gather_input_ids , output_type , output_shape , trip_count_input_ids ,
159- rank , loop_node .name )
160- loop_node .set_body_graph_as_attr ("body" , loop_body )
159+ skip_conversion = False , branches = branches )
161160 return loop_node
162161
163162
@@ -223,8 +222,9 @@ def make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dt
223222
224223 # loop
225224 loop_inputs = [trip_count_node .output [0 ], cond_name , start ]
226- loop_node = ctx .make_node ("Loop" , loop_inputs , output_count = 2 , op_name_scope = base_name , name = "loop" )
227- loop_node .set_body_graph_as_attr ("body" , g )
225+ branches = {"body" : g }
226+ loop_node = ctx .make_node ("Loop" , loop_inputs ,
227+ output_count = 2 , op_name_scope = base_name , name = "loop" , branches = branches )
228228
229229 ctx .make_node ("Identity" , [loop_node .output [1 ]], name = base_name , shapes = [shape ], dtypes = [dtype ], outputs = [output ])
230230
@@ -404,15 +404,16 @@ def version_1(cls, ctx, node, **kwargs):
404404 ctx .remove_node (node .name )
405405
406406 # replace the original node
407- if_node = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
408- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True )
409-
407+ branches = {}
410408 for branch in ["then_branch" , "else_branch" ]:
411409 func_name = node .get_attr_str (branch )
412410 g = find_function (func_name )
413411 g .parent_graph = ctx
414412 wire_if_branch (ctx , g , inputs , output_shapes , output_dtypes , func_name , node .name )
415- if_node .set_body_graph_as_attr (branch , g )
413+ branches [branch ] = g
414+
415+ _ = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
416+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True , branches = branches )
416417
417418
418419@tf_op (["If" ])
@@ -431,15 +432,16 @@ def version_1(cls, ctx, node, **kwargs):
431432 ctx .remove_node (node .name )
432433
433434 # replace the original node
434- if_node = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
435- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True )
436-
435+ branches = {}
437436 for branch in ["then_branch" , "else_branch" ]:
438437 func_name = node .get_attr_str (branch )
439438 g = find_function (func_name )
440439 g .parent_graph = ctx
441440 wire_if_branch (ctx , g , inputs , output_shapes , output_dtypes , func_name , node .name )
442- if_node .set_body_graph_as_attr (branch , g )
441+ branches [branch ] = g
442+
443+ _ = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
444+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True , branches = branches )
443445
444446
445447@tf_op (["TensorListSetItem" ])
@@ -610,9 +612,11 @@ def version_7(cls, ctx, node, **kwargs):
610612 output_dtypes = output_dtypes [2 :]
611613 output_names = output_names [2 :]
612614
615+ branches = {"body" : body }
613616 loop_node = ctx .make_node ("Loop" , [maximum_iterations_name , cond_outputs [0 ]] + loop_vars ,
614617 output_count = len (output_shapes ), name = node .name + "_loop" ,
615- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True )
618+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
619+ branches = branches )
616620
617621 output_map = dict (zip (output_names , loop_node .output ))
618622
@@ -628,7 +632,6 @@ def version_7(cls, ctx, node, **kwargs):
628632 for i , n in enumerate (body .inputs ):
629633 if body .get_dtype (n .output [0 ]) == onnx_pb .TensorProto .UNDEFINED :
630634 body .set_dtype (n .output [0 ], ctx .get_dtype (loop_node .input [i ]))
631- loop_node .set_body_graph_as_attr ("body" , body )
632635
633636
634637def wire_while_body (parent_g , g , loop_node_inputs , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
@@ -801,13 +804,14 @@ def prefix_graph(g, scope):
801804 attr = node .attr
802805 if node .is_graph_input ():
803806 continue
804- new_node = g .make_node (node .type , node .input , name = node .name , output_count = len (node .output ),
805- shapes = output_shapes , dtypes = output_dtypes , attr = attr ,
806- op_name_scope = scope , skip_conversion = True )
807+ branches = {}
807808 attr_graphs = node .get_body_graphs ()
808809 if attr_graphs :
809810 for k , v in attr_graphs .items ():
810- new_node .set_body_graph_as_attr (k , v )
811+ branches [k ] = v
812+ new_node = g .make_node (node .type , node .input , name = node .name , output_count = len (node .output ),
813+ shapes = output_shapes , dtypes = output_dtypes , attr = attr ,
814+ op_name_scope = scope , skip_conversion = True , branches = branches )
811815 for old_output , new_output in zip (node .output , new_node .output ):
812816 for i , oname in enumerate (g .outputs ):
813817 if old_output == oname :
0 commit comments