-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[mlir][nvgpu] Fix crash when mmaShape size is not three #173490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-nvgpu @llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR fixes a crash when the mmaShape attribute of Patch is 51.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/173490.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
index 73d86283a5940..5b9ae8bb7a518 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
@@ -66,16 +66,6 @@ class NVGPU_MmaSyncOp<string mnemonic> :
NVGPU_Op<mnemonic, [Pure,
PredOpTrait<"matrixA and matrixB have same element type",
TCopVTEtIsSameAs<0, 1>>]> {
- code extraBaseClassDeclaration = [{
- std::array<int64_t, 3> getMmaShapeAsArray() {
- ArrayAttr mmaShape = this->getMmaShape();
- assert(mmaShape.size() == 3 && "mmaShape should be three integers");
- return {::llvm::cast<IntegerAttr>(mmaShape[0]).getInt(),
- ::llvm::cast<IntegerAttr>(mmaShape[1]).getInt(),
- ::llvm::cast<IntegerAttr>(mmaShape[2]).getInt()};
- }
- }];
-
let hasVerifier = 1;
}
@@ -96,14 +86,14 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
Example:
```mlir
- %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = [16, 8, 16]} :
+ %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = array<i64: 16, 8, 16>} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
```
}];
let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
AnyVectorOfNonZeroRank:$matrixB,
AnyVectorOfNonZeroRank:$matrixC,
- I64ArrayAttr:$mmaShape,
+ DenseI64ArrayAttr:$mmaShape,
OptionalAttr<UnitAttr>:$tf32Enabled);
let results = (outs AnyVectorOfNonZeroRank:$res);
@@ -112,7 +102,7 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
OpBuilder<(ins "Value":$matrixA,
"Value":$matrixB,
"Value":$matrixC,
- "ArrayAttr":$mmaShape)>,
+ "DenseI64ArrayAttr":$mmaShape)>,
OpBuilder<(ins "Value":$matrixA,
"Value":$matrixB,
"Value":$matrixC,
@@ -124,8 +114,6 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
`(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
}];
-
- let extraClassDeclaration = extraBaseClassDeclaration;
}
def NVGPU_MmaSparseSyncMetadataType : FixedVectorOfLengthAndType<[2], [I16]>,
@@ -151,7 +139,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
Example (targetingthe f16 16x8x32 `mma.sp` PTX instruction):
```mlir
- nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
+ nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = array<i64: 16, 8, 32>} :
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
```
}];
@@ -160,7 +148,7 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
AnyVectorOfNonZeroRank:$matrixB,
AnyVectorOfNonZeroRank:$matrixC,
NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
- I64ArrayAttr:$mmaShape,
+ DenseI64ArrayAttr:$mmaShape,
DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
OptionalAttr<UnitAttr>:$tf32Enabled
);
@@ -179,8 +167,6 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
`(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
}];
-
- let extraClassDeclaration = extraBaseClassDeclaration;
}
def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..7cd59b0f135fe 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -340,7 +340,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
VectorType bType = op.getMatrixA().getType();
VectorType cType = op.getMatrixC().getType();
- std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
+ ArrayRef<int64_t> gemmShape = op.getMmaShape();
// Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
@@ -485,7 +485,7 @@ static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
/// it's expected that the provided parameters correspond to a valid
/// instruction.
static std::string buildMmaSparseAsmString(
- const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
+ ArrayRef<int64_t> shape, unsigned matASize, unsigned matBSize,
unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
@@ -526,7 +526,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
- int64_t metadataSelector, const std::array<int64_t, 3> &shape,
+ int64_t metadataSelector, ArrayRef<int64_t> shape,
Type intrinsicResultType) {
auto asmDialectAttr =
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
@@ -618,7 +618,7 @@ struct NVGPUMmaSparseSyncLowering
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
- matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
+ matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShape(),
intrinsicResTy);
if (failed(intrinsicResult))
return failure();
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 98434357f826f..d4aa893a5b420 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1058,8 +1058,8 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
- Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
- rewriter.getI64ArrayAttr({m, n, k}));
+ Value matmul =
+ nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC, {m, n, k});
valueMapping[op.getResult()] = matmul;
return success();
}
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 237aab4d7f309..eb2de91d30988 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -119,7 +119,8 @@ LogicalResult DeviceAsyncCopyOp::verify() {
//===----------------------------------------------------------------------===//
void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, Value matrixA,
- Value matrixB, Value matrixC, ArrayAttr mmaShape) {
+ Value matrixB, Value matrixC,
+ DenseI64ArrayAttr mmaShape) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
mmaShape, UnitAttr());
}
@@ -129,8 +130,7 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
bool tf32Enabled) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
- odsBuilder.getI64ArrayAttr(mmaShape),
- tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
+ mmaShape, tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
}
/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
@@ -138,7 +138,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
TypedValue<VectorType> matrixA,
TypedValue<VectorType> matrixB,
TypedValue<VectorType> matrixC,
- const std::array<int64_t, 3> &mmaShape,
+ ArrayRef<int64_t> mmaShape,
bool tf32Enabled, bool sparse = false) {
// The verification for mma.sync covering various shapes and data types is
// based on the fundamental tensor core shape.
@@ -209,7 +209,12 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
return op->emitError() << "matrixC must be 2 dimensional vector";
}
- auto [m, n, k] = mmaShape;
+ if (mmaShape.size() != 3) {
+ return op->emitError() << "mmaShape should be three integers";
+ }
+ int64_t m = mmaShape[0];
+ int64_t n = mmaShape[1];
+ int64_t k = mmaShape[2];
// verify warp-wide size for vector a
int64_t sparseFactor = sparse ? 2 : 1;
@@ -262,7 +267,7 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
LogicalResult MmaSyncOp::verify() {
return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
- getMatrixC(), getMmaShapeAsArray(),
+ getMatrixC(), getMmaShape(),
getOperation()->hasAttr(getTf32EnabledAttrName()));
}
@@ -274,17 +279,16 @@ void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
Value matrixB, Value matrixC, Value sparseMetadata,
ArrayRef<int64_t> mmaShape) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
- sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
+ sparseMetadata, mmaShape, 0, UnitAttr());
}
LogicalResult MmaSparseSyncOp::verify() {
unsigned sparsitySelector = getSparsitySelector();
if (sparsitySelector > 1)
return emitOpError() << "sparsity selector should be 0 or 1";
- return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
- getMatrixC(), getMmaShapeAsArray(),
- getOperation()->hasAttr(getTf32EnabledAttrName()),
- true);
+ return verifyMmaSyncOp(
+ this->getOperation(), getMatrixA(), getMatrixB(), getMatrixC(),
+ getMmaShape(), getOperation()->hasAttr(getTf32EnabledAttrName()), true);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 0eb44789fe31d..74bf571b4a64e 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -14,7 +14,7 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
// CHECK-NOT: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -31,7 +31,7 @@ func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %a
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
// CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -59,7 +59,7 @@ func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: v
// CHECK-NOT: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 8>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 8>} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.mlir.poison : !llvm.array<2 x vector<2xf16>>
@@ -90,7 +90,7 @@ func.func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: ve
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
@@ -109,7 +109,7 @@ func.func @m16n8k32_i4(%arg0: vector<2x8xi4>, %arg1: vector<1x8xi4>, %arg2: vect
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 32>} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
@@ -134,7 +134,7 @@ func.func @m16n8k64_i4(%arg0: vector<4x8xi4>, %arg1: vector<2x8xi4>, %arg2: vect
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 64>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 64>} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
@@ -145,7 +145,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
// CHECK: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
// CHECK-SAME: shape = #nvvm.shape<m = 8, n = 8, k = 4>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 8, 8, 4>} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
// CHECK: llvm.mlir.poison : vector<2xf64>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
@@ -201,7 +201,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
// CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = array<i64: 16, 8, 4>, tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.poison : vector<2xf32>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
@@ -370,7 +370,7 @@ func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
// CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 32>} :
(vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue %[[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
@@ -406,7 +406,7 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
// CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 16>} :
(vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -427,7 +427,7 @@ func.func @mma_sp_sync_f16_16816_01(%arg0: vector<2x2xf16>,
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3)
- {mmaShape = [16, 8, 16], sparsitySelector = 1 : i32} :
+ {mmaShape = array<i64: 16, 8, 16>, sparsitySelector = 1 : i32} :
(vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
@@ -465,7 +465,7 @@ func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
// CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32)
- %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = array<i64: 16, 8, 64>} :
(vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
diff --git a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
index 0afaa19d59d15..c42f5add697f0 100644
--- a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
@@ -29,7 +29,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
%B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32>
%C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
- // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+ // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = array<i64: 16, 8, 16>} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
%D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
@@ -38,7 +38,7 @@ func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16
%B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32>
...
[truncated]
|
This PR fixes a crash when the mmaShape attribute of `nvgpu.mma.sync` does not have exactly three elements. The change replaces the ArrayAttr-based mmaShape with DenseI64ArrayAttr and adds verifier checks to ensure the attribute has three elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why you eplace the ArrayAttr-based mmaShape with DenseI64ArrayAttr? I think I64ArrayAttr is good.I think the DenseI64ArrayAttr approach makes the IR more complicated; previously, we didn't need to write 'i64:'.
Because we don't need to add a |
Actually, I don't think the print format is a big problem, because we don't need to 'write' it in most of the time. If you like, I think the best way is write a customer assembly format for |
I think this makes the usability of the NVGPU dialect worse. I would like to hear @grypp thoughts. |
|
Ping~ |
This PR fixes a crash when the mmaShape attribute of
nvgpu.mma.syncdoes not have exactly three elements. The change replaces the ArrayAttr-based mmaShape with DenseI64ArrayAttr and adds verifier checks to ensure the attribute has three elements. Fixes #173378.