Skip to content

Commit 2dc17fc

Browse files
hariprasadraviHariprasad Ravishankar
authored andcommitted
[LINALG] Fix: Incorrect linalg lowering for aten.convolution_transpose with negative effective padding (llvm#4369)
The `torch-to-linalg` lowering for `aten.convolution` (with `transposed=true`) incorrectly handles cases where the effective padding is negative. The logic for this is contained in `createTransposedInputPadding`. The original implementation had two critical flaws: **Incorrect Math**: The logic block for negative padding (if (anyDimensionPaddingIsNegative)) attempted to "pre-crop" the input tensor before un-striding. The math used to calculate these slice offsets and sizes was incorrect, resulting in `tensor.extract_slice` operations with out-of-bounds offsets and negative sizes, causing the compiler to fail. **Failed "Mixed-Mode**" **Logic**: The code was built on an "all-or-nothing" assumption. It failed to handle "mixed-mode" padding, where one spatial dimension required padding (positive offset) while another required cropping (negative offset). It would enter the negative padding path and apply cropping logic to all dimensions, leading to out-of-bounds errors when it tried to crop a dimension that should have been padded. This patch refactors the logic into two clean, robust paths: **All-Padding Path (else block):** Trigger: All spatial dimensions have an effective padding offset >= 0. Action: Retains the original, efficient "fast path." It uses a single `tensor.insert_slice` to perform both un-striding (with strides) and padding (with offsets) in one operation. **Safe Path (if (anyDimensionPaddingIsNegative) block):** Trigger: At least one spatial dimension has a negative effective padding offset. Action: This path is now a unified, robust 3-step process that correctly handles both all-crop and mixed-mode scenarios: Create "Super-Tensor": It computes a maxSizes tensor, which is the "union" of the padded and un-strided sizes (i.e., max(innerSize, outerSize) for each dimension). Pad & Un-stride: It performs a single `tensor.insert_slice` of the original input into this maxSizes tensor. This one operation correctly applies all positive padding (via insertSliceOffsets) and un-striding (via strideIndexValues). Crop: It performs a final `tensor.extract_slice` to crop the maxSizes tensor down to the final outerSizes. This correctly applies all negative padding (via extractSliceOffsets). This new logic resolves all known failure cases and is validated by the new TransposedConv{1,2,3}dNegativePadding test cases, which specifically target this functionality. --------- Co-authored-by: Hariprasad Ravishankar <hravisha@ah-hravisha-l.dhcp.mathworks.com> Co-authored-by: Hariprasad Ravishankar <hravisha@mathworks.com>
1 parent b4c9774 commit 2dc17fc

File tree

4 files changed

+223
-36
lines changed

4 files changed

+223
-36
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,6 +1570,25 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
15701570
};
15711571
} // namespace
15721572

1573+
/*
1574+
* Calculates the dimensions and offsets needed to emulate a Transposed
1575+
* Convolution (like PyTorch's ConvTranspose2d) using a standard
1576+
* Forward Convolution.
1577+
*
1578+
* This involves creating a new tensor by:
1579+
* 1. Calculating `innerSizes`: The input size after dilation by `stride`.
1580+
* innerSize[i] = (inDim[i] - 1) * stride[i] + 1
1581+
*
1582+
* 2. Calculating `outerSizes`: The final padded tensor size.
1583+
* offset[i] = (weightDim[i] - 1) * dilation[i] - padding[i]
1584+
* outerSize[i] = innerSize[i] + (2 * offset[i]) + outputPadding[i]
1585+
*
1586+
* If `offset[i]` is negative, this is treated as *cropping* the
1587+
* `innerSizes` tensor. This function calculates the
1588+
* `insertSliceOffsets` (padding) and `extractSliceOffsets` (cropping)
1589+
* to correctly place the (potentially cropped) inner tensor within the
1590+
* new outer tensor.
1591+
*/
15731592
Value ConvertAtenConvolutionOp::createTransposedInputPadding(
15741593
Value inBatch, Value inChannels, SmallVector<Value> &inDims,
15751594
SmallVector<Value> &weightDims, SmallVector<Value> &paddingIntValues,
@@ -1583,33 +1602,34 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
15831602
SmallVector<Value> insertSliceOffsets{c0, c0};
15841603

15851604
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
1586-
SmallVector<Value> sliceSizes{inputSizes[0], inputSizes[1]};
1587-
1588-
// For the case in which the padding dimension value is negative,
1589-
// we will need to shrink the dimension. Note in the PyTorch
1590-
// ConvTranspose2d operator documentation that the padding is
1591-
// defined by dilation * (kernel_size - 1) - padding. If the
1592-
// resulting padding is negative, PyTorch will extract elements
1593-
// from both sides of the dimension.
1605+
15941606
SmallVector<Value> extractSliceOffsets{c0, c0};
15951607
bool anyDimensionPaddingIsNegative = false;
15961608

15971609
Value c2 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(2));
15981610

15991611
for (size_t i = 0; i < numSpatialDims; i++) {
1612+
// Calculate inner size: (input_size - 1) * stride + 1
16001613
Value innerSize = rewriter.createOrFold<arith::SubIOp>(loc, inDims[i], c1);
16011614
innerSize = rewriter.createOrFold<arith::MulIOp>(
16021615
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
16031616
innerSize = rewriter.createOrFold<arith::AddIOp>(loc, innerSize, c1);
1617+
innerSizes.push_back(innerSize);
16041618

16051619
Value offset = rewriter.createOrFold<arith::SubIOp>(loc, weightDims[i], c1);
16061620
offset = rewriter.createOrFold<arith::MulIOp>(
16071621
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
16081622
offset = rewriter.createOrFold<arith::SubIOp>(
16091623
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));
16101624

1625+
// We need to crop or pad from two sides - top&bottom or left&right.
1626+
// Therefore multiply by 2.
16111627
Value outerSize = rewriter.createOrFold<arith::MulIOp>(loc, offset, c2);
1628+
1629+
// Crop or pad based on the sign of offset
16121630
outerSize = rewriter.createOrFold<arith::AddIOp>(loc, outerSize, innerSize);
1631+
1632+
// Add optional padding values
16131633
outerSize = rewriter.createOrFold<arith::AddIOp>(
16141634
loc, outerSize,
16151635
castIntToIndex(rewriter, loc, outputPaddingIntValues[i]));
@@ -1619,44 +1639,73 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
16191639
// Make the negative value positive by multiplying by -1.
16201640
anyDimensionPaddingIsNegative = true;
16211641
auto offsetType = offset.getType();
1622-
auto negOneConst = rewriter.createOrFold<arith::ConstantOp>(
1623-
loc, offsetType, rewriter.getIntegerAttr(offsetType, -1));
1642+
auto negOneConst = rewriter.create<arith::ConstantOp>(
1643+
loc, rewriter.getIntegerAttr(offsetType, -1));
16241644
auto posOffset =
16251645
rewriter.createOrFold<arith::MulIOp>(loc, offset, negOneConst);
16261646

1627-
// Compute the reduced dimension size due to negative padding.
1628-
auto sizeReduction =
1629-
rewriter.createOrFold<arith::MulIOp>(loc, posOffset, c2);
1630-
sliceSizes.push_back(rewriter.createOrFold<arith::SubIOp>(
1631-
loc, inputSizes[i + 2], sizeReduction));
1632-
16331647
extractSliceOffsets.push_back(posOffset);
16341648
insertSliceOffsets.push_back(c0);
16351649
} else {
1636-
sliceSizes.push_back(inputSizes[i + 2]);
16371650
extractSliceOffsets.push_back(c0);
16381651
insertSliceOffsets.push_back(offset);
16391652
}
16401653
}
1641-
Value initTensor = createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
16421654

16431655
// Insert input into allocated tensor
16441656
SmallVector<Value> strideIndexValues{c1, c1};
16451657
for (auto stride : strideIntValues)
16461658
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
16471659

1648-
auto insertSliceOpInput = input;
16491660
if (anyDimensionPaddingIsNegative) {
1650-
insertSliceOpInput = rewriter.create<tensor::ExtractSliceOp>(
1661+
// Some dimensions may need padding and some dimensions need cropping
1662+
1663+
// 1. Allocate a maxSizes buffer (max of inner and outer for each dim)
1664+
// 2. Insert the input into maxSizes buffer at appropriate offsets (if
1665+
// insertSliceOffsets is positive, pad; 0 no padding) and stride
1666+
// 3. Extract the final outerSizes from maxSizes buffer
1667+
1668+
// Create the "max size" tensor to accommodate both padding and cropping
1669+
SmallVector<Value> maxSizes{inBatch, inChannels};
1670+
for (size_t i = 0; i < numSpatialDims; ++i) {
1671+
Value innerDim = innerSizes[i + 2];
1672+
Value outerDim = outerSizes[i + 2];
1673+
Value isPadding = rewriter.createOrFold<arith::CmpIOp>(
1674+
loc, arith::CmpIPredicate::ugt, outerDim, innerDim);
1675+
Value maxDim = rewriter.createOrFold<arith::SelectOp>(loc, isPadding,
1676+
outerDim, innerDim);
1677+
maxSizes.push_back(maxDim);
1678+
}
1679+
1680+
Value initMaxTensor =
1681+
createInitTensor(rewriter, loc, maxSizes, inputDTy, pad);
1682+
1683+
// Insert input
1684+
auto paddedTensor = rewriter.create<tensor::InsertSliceOp>(
16511685
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
1652-
extractSliceOffsets, sliceSizes, strideIndexValues);
1653-
}
1686+
initMaxTensor, insertSliceOffsets, inputSizes, strideIndexValues);
16541687

1655-
auto paddedInput = rewriter.create<tensor::InsertSliceOp>(
1656-
loc,
1657-
torch_to_linalg::removeSizeInformation(rewriter, loc, insertSliceOpInput),
1658-
initTensor, insertSliceOffsets, sliceSizes, strideIndexValues);
1659-
return paddedInput;
1688+
SmallVector<Value> allOnesStrides(inputSizes.size(), c1);
1689+
1690+
// Crop. Extract the final tensor from the "max" tensor
1691+
auto finalTensor = rewriter.create<tensor::ExtractSliceOp>(
1692+
loc,
1693+
torch_to_linalg::removeSizeInformation(rewriter, loc, paddedTensor),
1694+
extractSliceOffsets, outerSizes, allOnesStrides);
1695+
1696+
return finalTensor;
1697+
1698+
} else {
1699+
1700+
Value initPaddedTensor =
1701+
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
1702+
1703+
// Insert the original input into the outer tensor with calculated offsets
1704+
auto paddedInput = rewriter.create<tensor::InsertSliceOp>(
1705+
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
1706+
initPaddedTensor, insertSliceOffsets, inputSizes, strideIndexValues);
1707+
return paddedInput;
1708+
}
16601709
}
16611710

16621711
namespace {

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3268,6 +3268,7 @@
32683268
"TraceSignedIntModule_basic",
32693269
"TraceUnsignedIntModule_basic",
32703270
"TraceUnsignedIntModule_empty",
3271+
"TransposedConv1dNegativePaddingUnitStrideDyn_basic",
32713272
"UniformModule_basic",
32723273
"UniformNoCorrelationModule_basic",
32733274
"UniformStaticShapeModule_basic",
@@ -3966,7 +3967,10 @@
39663967
"TraceModule_empty",
39673968
"TraceUnsignedIntModule_empty",
39683969
"TransposedConv1dNegativePadding_basic",
3970+
"TransposedConv1dNegativePaddingUnitStrideDyn_basic",
3971+
"TransposedConv1dNegativePaddingLarge_basic",
39693972
"TransposedConv2dNegativePadding_basic",
3973+
"TransposedConv2dPositiveAndNegativePadding_basic",
39703974
"TransposedConv3dNegativePadding_basic",
39713975
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
39723976
"InterpolateDynamicModule_sizes_nearest",
@@ -5046,7 +5050,10 @@
50465050
"TraceUnsignedIntModule_basic",
50475051
"TraceUnsignedIntModule_empty",
50485052
"TransposedConv1dNegativePadding_basic",
5053+
"TransposedConv1dNegativePaddingUnitStrideDyn_basic",
5054+
"TransposedConv1dNegativePaddingLarge_basic",
50495055
"TransposedConv2dNegativePadding_basic",
5056+
"TransposedConv2dPositiveAndNegativePadding_basic",
50505057
"TransposedConv3dNegativePadding_basic",
50515058
"TupleModule_basic",
50525059
"TypeAsDifferentModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,7 +1988,7 @@ def forward(self, inputVec, weight, bias):
19881988
inputVec,
19891989
weight,
19901990
bias=bias,
1991-
stride=[1],
1991+
stride=[4],
19921992
padding=[3],
19931993
dilation=[1],
19941994
transposed=True,
@@ -2002,6 +2002,72 @@ def TransposedConv1dNegativePadding_basic(module, tu: TestUtils):
20022002
module.forward(tu.rand(1, 1, 7), tu.rand(1, 2, 3), tu.rand(2))
20032003

20042004

2005+
class TransposedConv1dNegativePaddingUnitStrideDyn(torch.nn.Module):
2006+
def __init__(self):
2007+
super().__init__()
2008+
2009+
@export
2010+
@annotate_args(
2011+
[
2012+
None,
2013+
([-1, -1, -1], torch.float32, True),
2014+
([1, 2, 3], torch.float32, True),
2015+
([2], torch.float32, True),
2016+
]
2017+
)
2018+
def forward(self, inputVec, weight, bias):
2019+
return torch.ops.aten.convolution(
2020+
inputVec,
2021+
weight,
2022+
bias=bias,
2023+
stride=[1],
2024+
padding=[3],
2025+
dilation=[1],
2026+
transposed=True,
2027+
output_padding=[0],
2028+
groups=1,
2029+
)
2030+
2031+
2032+
@register_test_case(
2033+
module_factory=lambda: TransposedConv1dNegativePaddingUnitStrideDyn()
2034+
)
2035+
def TransposedConv1dNegativePaddingUnitStrideDyn_basic(module, tu: TestUtils):
2036+
module.forward(tu.rand(1, 1, 7), tu.rand(1, 2, 3), tu.rand(2))
2037+
2038+
2039+
class TransposedConv1dNegativePaddingLarge(torch.nn.Module):
2040+
def __init__(self):
2041+
super().__init__()
2042+
2043+
@export
2044+
@annotate_args(
2045+
[
2046+
None,
2047+
([1, 17, 5], torch.float32, True),
2048+
([17, 6, 3], torch.float32, True),
2049+
([6], torch.float32, True),
2050+
]
2051+
)
2052+
def forward(self, inputVec, weight, bias):
2053+
return torch.ops.aten.convolution(
2054+
inputVec,
2055+
weight,
2056+
bias=bias,
2057+
stride=[7],
2058+
padding=[10],
2059+
dilation=[4],
2060+
transposed=True,
2061+
output_padding=[0],
2062+
groups=1,
2063+
)
2064+
2065+
2066+
@register_test_case(module_factory=lambda: TransposedConv1dNegativePaddingLarge())
2067+
def TransposedConv1dNegativePaddingLarge_basic(module, tu: TestUtils):
2068+
module.forward(tu.rand(1, 17, 5), tu.rand(17, 6, 3), tu.rand(6))
2069+
2070+
20052071
class TransposedConv2dNegativePadding(torch.nn.Module):
20062072
def __init__(self):
20072073
super().__init__()
@@ -2034,6 +2100,38 @@ def TransposedConv2dNegativePadding_basic(module, tu: TestUtils):
20342100
module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2))
20352101

20362102

2103+
class TransposedConv2dPositiveAndNegativePadding(torch.nn.Module):
2104+
def __init__(self):
2105+
super().__init__()
2106+
2107+
@export
2108+
@annotate_args(
2109+
[
2110+
None,
2111+
([1, 1, 4, 7], torch.float32, True),
2112+
([1, 2, 3, 3], torch.float32, True),
2113+
([2], torch.float32, True),
2114+
]
2115+
)
2116+
def forward(self, inputVec, weight, bias):
2117+
return torch.ops.aten.convolution(
2118+
inputVec,
2119+
weight,
2120+
bias=bias,
2121+
stride=[4, 4],
2122+
padding=[0, 3],
2123+
dilation=[1, 1],
2124+
transposed=True,
2125+
output_padding=[0, 0],
2126+
groups=1,
2127+
)
2128+
2129+
2130+
@register_test_case(module_factory=lambda: TransposedConv2dPositiveAndNegativePadding())
2131+
def TransposedConv2dPositiveAndNegativePadding_basic(module, tu: TestUtils):
2132+
module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2))
2133+
2134+
20372135
class TransposedConv3dNegativePadding(torch.nn.Module):
20382136
def __init__(self):
20392137
super().__init__()
@@ -2052,9 +2150,9 @@ def forward(self, inputVec, weight, bias):
20522150
inputVec,
20532151
weight,
20542152
bias=bias,
2055-
stride=[1, 1, 1],
2153+
stride=[1, 5, 3],
20562154
padding=[2, 1, 3],
2057-
dilation=[1, 1, 1],
2155+
dilation=[1, 2, 1],
20582156
transposed=True,
20592157
output_padding=[0, 0, 0],
20602158
groups=1,

test/Conversion/TorchToLinalg/convolution.mlir

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,17 @@ func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>)
152152
}
153153

154154
// CHECK-LABEL: func.func @tranConv2dNegativePadding(
155-
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32>
156-
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
157-
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[IN_TENSOR]][0, 0, 0, 1] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x7xf32> to tensor<1x1x4x5xf32>
158-
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[INIT_TENSOR:.*]][0, 0, 2, 0] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x5xf32> into tensor<1x1x8x5xf32>
159-
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[INSERTED_SLICE]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
160-
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
155+
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32> attributes {torch.assume_strict_symbolic_shapes} {
156+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
157+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
158+
// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
159+
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
160+
// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x8x7xf32>
161+
// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x8x7xf32>) -> tensor<1x1x8x7xf32>
162+
// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 1, 1] : tensor<1x1x4x7xf32> into tensor<1x1x8x7xf32>
163+
// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 8, 5] [1, 1, 1, 1] : tensor<1x1x8x7xf32> to tensor<1x1x8x5xf32>
164+
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
165+
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
161166
func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) -> !torch.vtensor<[1, 2, 6, 3],f32> attributes {torch.assume_strict_symbolic_shapes} {
162167
%int0 = torch.constant.int 0
163168
%true = torch.constant.bool true
@@ -174,3 +179,31 @@ func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) ->
174179
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int1 : !torch.vtensor<[1, 1, 4, 7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1, 2, 6, 3],f32>
175180
return %6 : !torch.vtensor<[1, 2, 6, 3],f32>
176181
}
182+
183+
// CHECK-LABEL: func.func @tranConv2dNegativeAndPositivePadding(
184+
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>,
185+
// CHECK-SAME: %[[WEIGHTS_VTENSOR:.*]]: !torch.vtensor<[1,2,3,3],f32>,
186+
// CHECK-SAME: %[[BIAS_VTENSOR:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> {
187+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
188+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
189+
// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
190+
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
191+
// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x17x25xf32>
192+
// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x17x25xf32>) -> tensor<1x1x17x25xf32>
193+
// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 4, 4] : tensor<1x1x4x7xf32> into tensor<1x1x17x25xf32>
194+
// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 17, 23] [1, 1, 1, 1] : tensor<1x1x17x25xf32> to tensor<1x1x17x23xf32>
195+
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x17x23xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x15x21xf32>) -> tensor<1x2x15x21xf32>
196+
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x15x21xf32> -> !torch.vtensor<[1,2,15,21],f32>
197+
func.func @tranConv2dNegativeAndPositivePadding(%arg0: !torch.vtensor<[1,1,4,7],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> {
198+
%int1 = torch.constant.int 1
199+
%int3 = torch.constant.int 3
200+
%int0 = torch.constant.int 0
201+
%int4 = torch.constant.int 4
202+
%true = torch.constant.bool true
203+
%0 = torch.prim.ListConstruct %int4, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
204+
%1 = torch.prim.ListConstruct %int0, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
205+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
206+
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
207+
%4 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %1, %2, %true, %3, %int1 : !torch.vtensor<[1,1,4,7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,15,21],f32>
208+
return %4 : !torch.vtensor<[1,2,15,21],f32>
209+
}

0 commit comments

Comments
 (0)