Skip to content

Commit 5e556e6

Browse files
michalpaszkowskipszymich
authored andcommitted
Support new SPV_INTEL_joint_matrix spec in SPIRVReader
This change adds support for the new specification of the SPV_INTEL_joint_matrix extension in SPIRVReader.
1 parent a808d3f commit 5e556e6

File tree

8 files changed

+85
-31
lines changed

8 files changed

+85
-31
lines changed

IGC/AdaptorOCL/SPIRV/SPIRVReader.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,6 +2297,7 @@ SPIRVToLLVM::transType(SPIRVType *T) {
22972297
SPIRAddressSpace::SPIRAS_Global));
22982298
}
22992299
case OpTypeJointMatrixINTEL:
2300+
case OpTypeJointMatrixINTEL_OLD:
23002301
{
23012302
SPIRVTypeJointMatrixINTEL *MT = static_cast<SPIRVTypeJointMatrixINTEL *>(T);
23022303
std::string typeName = "intel.joint_matrix_" + MT->getMangledName() + "_t";
@@ -3174,6 +3175,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
31743175
}
31753176
break;
31763177
case OpTypeJointMatrixINTEL:
3178+
case OpTypeJointMatrixINTEL_OLD:
31773179
{
31783180
IGC_ASSERT(CV.size() == 1);
31793181

@@ -3663,7 +3665,8 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
36633665
case OpVectorExtractDynamic: {
36643666
auto CE = static_cast<SPIRVVectorExtractDynamic *>(BV);
36653667
IGC_ASSERT_MESSAGE(BB, "Invalid BB");
3666-
if (CE->getVector()->getType()->getOpCode() == OpTypeJointMatrixINTEL)
3668+
Op VectorTypeOpCode = CE->getVector()->getType()->getOpCode();
3669+
if (VectorTypeOpCode == OpTypeJointMatrixINTEL || VectorTypeOpCode == OpTypeJointMatrixINTEL_OLD)
36673670
{
36683671
Value *matrix = transValue(CE->getVector(), F, BB);
36693672
Value *index = transValue(CE->getIndex(), F, BB);
@@ -3719,7 +3722,8 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
37193722
case OpVectorInsertDynamic: {
37203723
auto CI = static_cast<SPIRVVectorInsertDynamic *>(BV);
37213724
IGC_ASSERT_MESSAGE(BB, "Invalid BB");
3722-
if (CI->getVector()->getType()->getOpCode() == OpTypeJointMatrixINTEL)
3725+
Op VectorTypeOpCode = CI->getVector()->getType()->getOpCode();
3726+
if (VectorTypeOpCode == OpTypeJointMatrixINTEL || VectorTypeOpCode == OpTypeJointMatrixINTEL_OLD)
37233727
{
37243728
Value *matrix = transValue(CI->getVector(), F, BB);
37253729
Value *component = transValue(CI->getComponent(), F, BB);
@@ -3982,7 +3986,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
39823986
case OpJointMatrixLoadINTEL: {
39833987
SPIRVJointMatrixLoadINTEL *ML = static_cast<SPIRVJointMatrixLoadINTEL *>(BV);
39843988
std::vector<SPIRVValue *> BArgs = ML->getOperands();
3985-
enum SPVIdx { Pointer, Stride, Layout, Scope, MemOp };
3989+
enum SPVIdx { Pointer, Stride, Layout, MemOp };
39863990

39873991
SPIRVTypeJointMatrixINTEL *MatTy = static_cast<SPIRVTypeJointMatrixINTEL *>(ML->getType());
39883992
const unsigned loadLayout = (unsigned)BM->get<SPIRVConstant>(BArgs[Layout]->getId())->getZExtIntValue();
@@ -4028,7 +4032,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
40284032
case OpJointMatrixStoreINTEL: {
40294033
SPIRVJointMatrixStoreINTEL *MS = static_cast<SPIRVJointMatrixStoreINTEL *>(BV);
40304034
std::vector<SPIRVValue *> BArgs = MS->getOperands();
4031-
enum SPVIdx { Pointer, Object, Stride, Layout, Scope, MemOp };
4035+
enum SPVIdx { Pointer, Object, Stride, Layout, MemOp };
40324036

40334037
SPIRVTypeJointMatrixINTEL *MatTy = static_cast<SPIRVTypeJointMatrixINTEL *>(BArgs[Object]->getType());
40344038
const unsigned storeLayout = (unsigned)BM->get<SPIRVConstant>(BArgs[Layout]->getId())->getZExtIntValue();
@@ -4079,7 +4083,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
40794083
SPIRVInstruction *MI = static_cast<SPIRVInstruction *>(BV);
40804084
std::vector<SPIRVValue *> BArgs = MI->getOperands();
40814085

4082-
enum SPVIdx { A, B, C, Scope };
4086+
enum SPVIdx { A, B, C };
40834087
auto *MatATy = static_cast<SPIRVTypeJointMatrixINTEL *>(BArgs[A]->getType());
40844088
auto *MatBTy = static_cast<SPIRVTypeJointMatrixINTEL *>(BArgs[B]->getType());
40854089
auto *MatCTy = static_cast<SPIRVTypeJointMatrixINTEL *>(BArgs[C]->getType());

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ SPIRVEntry::create(Op OpCode) {
6464
#define _SPIRV_OP(x,...) case Op##x: return igc_spv::create<SPIRV##x>();
6565
#include "SPIRVOpCodeEnum.h"
6666
#undef _SPIRV_OP
67+
case OpTypeJointMatrixINTEL_OLD: // Remove once OpTypeJointMatrixINTEL_OLD is removed
68+
return new SPIRVTypeJointMatrixINTEL(OpTypeJointMatrixINTEL_OLD);
6769
default:
6870
break;
6971
}

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,8 @@ class SPIRVVectorExtractDynamic:public SPIRVInstruction {
15001500
if (getValue(VectorId)->isForward())
15011501
return;
15021502
IGC_ASSERT(getValueType(VectorId)->isTypeVector()
1503-
|| getValue(VectorId)->getType()->getOpCode() == OpTypeJointMatrixINTEL);
1503+
|| getValue(VectorId)->getType()->getOpCode() == OpTypeJointMatrixINTEL
1504+
|| getValue(VectorId)->getType()->getOpCode() == OpTypeJointMatrixINTEL_OLD);
15041505
}
15051506
SPIRVId VectorId;
15061507
SPIRVId IndexId;
@@ -1523,7 +1524,8 @@ class SPIRVVectorInsertDynamic :public SPIRVInstruction {
15231524
if (getValue(VectorId)->isForward())
15241525
return;
15251526
IGC_ASSERT(getValueType(VectorId)->isTypeVector()
1526-
|| getValue(VectorId)->getType()->getOpCode() == OpTypeJointMatrixINTEL);
1527+
|| getValue(VectorId)->getType()->getOpCode() == OpTypeJointMatrixINTEL
1528+
|| getValue(VectorId)->getType()->getOpCode() == OpTypeJointMatrixINTEL_OLD);
15271529
}
15281530
SPIRVId VectorId;
15291531
SPIRVId IndexId;
@@ -2335,10 +2337,10 @@ class SPIRVJointMatrixINTELInst: public SPIRVInstTemplateBase {
23352337

23362338
_SPIRV_OP(JointMatrixLoad, true, 6, true)
23372339
_SPIRV_OP(JointMatrixStore, false, 5, true)
2338-
_SPIRV_OP(JointMatrixMad, true, 7)
2339-
_SPIRV_OP(JointMatrixSUMad, true, 7)
2340-
_SPIRV_OP(JointMatrixUSMad, true, 7)
2341-
_SPIRV_OP(JointMatrixUUMad, true, 7)
2340+
_SPIRV_OP(JointMatrixMad, true, 6)
2341+
_SPIRV_OP(JointMatrixSUMad, true, 6)
2342+
_SPIRV_OP(JointMatrixUSMad, true, 6)
2343+
_SPIRV_OP(JointMatrixUUMad, true, 6)
23422344
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
23432345
#undef _SPIRV_OP
23442346
}

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVOpCode.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ inline bool isTypeOpCode(Op OpCode) {
195195
unsigned OC = OpCode;
196196
return (OpTypeVoid <= OC && OC <= OpTypePipe) ||
197197
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL ||
198-
OC == OpTypeJointMatrixINTEL ||
198+
OC == OpTypeJointMatrixINTEL || OC == OpTypeJointMatrixINTEL_OLD ||
199199
isVCOpCode(OpCode) || OC == OpTypeTokenINTEL;
200200
}
201201

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,8 @@ _SPIRV_OP(TypeTokenINTEL, 6113)
516516
_SPIRV_OP(ConvertFToBF16INTEL, 6116)
517517
_SPIRV_OP(ConvertBF16ToFINTEL, 6117)
518518
// SPV_INTEL_matrix
519-
_SPIRV_OP(TypeJointMatrixINTEL, 6119)
519+
//_SPIRV_OP(TypeJointMatrixINTEL_OLD, 6119) Replaced by 6184
520+
_SPIRV_OP(TypeJointMatrixINTEL, 6184)
520521
_SPIRV_OP(JointMatrixLoadINTEL, 6120)
521522
_SPIRV_OP(JointMatrixStoreINTEL, 6121)
522523
_SPIRV_OP(JointMatrixMadINTEL, 6122)

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -353,30 +353,50 @@ void SPIRVTypeForwardPointer::decode(std::istream& I) {
353353
Decoder >> PointerId >> SC;
354354
}
355355

356+
unsigned SPIRVTypeJointMatrixINTEL::getRows() const {
357+
return (unsigned)get<SPIRVConstant>(Args[0])->getZExtIntValue();
358+
}
359+
360+
unsigned SPIRVTypeJointMatrixINTEL::getColumns() const {
361+
return (unsigned)get<SPIRVConstant>(Args[1])->getZExtIntValue();
362+
}
363+
356364
unsigned SPIRVTypeJointMatrixINTEL::getLayout() const {
357-
return (unsigned)get<SPIRVConstant>(Args[2])->getZExtIntValue();
365+
if (isLayoutParameterPresent())
366+
return (unsigned)get<SPIRVConstant>(Args[2])->getZExtIntValue();
367+
return 0;
368+
}
369+
370+
unsigned SPIRVTypeJointMatrixINTEL::getScope() const {
371+
return getOpCode() == OpTypeJointMatrixINTEL
372+
? (unsigned)get<SPIRVConstant>(Args[2])->getZExtIntValue()
373+
: (unsigned)get<SPIRVConstant>(Args[3])->getZExtIntValue();
358374
}
359375

360376
unsigned SPIRVTypeJointMatrixINTEL::getUse() const {
361377
if (isUseParameterPresent())
362-
return (unsigned)get<SPIRVConstant>(Args[4])->getZExtIntValue();
378+
return getOpCode() == OpTypeJointMatrixINTEL
379+
? (unsigned)get<SPIRVConstant>(Args[3])->getZExtIntValue()
380+
: (unsigned)get<SPIRVConstant>(Args[4])->getZExtIntValue();
363381
return 0;
364382
}
365383

366-
unsigned SPIRVTypeJointMatrixINTEL::getRows() const {
367-
return (unsigned)get<SPIRVConstant>(Args[0])->getZExtIntValue();
384+
unsigned SPIRVTypeJointMatrixINTEL::getComponentTypeInterpretation() const {
385+
if (isComponentTypeInterpretationParameterPresent())
386+
return (unsigned)get<SPIRVConstant>(Args[4])->getZExtIntValue();
387+
return 0;
368388
}
369389

370-
unsigned SPIRVTypeJointMatrixINTEL::getColumns() const {
371-
return (unsigned)get<SPIRVConstant>(Args[1])->getZExtIntValue();
390+
bool SPIRVTypeJointMatrixINTEL::isLayoutParameterPresent() const {
391+
return getOpCode() == OpTypeJointMatrixINTEL_OLD;
372392
}
373393

374-
unsigned SPIRVTypeJointMatrixINTEL::getScope() const {
375-
return (unsigned)get<SPIRVConstant>(Args[3])->getZExtIntValue();
394+
bool SPIRVTypeJointMatrixINTEL::isUseParameterPresent() const {
395+
return getOpCode() == OpTypeJointMatrixINTEL || Args.size() > 4;
376396
}
377397

378-
bool SPIRVTypeJointMatrixINTEL::isUseParameterPresent() const {
379-
return Args.size() > 4;
398+
bool SPIRVTypeJointMatrixINTEL::isComponentTypeInterpretationParameterPresent() const {
399+
return getOpCode() == OpTypeJointMatrixINTEL && Args.size() > 4;
380400
}
381401

382402
std::string SPIRVTypeJointMatrixINTEL::getMangledName() const {
@@ -413,9 +433,9 @@ std::string SPIRVTypeJointMatrixINTEL::getMangledName() const {
413433
name += "_";
414434

415435
if (ElemType->isTypeFloat()) {
416-
name += "f";
436+
name += "f";
417437
} else {
418-
name += "i";
438+
name += "i";
419439
}
420440
name += std::to_string(ElemType->getBitWidth());
421441
return std::move(name);

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -679,18 +679,25 @@ class SPIRVTypeNamedBarrier :public SPIRVType {
679679

680680
class SPIRVTypeJointMatrixINTEL : public SPIRVType {
681681
public:
682-
const static Op OC = OpTypeJointMatrixINTEL;
683682
const static SPIRVWord FixedWC = 3;
684683
// Complete constructor
685-
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, SPIRVType *ElemType,
686-
std::vector<SPIRVId> Args)
684+
SPIRVTypeJointMatrixINTEL(Op OC, SPIRVModule *M, SPIRVId TheId,
685+
SPIRVType *ElemType, std::vector<SPIRVId> Args)
687686
: SPIRVType(M, FixedWC, OC, TheId), ElemType(ElemType), Args(Args) {
687+
bool ValidOpcode = OC == OpTypeJointMatrixINTEL || OC == OpTypeJointMatrixINTEL_OLD;
688+
IGC_ASSERT_EXIT_MESSAGE(ValidOpcode, "Invalid opcode for TypeJointMatrixINTEL");
688689
validate();
689690
}
690691

691-
// Incomplete constructor
692+
// Incomplete constructors
692693
SPIRVTypeJointMatrixINTEL()
694+
: SPIRVType(OpTypeJointMatrixINTEL), ElemType(0), Args({0, 0, 0, 0}) {
695+
}
696+
// Remove once OpTypeJointMatrixINTEL_OLD is removed
697+
SPIRVTypeJointMatrixINTEL(Op OC)
693698
: SPIRVType(OC), ElemType(0), Args({0, 0, 0, 0}) {
699+
bool ValidOpcode = OC == OpTypeJointMatrixINTEL || OC == OpTypeJointMatrixINTEL_OLD;
700+
IGC_ASSERT_EXIT_MESSAGE(ValidOpcode, "Invalid opcode for TypeJointMatrixINTEL");
694701
}
695702

696703
CapVec getRequiredCapability() const override {
@@ -704,14 +711,17 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
704711

705712
SPIRVType *getElemType() const { return ElemType; }
706713

707-
unsigned getLayout() const;
708-
unsigned getUse() const;
709714
unsigned getRows() const;
710715
unsigned getColumns() const;
716+
unsigned getLayout() const;
711717
unsigned getScope() const;
718+
unsigned getUse() const;
719+
unsigned getComponentTypeInterpretation() const;
712720

713721
std::string getMangledName() const;
722+
bool isLayoutParameterPresent() const;
714723
bool isUseParameterPresent() const;
724+
bool isComponentTypeInterpretationParameterPresent() const;
715725

716726
enum {
717727
LayoutColumnMajor = 0,
@@ -729,6 +739,15 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
729739
UseMAX
730740
};
731741

742+
enum {
743+
CTINone = 0,
744+
CTITF32 = 1,
745+
CTIBfloat16 = 2,
746+
CTIPackedInt2 = 3,
747+
CTIPackedInt4 = 4,
748+
CTIMAX
749+
};
750+
732751
protected:
733752
_SPIRV_DEF_DEC3_OVERRIDE(Id, ElemType, Args)
734753
void validate() const override {
@@ -738,6 +757,7 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
738757
IGC_ASSERT_EXIT_MESSAGE(getColumns() <= 64, "Unsupported columns size.");
739758
IGC_ASSERT_EXIT_MESSAGE(getLayout() < LayoutMAX, "Unsupported layout.");
740759
IGC_ASSERT_EXIT_MESSAGE(getUse() < UseMAX, "Unsupported use parameter.");
760+
IGC_ASSERT_EXIT_MESSAGE(getComponentTypeInterpretation() < CTIMAX, "Unsupported component type interpretation parameter." );
741761
}
742762

743763
private:

IGC/AdaptorOCL/SPIRV/libSPIRV/spirv.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,10 @@ enum Capability {
627627
CapabilityJointMatrixINTEL = 6118,
628628
CapabilityHWThreadQueryINTEL = 6134,
629629
CapabilityGlobalVariableDecorationsINTEL = 6146,
630+
JointMatrixTF32ComponentTypeINTEL = 6436,
631+
JointMatrixBF16ComponentTypeINTEL = 6437,
632+
JointMatrixPackedInt2ComponentTypeINTEL = 6438,
633+
JointMatrixPackedInt4ComponentTypeINTEL = 6439,
630634
};
631635

632636
enum PackedVectorFormat {
@@ -635,6 +639,7 @@ enum PackedVectorFormat {
635639
};
636640

637641
enum Op {
642+
OpTypeJointMatrixINTEL_OLD = 6119, // Replaced by 6184
638643
#define _SPIRV_OP(x, num) Op##x = num,
639644
#include "SPIRVOpCodeEnum.h"
640645
#undef _SPIRV_OP

0 commit comments

Comments
 (0)