@@ -132,7 +132,7 @@ def version_6(cls, ctx, node, **kwargs):
132132 node .type = "Sum"
133133
134134
135- @tf_op ("SegmentSum" )
135+ @tf_op ([ "SegmentSum" , "SegmentProd" , "SegmentMax" , "SegmentMin" ] )
136136class SegmentSum ():
137137 @classmethod
138138 def version_9 (cls , ctx , node , ** kwargs ):
@@ -143,20 +143,48 @@ def version_9(cls, ctx, node, **kwargs):
143143 data_rank = len (data_shape )
144144 data_np_dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (data_inp ))
145145 seg_np_dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (segment_inp ))
146+ data_is_float = np .dtype (data_np_dtype ).kind == 'f'
147+ data_is_int = np .dtype (data_np_dtype ).kind == 'i'
148+ utils .make_sure (data_is_float or data_is_int , "dtype for Segment ops must be float or int" )
149+
150+ if node .type == "SegmentSum" :
151+ onnx_op = "ReduceSum"
152+ identity_value = np .array (0 , dtype = data_np_dtype )
153+ elif node .type == "SegmentProd" :
154+ onnx_op = "ReduceProd"
155+ identity_value = np .array (1 , dtype = data_np_dtype )
156+ elif node .type == "SegmentMax" :
157+ onnx_op = "ReduceMax"
158+ if data_is_float :
159+ identity_value = np .array ('-inf' , dtype = data_np_dtype )
160+ else :
161+ identity_value = np .iinfo (data_np_dtype ).min
162+ elif node .type == "SegmentMin" :
163+ onnx_op = "ReduceMin"
164+ if data_is_float :
165+ identity_value = np .array ('inf' , dtype = data_np_dtype )
166+ else :
167+ identity_value = np .iinfo (data_np_dtype ).max
168+
146169 max_segment = ctx .make_node ("ReduceMax" , [segment_inp ], attr = {'axes' : [0 ], 'keepdims' : 0 })
147170 one_const = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , dtype = seg_np_dtype ))
171+ identity_const = ctx .make_const (utils .make_name ("const_identity" ), identity_value )
148172 num_segments = ctx .make_node ("Add" , [max_segment .output [0 ], one_const .output [0 ]])
149- onehot_values = ctx .make_const (utils .make_name ("onehot_values" ), np .array ([0 , 1 ], dtype = data_np_dtype ))
150- one_hot_node = ctx .make_node ("OneHot" , [segment_inp , num_segments .output [0 ], onehot_values .output [0 ]], attr = {'axis' : 0 })
151- one_hot_unsqueeze = one_hot_node
173+ # ORT doesn't support bool for OneHot so we use float32 and cast to bool
174+ onehot_values = ctx .make_const (utils .make_name ("onehot_values" ), np .array ([0 , 1 ], dtype = np .float32 ))
175+ one_hot_node = ctx .make_node ("OneHot" , [segment_inp , num_segments .output [0 ], onehot_values .output [0 ]],
176+ attr = {'axis' : 0 })
177+ one_hot_bool = ctx .make_node ("Cast" , [one_hot_node .output [0 ]], attr = {"to" : onnx_pb .TensorProto .BOOL })
178+ one_hot_unsqueeze = one_hot_bool
179+
152180 if data_rank > 1 :
153181 new_dims = list (range (2 , 2 + data_rank - 1 ))
154- one_hot_unsqueeze = ctx .make_node ("Unsqueeze" , [one_hot_node .output [0 ]], attr = {'axes' : new_dims })
182+ one_hot_unsqueeze = ctx .make_node ("Unsqueeze" , [one_hot_bool .output [0 ]], attr = {'axes' : new_dims })
155183
156- mul_node = ctx .make_node ("Mul " , [data_inp , one_hot_unsqueeze .output [0 ]])
184+ mul_node = ctx .make_node ("Where " , [one_hot_unsqueeze . output [ 0 ], data_inp , identity_const .output [0 ]])
157185
158186 shapes = node .output_shapes
159187 dtypes = node .output_dtypes
160188 ctx .remove_node (node .name )
161- sum_node = ctx .make_node ("ReduceSum" , [mul_node .output [0 ]], attr = {'axes' : [1 ], 'keepdims' : 0 },
162- name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
189+ ctx .make_node (onnx_op , [mul_node .output [0 ]], attr = {'axes' : [1 ], 'keepdims' : 0 },
190+ name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
0 commit comments