Skip to content

Commit 46a25d7

Browse files
authored
[torch-mlir][sparse] preserve sparsity during lowering torch to linalg (#2809)
This preserves sparsity at the most obvious places of lowering TORCH tensors to MLIR RankedTensorType tensors. Other places are marked for audit. With some initial lowering tests.
1 parent da7c6d2 commit 46a25d7

File tree

6 files changed

+50
-8
lines changed

6 files changed

+50
-8
lines changed

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
978978
return success();
979979
}
980980

981+
// TODO: audit possibility of sparsity on these tensors
981982
Type adjustedResultType = RankedTensorType::get(
982983
makeShapeLLVMCompatible(outputShape), resultType.getElementType());
983984
Type adjustedInputType = RankedTensorType::get(
@@ -1005,6 +1006,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
10051006
intermediateShape.push_back(sum);
10061007
}
10071008

1009+
// TODO: audit possibility of sparsity on these tensor
10081010
Type intermediateResultType =
10091011
RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape),
10101012
resultType.getElementType());
@@ -1657,6 +1659,7 @@ class ConvertAtenSliceScatterOp
16571659
auto srcType = src.getType().cast<RankedTensorType>();
16581660
int64_t srcRank = srcType.getRank();
16591661
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
1662+
// TODO: audit possibility of sparsity on these tensor
16601663
auto abstractSrcType = RankedTensorType::get(
16611664
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
16621665
Value abstractSrc =

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ namespace {
206206
//
207207
// TODO: Find an optimal lowering.
208208
// current lowering is not optimal for bags of large embeddings.
209-
// Since it traverses the output tensor multiple times.
210-
//
209+
// Since it traverses the output tensor multiple times.
210+
//
211211
//
212212

213213
class ConvertAtenEmbeddingBagPaddingIdxOp

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
377377
// TODO: Improve usage of static shape information.
378378
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
379379
ShapedType::kDynamic);
380-
auto lhsBroadcastType =
381-
RankedTensorType::get(lhsTargetShape, lhsType.getElementType());
380+
auto lhsBroadcastType = RankedTensorType::get(
381+
lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding());
382382
if (failed(torch_to_linalg::broadcastToGivenShape(
383383
op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType,
384384
broadcastedLhs))) {
@@ -387,8 +387,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
387387
}
388388
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
389389
ShapedType::kDynamic);
390-
auto rhsBroadcastType =
391-
RankedTensorType::get(rhsTargetShape, rhsType.getElementType());
390+
auto rhsBroadcastType = RankedTensorType::get(
391+
rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding());
392392
if (failed(torch_to_linalg::broadcastToGivenShape(
393393
op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType,
394394
broadcastedRhs))) {
@@ -880,7 +880,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
880880
if(numSpacialDims != 2)
881881
return rewriter.notifyMatchFailure(
882882
op, "unimplemented: only 2D grouped convolution supported");
883-
883+
884884
// Special depthwise case
885885
auto inShape = makeShapeTorchCompatible(
886886
input.getType().cast<RankedTensorType>().getShape());
@@ -894,6 +894,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
894894
(weightShape[0] == kUnknownSize ? kUnknownSize
895895
: weightShape[0] * weightShape[1]),
896896
weightShape[2], weightShape[3]};
897+
// TODO: audit possibility of sparsity on this tensor
897898
Type collapsedType = RankedTensorType::get(
898899
makeShapeLLVMCompatible(collapsedShape), elementType);
899900
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
8787
*pad = castIntToIndex(b, loc, *pad);
8888

8989
Type elementType = input.getType().cast<RankedTensorType>().getElementType();
90+
// TODO: audit possibility of sparsity on this tensor
9091
Type inputType =
9192
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
9293
SmallVector<int64_t>(inRank, kUnknownSize))),

lib/Dialect/Torch/IR/TorchTypes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,8 @@ TensorType ValueTensorType::toBuiltinTensor() const {
467467
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
468468
if (!elementType)
469469
return nullptr;
470-
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType);
470+
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType,
471+
getOptionalSparsity());
471472
}
472473

473474
LogicalResult
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// -----
4+
5+
#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
6+
7+
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
8+
// CHECK-LABEL: func.func @sum(
9+
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32>
10+
// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[64,64],f32,#[[$CSR]]> -> tensor<64x64xf32, #[[$CSR]]>
11+
// CHECK: linalg.generic {{{.*}}} ins(%[[S]] : tensor<64x64xf32, #[[$CSR]]>)
12+
func.func @sum(%arg0: !torch.vtensor<[64,64],f32,#CSR>) -> !torch.vtensor<[],f32> {
13+
%none = torch.constant.none
14+
%0 = torch.aten.sum %arg0, %none
15+
: !torch.vtensor<[64,64],f32,#CSR>, !torch.none -> !torch.vtensor<[],f32>
16+
return %0 : !torch.vtensor<[],f32>
17+
}
18+
19+
// -----
20+
21+
#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
22+
23+
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
24+
// CHECK-LABEL: func.func @SpMM(
25+
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,16],f32,#[[$CSR]]>,
26+
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32>
27+
// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[8,16],f32,#[[$CSR]]> -> tensor<8x16xf32, #[[$CSR]]>
28+
// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[B]] : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
29+
// CHECK: linalg.matmul ins(%[[S]], %[[T]] : tensor<8x16xf32, #[[$CSR]]>, tensor<16x8xf32>)
30+
func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
31+
%arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> {
32+
%0 = torch.aten.matmul %arg0, %arg1
33+
: !torch.vtensor<[8,16],f32,#CSR>,
34+
!torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
35+
return %0 : !torch.vtensor<[8,8],f32>
36+
}

0 commit comments

Comments
 (0)