Skip to content

Commit d0d74cb

Browse files
mbelickipszymich
authored andcommitted
Fix for large JointMatrix slices.
This change allows to correctly operate on largest implemented JointMatrix slices. (cherry picked from commit f38579f)
1 parent d0cdfd4 commit d0d74cb

File tree

1 file changed

+49
-8
lines changed

1 file changed

+49
-8
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass.cpp

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,12 @@ Type *JointMatrixFuncsResolutionPass::ResolveType(Type *opaqueType, JointMatrixT
647647
return IGCLLVM::FixedVectorType::get(baseType, 32);
648648
}
649649
if (desc.layout == LayoutRowMajor && desc.rows == 32 && desc.columns == 64) {
650+
/* This should ideally be a vector of <i32 x 128>. However since IGC
651+
* code gen supports vector operations only on vectors up to 32
652+
* entries, we model this slice as array of [2 x <i64 x 32>]. */
650653
Type *baseType = Type::getInt64Ty(ctx);
651-
return IGCLLVM::FixedVectorType::get(baseType, 32);
654+
Type *halfTy = IGCLLVM::FixedVectorType::get(baseType, 32);
655+
return ArrayType::get(halfTy, 2);
652656
}
653657

654658
if (desc.layout == LayoutPackedA && desc.rows == 16 && desc.columns == 16) {
@@ -698,6 +702,33 @@ static uint64_t constIntValue(const Value *v) {
698702
return cast<ConstantInt>(v)->getLimitedValue();
699703
}
700704

705+
template <class BuilderT>
706+
static Instruction *loadSlice(BuilderT *builder, Type *matTy, Value *sliceArray) {
707+
IGCLLVM::FixedVectorType *sliceTy = dyn_cast<IGCLLVM::FixedVectorType>(matTy);
708+
if (sliceTy && sliceTy->getNumElements() <= 32) {
709+
return builder->CreateLoad(matTy, sliceArray);
710+
} else if (matTy->isArrayTy() && matTy->getArrayNumElements() == 2) {
711+
Type *baseType = Type::getInt64Ty(builder->getContext());
712+
Type *halfTy = IGCLLVM::FixedVectorType::get(baseType, 32);
713+
Type *halfPtrTy = halfTy->getPointerTo(ADDRESS_SPACE_PRIVATE);
714+
715+
Value *ptr0 = builder->CreateBitCast(sliceArray, halfPtrTy);
716+
Value *slice0 = builder->CreateLoad(halfTy, ptr0);
717+
718+
Value *ptr1 = builder->CreateGEP(halfTy, ptr0, { builder->getInt32(1) });
719+
Value *slice1 = builder->CreateLoad(halfTy, ptr1);
720+
721+
Value *pair = UndefValue::get(ArrayType::get(halfTy, 2));
722+
pair = builder->CreateInsertValue(pair, slice0, { 0 });
723+
pair = builder->CreateInsertValue(pair, slice1, { 1 });
724+
725+
return dyn_cast<Instruction>(pair);
726+
}
727+
728+
IGC_ASSERT_MESSAGE(false, "Unexpected number of elements in matrix slice.");
729+
return nullptr;
730+
}
731+
701732
Instruction *JointMatrixFuncsResolutionPass::ResolveLoad(CallInst *CI)
702733
{
703734
Value *ptrVal = CI->getArgOperand(0);
@@ -731,7 +762,7 @@ Instruction *JointMatrixFuncsResolutionPass::ResolveLoad(CallInst *CI)
731762
Instruction *newCall = builder.CreateCall(M->getOrInsertFunction(funcName, funcType), Args);
732763
newCall->setDebugLoc(CI->getDebugLoc());
733764

734-
newCall = builder.CreateLoad(matTy, sliceArray);
765+
newCall = loadSlice(&builder, matTy, sliceArray);
735766

736767
return newCall;
737768
}
@@ -909,7 +940,7 @@ Instruction *JointMatrixFuncsResolutionPass::ResolveMad(CallInst *CI, unsigned O
909940
Value* args[4] = { ptrA, ptrB, ptrC, ptrD };
910941

911942
builder.CreateCall(madFunc, args);
912-
dpasCall = builder.CreateLoad(cMat->getType(), sliceD);
943+
dpasCall = loadSlice(&builder, cMat->getType(), sliceD);
913944
} else {
914945
int SD = 8; // systolic depth, only 8 supported currently
915946
int RC = aDesc.rows; // repeat count, from 1 to 8
@@ -1259,11 +1290,21 @@ void JointMatrixFuncsResolutionPass::InsertPlaceholder(Value *v) {
12591290
if (Instruction *inst = dyn_cast<Instruction>(v)) {
12601291
predecesor = inst;
12611292
}
1262-
/* Using bit-casts as placeholder values. Undefs of each type are unique per
1263-
* module and cannot be used as unique placeholders. */
1264-
Instruction *placeholder =
1265-
BitCastInst::Create(Instruction::BitCast, UndefValue::get(type),
1266-
type, "tmp.value", predecesor);
1293+
1294+
Instruction *placeholder = nullptr;
1295+
if (!type->isArrayTy()) {
1296+
/* Using bit-casts as placeholder values. Undefs of each type are unique per
1297+
* module and cannot be used as unique placeholders. */
1298+
placeholder =
1299+
BitCastInst::Create(Instruction::BitCast, UndefValue::get(type),
1300+
type, "tmp.value", predecesor);
1301+
} else {
1302+
/* Array types cannot be bitcasted. Use instert element with two undefs
1303+
* to create unique placeholder for array value.*/
1304+
Value *array = UndefValue::get(type);
1305+
Value *element = UndefValue::get(type->getArrayElementType());
1306+
placeholder = InsertValueInst::Create(array, element, { 0 }, "tmp.value", predecesor);
1307+
}
12671308
ResolvedValues[v] = placeholder;
12681309
PlaceholderInstructions[v] = placeholder;
12691310
}

0 commit comments

Comments
 (0)