Skip to content

Commit 3616eef

Browse files
author
Hariprasad Ravishankar
committed
Fix out-of-bounds indexing with transposedconv negative padding
1 parent e530dca commit 3616eef

File tree

3 files changed

+128
-26
lines changed

3 files changed

+128
-26
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,7 +1583,6 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
15831583
SmallVector<Value> insertSliceOffsets{c0, c0};
15841584

15851585
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
1586-
SmallVector<Value> sliceSizes{inputSizes[0], inputSizes[1]};
15871586

15881587
// For the case in which the padding dimension value is negative,
15891588
// we will need to shrink the dimension. Note in the PyTorch
@@ -1597,19 +1596,27 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
15971596
Value c2 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(2));
15981597

15991598
for (size_t i = 0; i < numSpatialDims; i++) {
1599+
// Calculate inner size: (input_size - 1) * stride + 1
16001600
Value innerSize = rewriter.createOrFold<arith::SubIOp>(loc, inDims[i], c1);
16011601
innerSize = rewriter.createOrFold<arith::MulIOp>(
16021602
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
16031603
innerSize = rewriter.createOrFold<arith::AddIOp>(loc, innerSize, c1);
1604+
innerSizes.push_back(innerSize);
16041605

16051606
Value offset = rewriter.createOrFold<arith::SubIOp>(loc, weightDims[i], c1);
16061607
offset = rewriter.createOrFold<arith::MulIOp>(
16071608
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
16081609
offset = rewriter.createOrFold<arith::SubIOp>(
16091610
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));
16101611

1612+
// We need to crop or pad from two sides - top&bottom or left&right.
1613+
// Therefore multiply by 2.
16111614
Value outerSize = rewriter.createOrFold<arith::MulIOp>(loc, offset, c2);
1615+
1616+
// Crop or pad based on the sign of offset
16121617
outerSize = rewriter.createOrFold<arith::AddIOp>(loc, outerSize, innerSize);
1618+
1619+
// Add optional padding values
16131620
outerSize = rewriter.createOrFold<arith::AddIOp>(
16141621
loc, outerSize,
16151622
castIntToIndex(rewriter, loc, outputPaddingIntValues[i]));
@@ -1624,39 +1631,69 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
16241631
auto posOffset =
16251632
rewriter.createOrFold<arith::MulIOp>(loc, offset, negOneConst);
16261633

1627-
// Compute the reduced dimension size due to negative padding.
1628-
auto sizeReduction =
1629-
rewriter.createOrFold<arith::MulIOp>(loc, posOffset, c2);
1630-
sliceSizes.push_back(rewriter.createOrFold<arith::SubIOp>(
1631-
loc, inputSizes[i + 2], sizeReduction));
1632-
16331634
extractSliceOffsets.push_back(posOffset);
16341635
insertSliceOffsets.push_back(c0);
16351636
} else {
1636-
sliceSizes.push_back(inputSizes[i + 2]);
16371637
extractSliceOffsets.push_back(c0);
16381638
insertSliceOffsets.push_back(offset);
16391639
}
16401640
}
1641-
Value initTensor = createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
16421641

16431642
// Insert input into allocated tensor
16441643
SmallVector<Value> strideIndexValues{c1, c1};
16451644
for (auto stride : strideIntValues)
16461645
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
16471646

1648-
auto insertSliceOpInput = input;
16491647
if (anyDimensionPaddingIsNegative) {
1650-
insertSliceOpInput = rewriter.create<tensor::ExtractSliceOp>(
1648+
1649+
// Some dimensions may need padding and some dimensions need cropping
1650+
1651+
// 1. Allocate a maxSizes buffer (max of inner and outer for each dim)
1652+
// 2. Insert the input into maxSizes buffer at appropriate offsets (if
1653+
// insertSliceOffsets is positive, pad; 0 no padding) and stride
1654+
// 3. Extract the final outerSizes from maxSizes buffer
1655+
1656+
// Create the "max size" tensor to accommodate both padding and cropping
1657+
SmallVector<Value> maxSizes{inBatch, inChannels};
1658+
for (size_t i = 0; i < numSpatialDims; ++i) {
1659+
Value innerDim = innerSizes[i + 2];
1660+
Value outerDim = outerSizes[i + 2];
1661+
Value isPadding = rewriter.create<arith::CmpIOp>(
1662+
loc, arith::CmpIPredicate::ugt, outerDim, innerDim);
1663+
Value maxDim =
1664+
rewriter.create<arith::SelectOp>(loc, isPadding, outerDim, innerDim);
1665+
maxSizes.push_back(maxDim);
1666+
}
1667+
1668+
Value initMaxTensor =
1669+
createInitTensor(rewriter, loc, maxSizes, inputDTy, pad);
1670+
1671+
// Insert input
1672+
auto paddedTensor = rewriter.create<tensor::InsertSliceOp>(
16511673
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
1652-
extractSliceOffsets, sliceSizes, strideIndexValues);
1653-
}
1674+
initMaxTensor, insertSliceOffsets, inputSizes, strideIndexValues);
16541675

1655-
auto paddedInput = rewriter.create<tensor::InsertSliceOp>(
1656-
loc,
1657-
torch_to_linalg::removeSizeInformation(rewriter, loc, insertSliceOpInput),
1658-
initTensor, insertSliceOffsets, sliceSizes, strideIndexValues);
1659-
return paddedInput;
1676+
SmallVector<Value> allOnesStrides(inputSizes.size(), c1);
1677+
1678+
// Crop. Extract the final tensor from the "max" tensor
1679+
auto finalTensor = rewriter.create<tensor::ExtractSliceOp>(
1680+
loc,
1681+
torch_to_linalg::removeSizeInformation(rewriter, loc, paddedTensor),
1682+
extractSliceOffsets, outerSizes, allOnesStrides);
1683+
1684+
return finalTensor;
1685+
1686+
} else {
1687+
1688+
Value initPaddedTensor =
1689+
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
1690+
1691+
// Insert the original input into the outer tensor with calculated offsets
1692+
auto paddedInput = rewriter.create<tensor::InsertSliceOp>(
1693+
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
1694+
initPaddedTensor, insertSliceOffsets, inputSizes, strideIndexValues);
1695+
return paddedInput;
1696+
}
16601697
}
16611698

16621699
namespace {

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,7 +1988,7 @@ def forward(self, inputVec, weight, bias):
19881988
inputVec,
19891989
weight,
19901990
bias=bias,
1991-
stride=[1],
1991+
stride=[4],
19921992
padding=[3],
19931993
dilation=[1],
19941994
transposed=True,
@@ -2034,6 +2034,38 @@ def TransposedConv2dNegativePadding_basic(module, tu: TestUtils):
20342034
module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2))
20352035

20362036

2037+
class TransposedConv2dPositiveAndNegativePadding(torch.nn.Module):
2038+
def __init__(self):
2039+
super().__init__()
2040+
2041+
@export
2042+
@annotate_args(
2043+
[
2044+
None,
2045+
([1, 1, 4, 7], torch.float32, True),
2046+
([1, 2, 3, 3], torch.float32, True),
2047+
([2], torch.float32, True),
2048+
]
2049+
)
2050+
def forward(self, inputVec, weight, bias):
2051+
return torch.ops.aten.convolution(
2052+
inputVec,
2053+
weight,
2054+
bias=bias,
2055+
stride=[4, 4],
2056+
padding=[0, 3],
2057+
dilation=[1, 1],
2058+
transposed=True,
2059+
output_padding=[0, 0],
2060+
groups=1,
2061+
)
2062+
2063+
2064+
@register_test_case(module_factory=lambda: TransposedConv2dPositiveAndNegativePadding())
2065+
def TransposedConv2dPositiveAndNegativePadding_basic(module, tu: TestUtils):
2066+
module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2))
2067+
2068+
20372069
class TransposedConv3dNegativePadding(torch.nn.Module):
20382070
def __init__(self):
20392071
super().__init__()
@@ -2052,7 +2084,7 @@ def forward(self, inputVec, weight, bias):
20522084
inputVec,
20532085
weight,
20542086
bias=bias,
2055-
stride=[1, 1, 1],
2087+
stride=[4, 4, 4],
20562088
padding=[2, 1, 3],
20572089
dilation=[1, 1, 1],
20582090
transposed=True,

test/Conversion/TorchToLinalg/convolution.mlir

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,17 @@ func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>)
152152
}
153153

154154
// CHECK-LABEL: func.func @tranConv2dNegativePadding(
155-
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32>
156-
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
157-
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[IN_TENSOR]][0, 0, 0, 1] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x7xf32> to tensor<1x1x4x5xf32>
158-
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[INIT_TENSOR:.*]][0, 0, 2, 0] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x5xf32> into tensor<1x1x8x5xf32>
159-
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[INSERTED_SLICE]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
160-
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
155+
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32> attributes {torch.assume_strict_symbolic_shapes} {
156+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
157+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
158+
// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
159+
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
160+
// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x8x7xf32>
161+
// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x8x7xf32>) -> tensor<1x1x8x7xf32>
162+
// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 1, 1] : tensor<1x1x4x7xf32> into tensor<1x1x8x7xf32>
163+
// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 8, 5] [1, 1, 1, 1] : tensor<1x1x8x7xf32> to tensor<1x1x8x5xf32>
164+
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
165+
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
161166
func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) -> !torch.vtensor<[1, 2, 6, 3],f32> attributes {torch.assume_strict_symbolic_shapes} {
162167
%int0 = torch.constant.int 0
163168
%true = torch.constant.bool true
@@ -174,3 +179,31 @@ func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) ->
174179
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int1 : !torch.vtensor<[1, 1, 4, 7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1, 2, 6, 3],f32>
175180
return %6 : !torch.vtensor<[1, 2, 6, 3],f32>
176181
}
182+
183+
// CHECK-LABEL: func.func @tranConv2dNegativeAndPositivePadding(
184+
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>,
185+
// CHECK-SAME: %[[WEIGHTS_VTENSOR:.*]]: !torch.vtensor<[1,2,3,3],f32>,
186+
// CHECK-SAME: %[[BIAS_VTENSOR:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> {
187+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
188+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
189+
// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
190+
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
191+
// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x17x25xf32>
192+
// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x17x25xf32>) -> tensor<1x1x17x25xf32>
193+
// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 4, 4] : tensor<1x1x4x7xf32> into tensor<1x1x17x25xf32>
194+
// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 17, 23] [1, 1, 1, 1] : tensor<1x1x17x25xf32> to tensor<1x1x17x23xf32>
195+
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x17x23xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x15x21xf32>) -> tensor<1x2x15x21xf32>
196+
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x15x21xf32> -> !torch.vtensor<[1,2,15,21],f32>
197+
func.func @tranConv2dNegativeAndPositivePadding(%arg0: !torch.vtensor<[1,1,4,7],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> {
198+
%int1 = torch.constant.int 1
199+
%int3 = torch.constant.int 3
200+
%int0 = torch.constant.int 0
201+
%int4 = torch.constant.int 4
202+
%true = torch.constant.bool true
203+
%0 = torch.prim.ListConstruct %int4, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
204+
%1 = torch.prim.ListConstruct %int0, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
205+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
206+
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
207+
%4 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %1, %2, %true, %3, %int1 : !torch.vtensor<[1,1,4,7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,15,21],f32>
208+
return %4 : !torch.vtensor<[1,2,15,21],f32>
209+
}

0 commit comments

Comments
 (0)