@@ -429,7 +429,18 @@ def version_1(cls, ctx, node, **kwargs):
429429
430430 node .set_attr ("output_shape" , new_output_shape )
431431 else :
432- # FIXME: This case fails in edge cases where strides > 1
432+ utils .make_sure (ctx .opset >= 10 , "Opset 10 needed for Conv Backprop Input with non-constant shape" )
433+ strides = parse_dims_attr (node , node .get_attr ('strides' ).ints , spatial )
434+ use_strides_workaround = any (d > 1 for d in strides )
435+ if use_strides_workaround and ctx .opset < 12 :
436+ # When strides > 1, ONNX and TF have an implementation difference in ConvTranspose. ONNX outputs a
437+ # slightly smaller tensor which must be padded with a row of 0s. Pad with dynamic shape requires
438+ # opset >= 11 and Max of int64 needs opset >= 12. Depending on the output_shape, this row of 0s might
439+ # be shaved off, in which case TF and ONNX agree. When output_shape is dynamic it is impossible to
440+ # know at conversion time whether this is the case and the workaround is needed.
441+ logger .warning ("Conv Backprop Input with strides > 1 and non-constant shape has known bug. "
442+ "Workaround requires opset 12." )
443+ use_strides_workaround = False
433444 input_shape = ctx .make_node ("Cast" , [node .input [0 ]], attr = {'to' : TensorProto .INT64 })
434445 output_shape = ctx .make_node ("Shape" , [node .output [0 ]])
435446 output_h = GraphBuilder (ctx ).make_slice (
@@ -442,9 +453,17 @@ def version_1(cls, ctx, node, **kwargs):
442453 {"data" : input_shape .output [0 ], "ends" : [3 ], "starts" : [2 ], "axes" : [0 ]})
443454 diff_h = ctx .make_node ("Sub" , [output_h , expect_h ])
444455 diff_w = ctx .make_node ("Sub" , [output_w , expect_w ])
456+ nonneg_diff_h = diff_h
457+ nonneg_diff_w = diff_w
458+
459+ if use_strides_workaround :
460+ const_zero = ctx .make_const (utils .make_name (node .name + "_const_zero" ), np .array ([0 ], dtype = np .int64 ))
461+ nonneg_diff_h = ctx .make_node ("Max" , [diff_h .output [0 ], const_zero .output [0 ]])
462+ nonneg_diff_w = ctx .make_node ("Max" , [diff_w .output [0 ], const_zero .output [0 ]])
463+
445464 const_two = ctx .make_const (utils .make_name (node .name + "_const_two" ), np .array ([2 ], dtype = np .int64 ))
446- start_h = ctx .make_node ("Div" , [diff_h .output [0 ], const_two .output [0 ]])
447- start_w = ctx .make_node ("Div" , [diff_w .output [0 ], const_two .output [0 ]])
465+ start_h = ctx .make_node ("Div" , [nonneg_diff_h .output [0 ], const_two .output [0 ]])
466+ start_w = ctx .make_node ("Div" , [nonneg_diff_w .output [0 ], const_two .output [0 ]])
448467 end_h = ctx .make_node ("Add" , [start_h .output [0 ], expect_h ])
449468 end_w = ctx .make_node ("Add" , [start_w .output [0 ], expect_w ])
450469 if spatial == 3 :
@@ -453,7 +472,10 @@ def version_1(cls, ctx, node, **kwargs):
453472 expect_d = GraphBuilder (ctx ).make_slice (
454473 {"data" : input_shape .output [0 ], "ends" : [4 ], "starts" : [3 ], "axes" : [0 ]})
455474 diff_d = ctx .make_node ("Sub" , [output_d , expect_d ])
456- start_d = ctx .make_node ("Div" , [diff_d .output [0 ], const_two .output [0 ]])
475+ nonneg_diff_d = diff_d
476+ if use_strides_workaround :
477+ nonneg_diff_d = ctx .make_node ("Max" , [diff_d .output [0 ], const_zero .output [0 ]])
478+ start_d = ctx .make_node ("Div" , [nonneg_diff_d .output [0 ], const_two .output [0 ]])
457479 end_d = ctx .make_node ("Add" , [start_d .output [0 ], expect_d ])
458480
459481 starts = ctx .make_node ("Concat" , [start_h .output [0 ], start_w .output [0 ], start_d .output [0 ]],
@@ -471,10 +493,35 @@ def version_1(cls, ctx, node, **kwargs):
471493 [node .output [0 ], starts .output [0 ], ends .output [0 ], slice_axes .output [0 ]],
472494 shapes = output_shape_orig )
473495
496+ final_node = slice_node
497+
498+ if use_strides_workaround :
499+ cz = const_zero .output [0 ]
500+
501+ neg_diff_h = ctx .make_node ("Neg" , [diff_h .output [0 ]])
502+ shrink_h_by = ctx .make_node ("Max" , [neg_diff_h .output [0 ], const_zero .output [0 ]])
503+ shb = shrink_h_by .output [0 ]
504+
505+ neg_diff_w = ctx .make_node ("Neg" , [diff_w .output [0 ]])
506+ shrink_w_by = ctx .make_node ("Max" , [neg_diff_w .output [0 ], const_zero .output [0 ]])
507+ swb = shrink_w_by .output [0 ]
508+
509+ if spatial == 3 :
510+ neg_diff_d = ctx .make_node ("Neg" , [diff_d .output [0 ]])
511+ shrink_d_by = ctx .make_node ("Max" , [neg_diff_d .output [0 ], const_zero .output [0 ]])
512+ sdb = shrink_d_by .output [0 ]
513+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , shb , swb , sdb , cz ], attr = {"axis" : 0 })
514+ padded_node = ctx .make_node ("Pad" , [slice_node .output [0 ], pads .output [0 ]])
515+ else :
516+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , shb , swb , cz ], attr = {"axis" : 0 })
517+ padded_node = ctx .make_node ("Pad" , [slice_node .output [0 ], pads .output [0 ]])
518+
519+ final_node = padded_node
520+
474521 downstream_nodes = ctx .find_output_consumers (node .output [0 ])
475522 downstream_nodes .remove (output_shape )
476523 downstream_nodes .remove (slice_node )
477- ctx .replace_all_inputs (node .output [0 ], slice_node .output [0 ], ops = downstream_nodes )
524+ ctx .replace_all_inputs (node .output [0 ], final_node .output [0 ], ops = downstream_nodes )
478525
479526 conv_dims_attr (node , "strides" , spatial = spatial )
480527 conv_dims_attr (node , "dilations" , spatial = spatial )
0 commit comments