@@ -140,6 +140,32 @@ def compress_graph_def(graph_def):
140140 tensor .tensor_content = b''
141141 return const_node_values
142142
143+ def get_index_from_strided_slice_of_shape (node , outputs_to_values ):
144+ """Returns the index of the dimension that the strided slice is reading from the shape node or None"""
145+ attr_vals = {
146+ 'shrink_axis_mask' : 1 ,
147+ 'ellipsis_mask' : 0 ,
148+ 'begin_mask' : 0 ,
149+ 'new_axis_mask' : 0 ,
150+ 'end_mask' : 0
151+ }
152+ for a in node .node_def .attr :
153+ if a in attr_vals :
154+ i = get_tf_node_attr (node , a )
155+ if i != attr_vals [a ]:
156+ return None
157+ i1 = outputs_to_values .get (node .inputs [1 ].name )
158+ i2 = outputs_to_values .get (node .inputs [2 ].name )
159+ i3 = outputs_to_values .get (node .inputs [3 ].name )
160+ if i1 is None or i2 is None or i3 is None :
161+ return None
162+ if i1 .shape != (1 ,) or i2 .shape != (1 ,) or i3 .shape != (1 ,):
163+ return None
164+ i1 , i2 , i3 = i1 [0 ], i2 [0 ], i3 [0 ]
165+ if i1 + 1 != i2 or i3 != 1 :
166+ return None
167+ return i1
168+
143169def compute_const_folding_using_tf (g , const_node_values ):
144170 """Find nodes with constant inputs and compute their values using TF"""
145171 if const_node_values is None :
@@ -149,6 +175,8 @@ def compute_const_folding_using_tf(g, const_node_values):
149175 ops = g .get_operations ()
150176 outputs_to_values = {}
151177 outputs_to_dtypes = {}
178+ outputs_to_shapes = {}
179+ shape_node_outputs = {}
152180
153181 for node in ops :
154182 # Load values of constants. Use const_node_values if possible
@@ -158,6 +186,14 @@ def compute_const_folding_using_tf(g, const_node_values):
158186 tensor .tensor_content = const_node_values [node .name ]
159187 outputs_to_values [node .outputs [0 ].name ] = get_tf_tensor_data (tensor )
160188 outputs_to_dtypes [node .outputs [0 ].name ] = node .outputs [0 ].dtype
189+ for out in node .outputs :
190+ outputs_to_shapes [out .name ] = get_tf_tensor_shape (out )
191+
192+ for node in ops :
193+ if node .type == "Shape" :
194+ shape = outputs_to_shapes .get (node .inputs [0 ].name )
195+ if shape is not None :
196+ shape_node_outputs [node .outputs [0 ].name ] = shape
161197
162198 unneeded_outputs = set ()
163199 progress = True
@@ -167,6 +203,14 @@ def compute_const_folding_using_tf(g, const_node_values):
167203 # Find ops with constant inputs and compute their values
168204 input_names = [i .name for i in node .inputs ]
169205 output_names = [i .name for i in node .outputs ]
206+ if node .type == 'StridedSlice' and input_names [0 ] in shape_node_outputs \
207+ and output_names [0 ] not in outputs_to_values :
208+ shape = shape_node_outputs [input_names [0 ]]
209+ i = get_index_from_strided_slice_of_shape (node , outputs_to_values )
210+ if i is not None and 0 <= i < len (shape ) and shape [i ] is not None :
211+ outputs_to_values [output_names [0 ]] = np .array (shape [i ])
212+ outputs_to_dtypes [node .outputs [0 ].name ] = node .outputs [0 ].dtype
213+ progress = True
170214 can_fold = node .type not in ['Enter' ]
171215 can_fold = can_fold and len (input_names ) > 0 and all (inp in outputs_to_values for inp in input_names )
172216 # We can only fold nodes with a single output
0 commit comments