Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/dxc/DXIL/DxilFunctionProps.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ struct DxilFunctionProps {
memset(&Node, 0, sizeof(Node));
Node.LaunchType = DXIL::NodeLaunchType::Invalid;
Node.LocalRootArgumentsTableIndex = -1;
groupSharedLimitBytes = 0;
}
union {
// Geometry shader.
Expand Down Expand Up @@ -174,6 +175,8 @@ struct DxilFunctionProps {
// numThreads shared between multiple shader types and node shaders.
unsigned numThreads[3];

unsigned groupSharedLimitBytes;

struct NodeProps {
DXIL::NodeLaunchType LaunchType = DXIL::NodeLaunchType::Invalid;
bool IsProgramEntry;
Expand Down
1 change: 1 addition & 0 deletions include/dxc/DXIL/DxilMetadataHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ class DxilMDHelper {
static const unsigned kDxilNodeOutputsTag = 21;
static const unsigned kDxilNodeMaxDispatchGridTag = 22;
static const unsigned kDxilRangedWaveSizeTag = 23;
static const unsigned kDxilGroupSharedLimitTag = 24;

// Node Input/Output State.
static const unsigned kDxilNodeOutputIDTag = 0;
Expand Down
2 changes: 2 additions & 0 deletions include/dxc/DXIL/DxilModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ class DxilModule {
void SetNumThreads(unsigned x, unsigned y, unsigned z);
unsigned GetNumThreads(unsigned idx) const;

unsigned GetGroupSharedLimit() const;

// Compute shader
DxilWaveSize &GetWaveSize();
const DxilWaveSize &GetWaveSize() const;
Expand Down
23 changes: 18 additions & 5 deletions include/dxc/DxilContainer/DxilPipelineStateValidation.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ struct PSVRuntimeInfo3 : public PSVRuntimeInfo2 {
uint32_t EntryFunctionName;
};

struct PSVRuntimeInfo4 : public PSVRuntimeInfo3 {
uint32_t GroupSharedLimit;
};

enum class PSVResourceType {
Invalid = 0,

Expand Down Expand Up @@ -474,7 +478,7 @@ class PSVSignatureElement {
const uint32_t *SemanticIndexes) const;
};

#define MAX_PSV_VERSION 3
#define MAX_PSV_VERSION 4

struct PSVInitInfo {
PSVInitInfo(uint32_t psvVersion) : PSVVersion(psvVersion) {}
Expand All @@ -491,7 +495,7 @@ struct PSVInitInfo {
uint8_t SigPatchConstOrPrimVectors = 0;
uint8_t SigOutputVectors[PSV_GS_MAX_STREAMS] = {0, 0, 0, 0};

static_assert(MAX_PSV_VERSION == 3, "otherwise this needs updating.");
static_assert(MAX_PSV_VERSION == 4, "otherwise this needs updating.");
uint32_t RuntimeInfoSize() const {
switch (PSVVersion) {
case 0:
Expand All @@ -500,10 +504,12 @@ struct PSVInitInfo {
return sizeof(PSVRuntimeInfo1);
case 2:
return sizeof(PSVRuntimeInfo2);
case 3:
return sizeof(PSVRuntimeInfo3);
default:
break;
}
return sizeof(PSVRuntimeInfo3);
return sizeof(PSVRuntimeInfo4);
}
uint32_t ResourceBindInfoSize() const {
if (PSVVersion < 2)
Expand All @@ -519,6 +525,7 @@ class DxilPipelineStateValidation {
PSVRuntimeInfo1 *m_pPSVRuntimeInfo1 = nullptr;
PSVRuntimeInfo2 *m_pPSVRuntimeInfo2 = nullptr;
PSVRuntimeInfo3 *m_pPSVRuntimeInfo3 = nullptr;
PSVRuntimeInfo4 *m_pPSVRuntimeInfo4 = nullptr;
uint32_t m_uResourceCount = 0;
uint32_t m_uPSVResourceBindInfoSize = 0;
void *m_pPSVResourceBindInfo = nullptr;
Expand Down Expand Up @@ -634,6 +641,8 @@ class DxilPipelineStateValidation {

PSVRuntimeInfo3 *GetPSVRuntimeInfo3() const { return m_pPSVRuntimeInfo3; }

PSVRuntimeInfo4 *GetPSVRuntimeInfo4() const { return m_pPSVRuntimeInfo4; }

uint32_t GetBindCount() const { return m_uResourceCount; }

template <typename _T>
Expand Down Expand Up @@ -949,6 +958,8 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize,
m_uPSVRuntimeInfoSize); // failure ok
AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0,
m_uPSVRuntimeInfoSize); // failure ok
AssignDerived(&m_pPSVRuntimeInfo4, m_pPSVRuntimeInfo0,
m_uPSVRuntimeInfoSize); // failure ok

// In RWMode::CalcSize, use temp runtime info to hold needed values from
// initInfo
Expand Down Expand Up @@ -1137,11 +1148,13 @@ void SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM);
void SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM);
void SetShaderProps(PSVRuntimeInfo1 *pInfo1, const DxilModule &DM);
void SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM);
void SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM);

void PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2,
PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind,
const char *EntryName, const char *Comment);
PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4,
uint8_t ShaderKind, const char *EntryName,
const char *Comment);

} // namespace hlsl

Expand Down
9 changes: 9 additions & 0 deletions lib/DXIL/DxilMetadataHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,10 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag,
}
MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal));
}

MDVals.emplace_back(
Uint32ToConstMD(DxilMDHelper::kDxilGroupSharedLimitTag));
MDVals.emplace_back(Uint32ToConstMD(props.groupSharedLimitBytes));
} break;
// Geometry shader.
case DXIL::ShaderKind::Geometry: {
Expand Down Expand Up @@ -1773,6 +1777,11 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO,
props.numThreads[2] = ConstMDToUint32(pNode->getOperand(2));
} break;

case DxilMDHelper::kDxilGroupSharedLimitTag: {
DXASSERT(props.IsCS(), "else invalid shader kind");
props.groupSharedLimitBytes = ConstMDToUint32(MDO);
} break;

case DxilMDHelper::kDxilGSStateTag: {
DXASSERT(props.IsGS(), "else invalid shader kind");
auto &GS = props.ShaderProps.GS;
Expand Down
9 changes: 9 additions & 0 deletions lib/DXIL/DxilModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,15 @@ unsigned DxilModule::GetNumThreads(unsigned idx) const {
return props.numThreads[idx];
}

unsigned DxilModule::GetGroupSharedLimit() const {
DXASSERT(m_DxilEntryPropsMap.size() == 1 &&
(m_pSM->IsCS() || m_pSM->IsMS() || m_pSM->IsAS()),
"only works for CS/MS/AS profiles");
const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
return props.groupSharedLimitBytes;
}

DxilWaveSize &DxilModule::GetWaveSize() {
return const_cast<DxilWaveSize &>(
static_cast<const DxilModule *>(this)->GetWaveSize());
Expand Down
4 changes: 4 additions & 0 deletions lib/DxilContainer/DxilContainerAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,8 @@ class DxilPSVWriter : public DxilPartWriter {
PSVRuntimeInfo1 *pInfo1 = m_PSV.GetPSVRuntimeInfo1();
PSVRuntimeInfo2 *pInfo2 = m_PSV.GetPSVRuntimeInfo2();
PSVRuntimeInfo3 *pInfo3 = m_PSV.GetPSVRuntimeInfo3();
PSVRuntimeInfo4 *pInfo4 = m_PSV.GetPSVRuntimeInfo4();

if (pInfo)
hlsl::SetShaderProps(pInfo, m_Module);
if (pInfo1)
Expand All @@ -806,6 +808,8 @@ class DxilPSVWriter : public DxilPartWriter {
hlsl::SetShaderProps(pInfo2, m_Module);
if (pInfo3)
pInfo3->EntryFunctionName = EntryFunctionName;
if (pInfo4)
hlsl::SetShaderProps(pInfo4, m_Module);

// Set resource binding information
UINT uResIndex = 0;
Expand Down
31 changes: 28 additions & 3 deletions lib/DxilContainer/DxilPipelineStateValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,20 @@ void hlsl::SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM) {
}
}

void hlsl::SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM) {
assert(pInfo4);
const ShaderModel *SM = DM.GetShaderModel();
switch (SM->GetKind()) {
case ShaderModel::Kind::Compute:
case ShaderModel::Kind::Mesh:
case ShaderModel::Kind::Amplification:
pInfo4->GroupSharedLimit = DM.GetGroupSharedLimit();
break;
default:
break;
}
}

void PSVResourceBindInfo0::Print(raw_ostream &OS) const {
OS << "PSVResourceBindInfo:\n";
OS << " Space: " << Space << "\n";
Expand Down Expand Up @@ -584,8 +598,9 @@ void PSVDependencyTable::Print(raw_ostream &OS, const char *InputSetName,

void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2,
PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind,
const char *EntryName, const char *Comment) {
PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO in this case following the pInfoN pattern is appropriate.

uint8_t ShaderKind, const char *EntryName,
const char *Comment) {
if (pInfo1 && pInfo1->ShaderStage != ShaderKind)
ShaderKind = pInfo1->ShaderStage;
OS << Comment << "PSVRuntimeInfo:\n";
Expand Down Expand Up @@ -808,13 +823,19 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
}
if (pInfo4) {
OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n";
}
break;
case PSVShaderKind::Amplification:
OS << Comment << " Amplification Shader\n";
if (pInfo2) {
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
}
if (pInfo4) {
OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n";
}
break;
case PSVShaderKind::Mesh:
OS << Comment << " Mesh Shader\n";
Expand All @@ -841,6 +862,9 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0,
OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << ","
<< pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n";
}
if (pInfo4) {
OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n";
}
break;
case PSVShaderKind::Library:
case PSVShaderKind::Invalid:
Expand Down Expand Up @@ -887,9 +911,10 @@ void DxilPipelineStateValidation::PrintPSVRuntimeInfo(
PSVRuntimeInfo1 *pInfo1 = m_pPSVRuntimeInfo1;
PSVRuntimeInfo2 *pInfo2 = m_pPSVRuntimeInfo2;
PSVRuntimeInfo3 *pInfo3 = m_pPSVRuntimeInfo3;
PSVRuntimeInfo4 *pInfo4 = m_pPSVRuntimeInfo4;

hlsl::PrintPSVRuntimeInfo(
OS, pInfo0, pInfo1, pInfo2, pInfo3, ShaderKind,
OS, pInfo0, pInfo1, pInfo2, pInfo3, pInfo4, ShaderKind,
m_pPSVRuntimeInfo3 ? m_StringTable.Get(pInfo3->EntryFunctionName) : "",
Comment);
}
Expand Down
7 changes: 4 additions & 3 deletions lib/DxilValidation/DxilContainerValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,13 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
PSVRuntimeInfo0 *PSV0,
PSVRuntimeInfo1 *PSV1,
PSVRuntimeInfo2 *PSV2) {
PSVRuntimeInfo3 DMPSV;
memset(&DMPSV, 0, sizeof(PSVRuntimeInfo3));
PSVRuntimeInfo4 DMPSV;
memset(&DMPSV, 0, sizeof(PSVRuntimeInfo4));

hlsl::SetShaderProps((PSVRuntimeInfo0 *)&DMPSV, DM);
hlsl::SetShaderProps((PSVRuntimeInfo1 *)&DMPSV, DM);
hlsl::SetShaderProps((PSVRuntimeInfo2 *)&DMPSV, DM);
hlsl::SetShaderProps((PSVRuntimeInfo4 *)&DMPSV, DM);
if (PSV1) {
// Init things not set in InitPSVRuntimeInfo.
DMPSV.ShaderStage = static_cast<uint8_t>(SM->GetKind());
Expand Down Expand Up @@ -447,7 +448,7 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM,
if (Mismatched) {
std::string Str;
raw_string_ostream OS(Str);
hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV,
hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, &DMPSV,
static_cast<uint8_t>(SM->GetKind()),
DM.GetEntryFunctionName().c_str(), "");
OS.flush();
Expand Down
12 changes: 12 additions & 0 deletions lib/DxilValidation/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3921,6 +3921,18 @@ static void ValidateGlobalVariables(ValidationContext &ValCtx) {
Rule = ValidationRule::SmMaxMSSMSize;
MaxSize = DXIL::kMaxMSSMSize;
}

// Check if the entry function has attribute to override TGSM size.
if (M.HasDxilEntryProps(M.GetEntryFunction())) {
DxilEntryProps &EntryProps = M.GetDxilEntryProps(M.GetEntryFunction());
if (EntryProps.props.IsCS()) {
unsigned SpecifiedTGSMSize = EntryProps.props.groupSharedLimitBytes;
if (SpecifiedTGSMSize > 0) {
MaxSize = SpecifiedTGSMSize;
}
}
}

if (TGSMSize > MaxSize) {
Module::global_iterator GI = M.GetModule()->global_end();
GlobalVariable *GV = &*GI;
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,11 @@ def HLSLNumThreads: InheritableAttr {
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
let Documentation = [Undocumented];
}
def HLSLGroupSharedLimit: InheritableAttr {
let Spellings = [CXX11<"", "GroupSharedLimit", 2017>];
let Args = [IntArgument<"Limit">];
let Documentation = [Undocumented];
}
def HLSLRootSignature: InheritableAttr {
let Spellings = [CXX11<"", "RootSignature", 2015>];
let Args = [StringArgument<"SignatureName">];
Expand Down
30 changes: 30 additions & 0 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,36 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
}
}

if (const HLSLGroupSharedLimitAttr *Attr =
FD->getAttr<HLSLGroupSharedLimitAttr>()) {
if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
unsigned DiagID = Diags.getCustomDiagID(
DiagnosticsEngine::Error,
"attribute GroupSharedLimit only valid for CS/MS/AS.");
Diags.Report(Attr->getLocation(), DiagID);
return;
}

// Only valid for SM6.10+
if (!SM->IsSM610Plus()) {
unsigned DiagID = Diags.getCustomDiagID(
DiagnosticsEngine::Error, "attribute GroupSharedLimit only valid for "
"Shader Model 6.10 and above.");
Diags.Report(Attr->getLocation(), DiagID);
return;
}

funcProps->groupSharedLimitBytes = Attr->getLimit();
} else {
if (SM->IsMS()) { // Fallback to default limits
funcProps->groupSharedLimitBytes = DXIL::kMaxMSSMSize; // 28k For MS
} else if (SM->IsAS() || SM->IsCS()) {
funcProps->groupSharedLimitBytes = DXIL::kMaxTGSMSize; // 32k For AS/CS
} else {
funcProps->groupSharedLimitBytes = 0;
}
}

// Hull shader.
if (const HLSLPatchConstantFuncAttr *Attr =
FD->getAttr<HLSLPatchConstantFuncAttr>()) {
Expand Down
1 change: 1 addition & 0 deletions tools/clang/lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ void Parser::ParseGNUAttributeArgs(IdentifierInfo *AttrName,
case AttributeList::AT_HLSLMaxVertexCount:
case AttributeList::AT_HLSLUnroll:
case AttributeList::AT_HLSLWaveSize:
case AttributeList::AT_HLSLGroupSharedLimit:
case AttributeList::AT_NoInline:
// The following are not accepted in [attribute(param)] syntax:
// case AttributeList::AT_HLSLCentroid:
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14656,6 +14656,11 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
S.Context.getAddrSpaceQualType(VD->getType(), DXIL::kTGSMAddrSpace));
}
break;
case AttributeList::AT_HLSLGroupSharedLimit:
declAttr = ::new (S.Context) HLSLGroupSharedLimitAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_HLSLUniform:
declAttr = ::new (S.Context) HLSLUniformAttr(
A.getRange(), S.Context, A.getAttributeSpellingListIndex());
Expand Down
Loading