Skip to content

Commit cf0c745

Browse files
JaroszPiotrigcbot
authored andcommitted
Add functions to enable and disable MidThread Preemption
This changes adds functions which allow to enable and disable MidThread Preemption during the shader execution.
1 parent 40d3812 commit cf0c745

File tree

11 files changed

+179
-9
lines changed

11 files changed

+179
-9
lines changed

IGC/AdaptorCommon/RayTracing/RTBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ Value* RTBuilder::CreateSyncStackPtrIntrinsic(
597597
return StackPtr;
598598
}
599599

600+
600601
RTBuilder::SWStackPtrVal* RTBuilder::getSWStackPointer(const Twine& Name)
601602
{
602603
auto* CI = this->CreateSWStackPtrIntrinsic(nullptr, false, Name);

IGC/AdaptorCommon/RayTracing/RTBuilder.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,8 @@ class RTBuilder : public IGCIRBuilder<>
385385
Value* CreateSWHotZonePtrIntrinsic(Value *Addr, Type *PtrTy, bool AddDecoration);
386386
Value* CreateAsyncStackPtrIntrinsic(Value *Addr, Type *PtrTy, bool AddDecoration);
387387
Value* CreateSyncStackPtrIntrinsic(Value* Addr, Type* PtrTy, bool AddDecoration);
388+
389+
388390
CallInst* CreateSWStackPtrIntrinsic(
389391
Value *Addr, bool AddDecoration, const Twine &Name = "");
390392
SWStackPtrVal* getSWStackPointer(
@@ -705,12 +707,6 @@ class RTBuilder : public IGCIRBuilder<>
705707
uint32_t dim,
706708
IGC::CallableShaderTypeMD ShaderTy);
707709

708-
Value* getTraceRayPayload(
709-
Value* bvhLevel,
710-
Value* traceRayCtrl,
711-
bool isRayQuery,
712-
const Twine& PayloadName = "");
713-
714710
Value* emitStateRegID(uint32_t BitStart, uint32_t BitEnd);
715711
Value* getSliceID();
716712
Value* getSubsliceID();
@@ -727,6 +723,13 @@ class RTBuilder : public IGCIRBuilder<>
727723
const IGC::RayDispatchShaderContext& RtCtx() const;
728724
//printf
729725
public:
726+
727+
Value* getTraceRayPayload(
728+
Value* bvhLevel,
729+
Value* traceRayCtrl,
730+
bool isRayQuery,
731+
const Twine& PayloadName = "");
732+
730733
void printTraceRay(const TraceRayAsyncHLIntrinsic* trace);
731734
void printDispatchRayIndex(const std::vector<Value*>& Indices);
732735
public:

IGC/Compiler/CISACodeGen/CISABuilder.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,6 +3006,61 @@ namespace IGC
30063006
return RMEncoding::RoundToZero_int;
30073007
}
30083008

3009+
CEncoder::PreemptionEncoding CEncoder::getEncoderPreemptionMode(EPreemptionMode preemptionMode)
3010+
{
3011+
switch (preemptionMode)
3012+
{
3013+
default:
3014+
break;
3015+
case PREEMPTION_ENABLED:
3016+
return PreemptionEncoding::PreemptionEnabled;
3017+
case PREEMPTION_DISABLED:
3018+
return PreemptionEncoding::PreemptionDisabled;
3019+
}
3020+
3021+
return PreemptionEncoding::PreemptionEnabled;
3022+
}
3023+
3024+
void CEncoder::SetPreemptionMode(EPreemptionMode actualPreemptionMode, EPreemptionMode newPreemptionMode)
3025+
{
3026+
if (actualPreemptionMode != newPreemptionMode)
3027+
{
3028+
PreemptionEncoding actualPreemptionMode_en = getEncoderPreemptionMode(actualPreemptionMode);
3029+
PreemptionEncoding newPreemptionMode_en = getEncoderPreemptionMode(newPreemptionMode);
3030+
SetPreemptionMode(actualPreemptionMode_en, newPreemptionMode_en);
3031+
}
3032+
}
3033+
3034+
void CEncoder::SetPreemptionMode(PreemptionEncoding actualPreemptionMode, PreemptionEncoding newPreemptionMode)
3035+
{
3036+
IGC_ASSERT_MESSAGE(
3037+
(actualPreemptionMode != newPreemptionMode),
3038+
"Only setting PreemptionMode if the new PreemptionMode is different from the current PreemptionMode!");
3039+
3040+
VISA_VectorOpnd* src0_Opnd = nullptr;
3041+
VISA_VectorOpnd* src1_Opnd = nullptr;
3042+
VISA_VectorOpnd* dst_Opnd = nullptr;
3043+
VISA_GenVar* cr0_var = nullptr;
3044+
3045+
uint preemptionMode = actualPreemptionMode ^ newPreemptionMode;
3046+
3047+
IGC_ASSERT(nullptr != vKernel);
3048+
3049+
V(vKernel->GetPredefinedVar(cr0_var, PREDEFINED_CR0));
3050+
V(vKernel->CreateVISASrcOperand(src0_Opnd, cr0_var, MODIFIER_NONE, 0, 1, 0, 0, 0));
3051+
V(vKernel->CreateVISAImmediate(src1_Opnd, &preemptionMode, ISA_TYPE_UD));
3052+
V(vKernel->CreateVISADstOperand(dst_Opnd, cr0_var, 1, 0, 0));
3053+
V(vKernel->AppendVISAArithmeticInst(
3054+
ISA_XOR,
3055+
nullptr,
3056+
false,
3057+
vISA_EMASK_M1_NM,
3058+
EXEC_SIZE_1,
3059+
dst_Opnd,
3060+
src0_Opnd,
3061+
src1_Opnd));
3062+
}
3063+
30093064
VISA_LabelOpnd* CEncoder::GetLabel(uint label)
30103065
{
30113066
VISA_LabelOpnd* visaLabel = labelMap[label];

IGC/Compiler/CISACodeGen/CISABuilder.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ namespace IGC
471471
void SetRoundingMode_FP(ERoundingMode actualRM, ERoundingMode newRM);
472472
void SetRoundingMode_FPCvtInt(ERoundingMode actualRM, ERoundingMode newRM);
473473

474+
void SetPreemptionMode(EPreemptionMode actualPreemptionMode, EPreemptionMode newPreemptionMode);
475+
474476
static uint GetCISADataTypeSize(VISA_Type type) {return CVariable::GetCISADataTypeSize(type);}
475477
static e_alignment GetCISADataTypeAlignment(VISA_Type type) {return CVariable::GetCISADataTypeAlignment(type);}
476478

@@ -670,6 +672,16 @@ namespace IGC
670672
RMEncoding getEncoderRoundingMode_FP(ERoundingMode FP_RM);
671673
RMEncoding getEncoderRoundingMode_FPCvtInt(ERoundingMode FCvtI_RM);
672674

675+
enum PreemptionEncoding
676+
{
677+
PreemptionDisabled = 0x000,
678+
PreemptionEnabled = 0x800
679+
};
680+
681+
PreemptionEncoding getEncoderPreemptionMode(EPreemptionMode preemptionMode);
682+
683+
void SetPreemptionMode(PreemptionEncoding actualPreemptionMode, PreemptionEncoding newPreemptionMode);
684+
673685
unsigned GetRawOpndSplitOffset(VISA_Exec_Size fromExecSize,
674686
VISA_Exec_Size toExecSize,
675687
unsigned thePart, CVariable* var) const;

IGC/Compiler/CISACodeGen/ComputeShaderCodeGen.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ namespace IGC
3737
m_disableMidThreadPreemption = true;
3838
}
3939

40+
bool GetDisableMidthreadPreemption()
41+
{
42+
return m_disableMidThreadPreemption;
43+
}
44+
4045
bool HasSLM() const { return m_hasSLM; }
4146
bool HasFullDispatchMask() override;
4247

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ EmitPass::EmitPass(CShaderProgram::KernelShaderMap& shaders, SIMDMode mode, bool
126126
m_canAbortOnSpill(canAbortOnSpill),
127127
m_roundingMode_FP(ERoundingMode::ROUND_TO_NEAREST_EVEN),
128128
m_roundingMode_FPCvtInt(ERoundingMode::ROUND_TO_ZERO),
129+
m_preemptionMode(EPreemptionMode::PREEMPTION_ENABLED),
129130
m_pSignature(pSignature),
130131
m_isDuplicate(false)
131132
{
@@ -1238,6 +1239,8 @@ bool EmitPass::runOnFunction(llvm::Function& F)
12381239
(m_currShader->GetContext()->m_instrTypes.numLoopInsts == 0) &&
12391240
(m_currShader->ProgramOutput()->m_InstructionCount < IGC_GET_FLAG_VALUE(MidThreadPreemptionDisableThreshold)))
12401241
{
1242+
m_preemptionMode = PREEMPTION_DISABLED;
1243+
12411244
if (m_currShader->GetShaderType() == ShaderType::COMPUTE_SHADER)
12421245
{
12431246
CComputeShader* csProgram = static_cast<CComputeShader*>(m_currShader);
@@ -16759,6 +16762,15 @@ void EmitPass::SetRoundingMode_FPCvtInt(ERoundingMode newRM_FPCvtInt)
1675916762
}
1676016763
}
1676116764

16765+
void EmitPass::SetPreemptionMode(EPreemptionMode newPreemptionMode)
16766+
{
16767+
if (newPreemptionMode != m_preemptionMode)
16768+
{
16769+
m_encoder->SetPreemptionMode(m_preemptionMode, newPreemptionMode);
16770+
m_preemptionMode = newPreemptionMode;
16771+
}
16772+
}
16773+
1676216774
// Return true if inst needs specific rounding mode; false otherwise.
1676316775
//
1676416776
// Currently, only gen intrinsic needs rounding mode other than the default.
@@ -22585,8 +22597,12 @@ void EmitPass::emitTraceRay(TraceRayIntrinsic* I, bool RayQueryEnable)
2258522597
offsetof(RTStackFormat::TraceRayMessage::Header, rayQueryLocation) / sizeof(DWORD);
2258622598
static_assert(RayQueryDword == 4, "header change?");
2258722599

22588-
CVariable* RayQueryVal =
22589-
m_currShader->ImmToVariable(RayQueryEnable ? 1 : 0, ISA_TYPE_UD);
22600+
uint64_t rayQueryHeader = 0x0;
22601+
22602+
rayQueryHeader |= RayQueryEnable ? 1 : 0;
22603+
22604+
CVariable* RayQueryVal =
22605+
m_currShader->ImmToVariable(rayQueryHeader, ISA_TYPE_UD);
2259022606

2259122607
m_encoder->SetSimdSize(SIMDMode::SIMD1);
2259222608
m_encoder->SetNoMask();
@@ -22631,6 +22647,7 @@ void EmitPass::emitTraceRay(TraceRayIntrinsic* I, bool RayQueryEnable)
2263122647
m_currShader->m_SIMDSize >= SIMDMode::SIMD16 ? 1 : 0,
2263222648
true,
2263322649
RayQueryEnable);
22650+
2263422651
CVariable* pMessDesc =
2263522652
m_currShader->ImmToVariable(messageSpecificControl, ISA_TYPE_UD);
2263622653

@@ -22666,7 +22683,9 @@ void EmitPass::emitTraceRay(TraceRayIntrinsic* I, bool RayQueryEnable)
2266622683
payload,
2266722684
extDescriptor,
2266822685
exDesc,
22669-
pMessDesc);
22686+
pMessDesc,
22687+
false);
22688+
2267022689
m_encoder->Push();
2267122690
}
2267222691

@@ -22715,6 +22734,7 @@ void EmitPass::emitReadTraceRaySync(llvm::GenIntrinsicInst* I)
2271522734
m_encoder->Push();
2271622735
}
2271722736

22737+
2271822738
void EmitPass::emitBTD(
2271922739
CVariable* GlobalBufferPtr,
2272022740
CVariable* StackID,

IGC/Compiler/CISACodeGen/EmitVISAPass.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ class EmitPass : public llvm::FunctionPass
471471
void emitSyncStackID(llvm::GenIntrinsicInst* I);
472472
void emitTraceRay(llvm::TraceRayIntrinsic *I, bool RayQueryEnable);
473473
void emitReadTraceRaySync(llvm::GenIntrinsicInst* I);
474+
475+
474476
void emitBTD(
475477
CVariable* GlobalBufferPtr,
476478
CVariable* StackID,
@@ -806,6 +808,8 @@ class EmitPass : public llvm::FunctionPass
806808
ERoundingMode m_roundingMode_FP;
807809
ERoundingMode m_roundingMode_FPCvtInt;
808810

811+
EPreemptionMode m_preemptionMode;
812+
809813
uint m_currentBlock = (uint) -1;
810814

811815
bool m_currFuncHasSubroutine = false;
@@ -877,6 +881,9 @@ class EmitPass : public llvm::FunctionPass
877881
void SetRoundingMode_FPCvtInt(ERoundingMode RM_FPCvtInt);
878882
bool setRMExplicitly(llvm::Instruction* inst);
879883
void ResetRoundingMode(llvm::Instruction* inst);
884+
885+
void SetPreemptionMode(EPreemptionMode newPreemptionMode);
886+
880887
// returns true if the instruction does not care about the rounding mode settings
881888
bool ignoreRoundingMode(llvm::Instruction* inst) const;
882889

IGC/Compiler/CISACodeGen/helper.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,26 @@ namespace IGC
223223
return type;
224224
}
225225

226+
///
227+
/// returns the number of exit blocks iin the given function.
228+
///
229+
unsigned getNumberOfExitBlocks(llvm::Function& function)
230+
{
231+
unsigned numberOfExitBlocks = 0;
232+
233+
for (llvm::BasicBlock& block : function.getBasicBlockList())
234+
{
235+
llvm::Instruction* terminator = block.getTerminator();
236+
237+
if (llvm::isa_and_nonnull<ReturnInst>(terminator))
238+
{
239+
++numberOfExitBlocks;
240+
}
241+
}
242+
243+
return numberOfExitBlocks;
244+
}
245+
226246
///
227247
/// returns constant buffer load offset
228248
///
@@ -1413,6 +1433,40 @@ namespace IGC
14131433
return false;
14141434
}
14151435

1436+
bool isBarrierIntrinsic(const llvm::Instruction* I)
1437+
{
1438+
const GenIntrinsicInst* GII = dyn_cast<GenIntrinsicInst>(I);
1439+
if (!GII)
1440+
return false;
1441+
1442+
switch (GII->getIntrinsicID())
1443+
{
1444+
case GenISAIntrinsic::GenISA_threadgroupbarrier:
1445+
case GenISAIntrinsic::GenISA_threadgroupbarrier_signal:
1446+
case GenISAIntrinsic::GenISA_threadgroupbarrier_wait:
1447+
case GenISAIntrinsic::GenISA_threadgroupnamedbarriers_signal:
1448+
case GenISAIntrinsic::GenISA_threadgroupnamedbarriers_wait:
1449+
case GenISAIntrinsic::GenISA_wavebarrier:
1450+
return true;
1451+
default:
1452+
return false;
1453+
}
1454+
}
1455+
1456+
bool isUserFunctionCall(const llvm::Instruction* I)
1457+
{
1458+
const CallInst* callInst = dyn_cast<CallInst>(I);
1459+
1460+
// Return true if:
1461+
// 1. callInst->getCalledFunction() == nullptr, this means indirect function call
1462+
// OR
1463+
// 2. Called function is not an Intrinsic.
1464+
bool isUserFunction = (callInst != nullptr) &&
1465+
((callInst->getCalledFunction() == nullptr) || !callInst->getCalledFunction()->isIntrinsic());
1466+
1467+
return isUserFunction;
1468+
}
1469+
14161470
bool isURBWriteIntrinsic(const llvm::Instruction* I)
14171471
{
14181472
const GenIntrinsicInst* GII = dyn_cast<GenIntrinsicInst>(I);

IGC/Compiler/CISACodeGen/helper.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ namespace IGC
178178
bool isSubGroupIntrinsicPVC(const llvm::Instruction* I);
179179
bool hasSubGroupIntrinsicPVC(llvm::Function& F);
180180

181+
bool isBarrierIntrinsic(const llvm::Instruction* I);
182+
183+
bool isUserFunctionCall(const llvm::Instruction* I);
184+
181185
bool IsStatelessMemLoadIntrinsic(llvm::GenISAIntrinsic::ID id);
182186
bool IsStatelessMemStoreIntrinsic(llvm::GenISAIntrinsic::ID id);
183187
bool IsStatelessMemAtomicIntrinsic(llvm::GenIntrinsicInst& inst, llvm::GenISAIntrinsic::ID id);
@@ -198,6 +202,8 @@ namespace IGC
198202
BufferType DecodeBufferType(unsigned addrSpace);
199203
int getConstantBufferLoadOffset(llvm::LoadInst* ld);
200204

205+
unsigned getNumberOfExitBlocks(llvm::Function& function);
206+
201207
bool isDummyBasicBlock(llvm::BasicBlock* BB);
202208

203209
bool IsDirectIdx(unsigned addrSpace);

IGC/Compiler/CodeGenPublicEnums.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,12 @@ namespace IGC
264264
ROUND_TO_ANY // dont care
265265
};
266266

267+
enum EPreemptionMode
268+
{
269+
PREEMPTION_ENABLED,
270+
PREEMPTION_DISABLED
271+
};
272+
267273

268274

269275
enum DISPATCH_SHADER_RAY_INFO_TYPE : unsigned int

0 commit comments

Comments
 (0)