@@ -2297,6 +2297,7 @@ SPIRVToLLVM::transType(SPIRVType *T) {
22972297 SPIRAddressSpace::SPIRAS_Global));
22982298 }
22992299 case OpTypeJointMatrixINTEL:
2300+ case OpTypeJointMatrixINTEL_OLD:
23002301 {
23012302 SPIRVTypeJointMatrixINTEL *MT = static_cast <SPIRVTypeJointMatrixINTEL *>(T);
23022303 std::string typeName = " intel.joint_matrix_" + MT->getMangledName () + " _t" ;
@@ -3174,6 +3175,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
31743175 }
31753176 break ;
31763177 case OpTypeJointMatrixINTEL:
3178+ case OpTypeJointMatrixINTEL_OLD:
31773179 {
31783180 IGC_ASSERT (CV.size () == 1 );
31793181
@@ -3663,7 +3665,8 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
36633665 case OpVectorExtractDynamic: {
36643666 auto CE = static_cast <SPIRVVectorExtractDynamic *>(BV);
36653667 IGC_ASSERT_MESSAGE (BB, " Invalid BB" );
3666- if (CE->getVector ()->getType ()->getOpCode () == OpTypeJointMatrixINTEL)
3668+ Op VectorTypeOpCode = CE->getVector ()->getType ()->getOpCode ();
3669+ if (VectorTypeOpCode == OpTypeJointMatrixINTEL || VectorTypeOpCode == OpTypeJointMatrixINTEL_OLD)
36673670 {
36683671 Value *matrix = transValue (CE->getVector (), F, BB);
36693672 Value *index = transValue (CE->getIndex (), F, BB);
@@ -3719,7 +3722,8 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
37193722 case OpVectorInsertDynamic: {
37203723 auto CI = static_cast <SPIRVVectorInsertDynamic *>(BV);
37213724 IGC_ASSERT_MESSAGE (BB, " Invalid BB" );
3722- if (CI->getVector ()->getType ()->getOpCode () == OpTypeJointMatrixINTEL)
3725+ Op VectorTypeOpCode = CI->getVector ()->getType ()->getOpCode ();
3726+ if (VectorTypeOpCode == OpTypeJointMatrixINTEL || VectorTypeOpCode == OpTypeJointMatrixINTEL_OLD)
37233727 {
37243728 Value *matrix = transValue (CI->getVector (), F, BB);
37253729 Value *component = transValue (CI->getComponent (), F, BB);
@@ -3982,7 +3986,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
39823986 case OpJointMatrixLoadINTEL: {
39833987 SPIRVJointMatrixLoadINTEL *ML = static_cast <SPIRVJointMatrixLoadINTEL *>(BV);
39843988 std::vector<SPIRVValue *> BArgs = ML->getOperands ();
3985- enum SPVIdx { Pointer, Stride, Layout, Scope, MemOp };
3989+ enum SPVIdx { Pointer, Stride, Layout, MemOp };
39863990
39873991 SPIRVTypeJointMatrixINTEL *MatTy = static_cast <SPIRVTypeJointMatrixINTEL *>(ML->getType ());
39883992 const unsigned loadLayout = (unsigned )BM->get <SPIRVConstant>(BArgs[Layout]->getId ())->getZExtIntValue ();
@@ -4028,7 +4032,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
40284032 case OpJointMatrixStoreINTEL: {
40294033 SPIRVJointMatrixStoreINTEL *MS = static_cast <SPIRVJointMatrixStoreINTEL *>(BV);
40304034 std::vector<SPIRVValue *> BArgs = MS->getOperands ();
4031- enum SPVIdx { Pointer, Object, Stride, Layout, Scope, MemOp };
4035+ enum SPVIdx { Pointer, Object, Stride, Layout, MemOp };
40324036
40334037 SPIRVTypeJointMatrixINTEL *MatTy = static_cast <SPIRVTypeJointMatrixINTEL *>(BArgs[Object]->getType ());
40344038 const unsigned storeLayout = (unsigned )BM->get <SPIRVConstant>(BArgs[Layout]->getId ())->getZExtIntValue ();
@@ -4079,7 +4083,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
40794083 SPIRVInstruction *MI = static_cast <SPIRVInstruction *>(BV);
40804084 std::vector<SPIRVValue *> BArgs = MI->getOperands ();
40814085
4082- enum SPVIdx { A, B, C, Scope };
4086+ enum SPVIdx { A, B, C };
40834087 auto *MatATy = static_cast <SPIRVTypeJointMatrixINTEL *>(BArgs[A]->getType ());
40844088 auto *MatBTy = static_cast <SPIRVTypeJointMatrixINTEL *>(BArgs[B]->getType ());
40854089 auto *MatCTy = static_cast <SPIRVTypeJointMatrixINTEL *>(BArgs[C]->getType ());
0 commit comments