Skip to content

Commit 754ca1f

Browse files
committed
[VULKAN] Add support for specialization constants
This commit introduces support for specialization constants in the Vulkan backend. Key changes: - Added struct to to represent a specialization constant with its ID, type, and value. - Updated YAML mapping in to parse specialization constants from the test configuration. - Modified to create and use when creating the compute pipeline, allowing specialization constants to be passed to the shader. - Added a new test case in to verify the functionality of specialization constants with various data types (bool, int, uint, float). Fixes llvm/llvm-project#142992
1 parent 32734ca commit 754ca1f

File tree

5 files changed

+281
-8
lines changed

5 files changed

+281
-8
lines changed

include/Support/Pipeline.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,18 @@ struct IOBindings {
282282
}
283283
};
284284

285+
struct SpecializationConstant {
286+
uint32_t ConstantID;
287+
DataFormat Type;
288+
std::string Value;
289+
};
290+
285291
struct Shader {
286292
Stages Stage;
287293
std::string Entry;
288294
std::unique_ptr<llvm::MemoryBuffer> Shader;
289295
int DispatchSize[3];
296+
llvm::SmallVector<SpecializationConstant> SpecializationConstants;
290297
};
291298

292299
struct Pipeline {
@@ -335,6 +342,7 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::Shader)
335342
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::dx::RootParameter)
336343
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::Result)
337344
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::VertexAttribute)
345+
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::SpecializationConstant)
338346

339347
namespace llvm {
340348
namespace yaml {
@@ -399,6 +407,10 @@ template <> struct MappingTraits<offloadtest::RuntimeSettings> {
399407
static void mapping(IO &I, offloadtest::RuntimeSettings &S);
400408
};
401409

410+
template <> struct MappingTraits<offloadtest::SpecializationConstant> {
411+
static void mapping(IO &I, offloadtest::SpecializationConstant &C);
412+
};
413+
402414
template <> struct ScalarEnumerationTraits<offloadtest::Rule> {
403415
static void enumeration(IO &I, offloadtest::Rule &V) {
404416
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::Rule::Val)

lib/API/VK/Device.cpp

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,34 @@
2020

2121
using namespace offloadtest;
2222

23-
#define VKFormats(FMT) \
23+
#define VKFormats(FMT, BITS) \
2424
if (Channels == 1) \
25-
return VK_FORMAT_R32_##FMT; \
25+
return VK_FORMAT_R##BITS##_##FMT; \
2626
if (Channels == 2) \
27-
return VK_FORMAT_R32G32_##FMT; \
27+
return VK_FORMAT_R##BITS##G##BITS##_##FMT; \
2828
if (Channels == 3) \
29-
return VK_FORMAT_R32G32B32_##FMT; \
29+
return VK_FORMAT_R##BITS##G##BITS##B##BITS##_##FMT; \
3030
if (Channels == 4) \
31-
return VK_FORMAT_R32G32B32A32_##FMT;
31+
return VK_FORMAT_R##BITS##G##BITS##B##BITS##A##BITS##_##FMT;
3232

3333
static VkFormat getVKFormat(DataFormat Format, int Channels) {
3434
switch (Format) {
35+
case DataFormat::Int16:
36+
VKFormats(SINT, 16) break;
37+
case DataFormat::UInt16:
38+
VKFormats(UINT, 16) break;
3539
case DataFormat::Int32:
36-
VKFormats(SINT) break;
40+
VKFormats(SINT, 32) break;
41+
case DataFormat::UInt32:
42+
VKFormats(UINT, 32) break;
3743
case DataFormat::Float32:
38-
VKFormats(SFLOAT) break;
44+
VKFormats(SFLOAT, 32) break;
45+
case DataFormat::Int64:
46+
VKFormats(SINT, 64) break;
47+
case DataFormat::UInt64:
48+
VKFormats(UINT, 64) break;
49+
case DataFormat::Float64:
50+
VKFormats(SFLOAT, 64) break;
3951
default:
4052
llvm_unreachable("Unsupported Resource format specified");
4153
}
@@ -1273,6 +1285,76 @@ class VKDevice : public offloadtest::Device {
12731285
return llvm::Error::success();
12741286
}
12751287

1288+
static void
1289+
parseSpecializationConstant(const SpecializationConstant &SpecConst,
1290+
VkSpecializationMapEntry &Entry,
1291+
llvm::SmallVector<char> &SpecData) {
1292+
Entry.constantID = SpecConst.ConstantID;
1293+
Entry.offset = SpecData.size();
1294+
switch (SpecConst.Type) {
1295+
case DataFormat::Float32: {
1296+
float Value = 0.0f;
1297+
double Tmp = 0.0;
1298+
llvm::StringRef(SpecConst.Value).getAsDouble(Tmp);
1299+
Value = static_cast<float>(Tmp);
1300+
Entry.size = sizeof(float);
1301+
SpecData.resize(SpecData.size() + sizeof(float));
1302+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(float));
1303+
break;
1304+
}
1305+
case DataFormat::Float64: {
1306+
double Value = 0.0;
1307+
llvm::StringRef(SpecConst.Value).getAsDouble(Value);
1308+
Entry.size = sizeof(double);
1309+
SpecData.resize(SpecData.size() + sizeof(double));
1310+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(double));
1311+
break;
1312+
}
1313+
case DataFormat::Int16: {
1314+
int16_t Value = 0;
1315+
llvm::StringRef(SpecConst.Value).getAsInteger(0, Value);
1316+
Entry.size = sizeof(int16_t);
1317+
SpecData.resize(SpecData.size() + sizeof(int16_t));
1318+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(int16_t));
1319+
break;
1320+
}
1321+
case DataFormat::UInt16: {
1322+
uint16_t Value = 0;
1323+
llvm::StringRef(SpecConst.Value).getAsInteger(0, Value);
1324+
Entry.size = sizeof(uint16_t);
1325+
SpecData.resize(SpecData.size() + sizeof(uint16_t));
1326+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(uint16_t));
1327+
break;
1328+
}
1329+
case DataFormat::Int32: {
1330+
int32_t Value = 0;
1331+
llvm::StringRef(SpecConst.Value).getAsInteger(0, Value);
1332+
Entry.size = sizeof(int32_t);
1333+
SpecData.resize(SpecData.size() + sizeof(int32_t));
1334+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(int32_t));
1335+
break;
1336+
}
1337+
case DataFormat::UInt32: {
1338+
uint32_t Value = 0;
1339+
llvm::StringRef(SpecConst.Value).getAsInteger(0, Value);
1340+
Entry.size = sizeof(uint32_t);
1341+
SpecData.resize(SpecData.size() + sizeof(uint32_t));
1342+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(uint32_t));
1343+
break;
1344+
}
1345+
case DataFormat::Bool: {
1346+
bool Value = false;
1347+
llvm::StringRef(SpecConst.Value).getAsInteger(0, Value);
1348+
Entry.size = sizeof(bool);
1349+
SpecData.resize(SpecData.size() + sizeof(bool));
1350+
memcpy(SpecData.data() + Entry.offset, &Value, sizeof(bool));
1351+
break;
1352+
}
1353+
default:
1354+
llvm_unreachable("Unsupported specialization constant type");
1355+
}
1356+
}
1357+
12761358
llvm::Error createPipeline(Pipeline &P, InvocationState &IS) {
12771359
VkPipelineCacheCreateInfo CacheCreateInfo = {};
12781360
CacheCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO;
@@ -1282,15 +1364,33 @@ class VKDevice : public offloadtest::Device {
12821364
"Failed to create pipeline cache.");
12831365

12841366
if (P.isCompute()) {
1285-
const CompiledShader &S = IS.Shaders[0];
1367+
const offloadtest::Shader &Shader = P.Shaders[0];
12861368
assert(IS.Shaders.size() == 1 &&
12871369
"Currently only support one compute shader");
1370+
const CompiledShader &S = IS.Shaders[0];
12881371
VkPipelineShaderStageCreateInfo StageInfo = {};
12891372
StageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
12901373
StageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
12911374
StageInfo.module = S.Shader;
12921375
StageInfo.pName = S.Entry.c_str();
12931376

1377+
llvm::SmallVector<VkSpecializationMapEntry> SpecEntries;
1378+
llvm::SmallVector<char> SpecData;
1379+
VkSpecializationInfo SpecInfo = {};
1380+
if (!Shader.SpecializationConstants.empty()) {
1381+
for (const auto &SpecConst : Shader.SpecializationConstants) {
1382+
VkSpecializationMapEntry Entry;
1383+
parseSpecializationConstant(SpecConst, Entry, SpecData);
1384+
SpecEntries.push_back(Entry);
1385+
}
1386+
1387+
SpecInfo.mapEntryCount = SpecEntries.size();
1388+
SpecInfo.pMapEntries = SpecEntries.data();
1389+
SpecInfo.dataSize = SpecData.size();
1390+
SpecInfo.pData = SpecData.data();
1391+
StageInfo.pSpecializationInfo = &SpecInfo;
1392+
}
1393+
12941394
VkComputePipelineCreateInfo PipelineCreateInfo = {};
12951395
PipelineCreateInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
12961396
PipelineCreateInfo.stage = StageInfo;

lib/Support/Pipeline.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ void MappingTraits<offloadtest::Shader>::mapping(IO &I,
372372
offloadtest::Shader &S) {
373373
I.mapRequired("Stage", S.Stage);
374374
I.mapRequired("Entry", S.Entry);
375+
I.mapOptional("SpecializationConstants", S.SpecializationConstants);
375376

376377
if (S.Stage == Stages::Compute) {
377378
// Stage-specific data, not sure if this should be optional
@@ -380,6 +381,7 @@ void MappingTraits<offloadtest::Shader>::mapping(IO &I,
380381
I.mapRequired("DispatchSize", MutableDispatchSize);
381382
}
382383
}
384+
383385
void MappingTraits<offloadtest::Result>::mapping(IO &I,
384386
offloadtest::Result &R) {
385387
I.mapRequired("Result", R.Name);
@@ -402,5 +404,13 @@ void MappingTraits<offloadtest::Result>::mapping(IO &I,
402404
break;
403405
}
404406
}
407+
408+
void MappingTraits<offloadtest::SpecializationConstant>::mapping(
409+
IO &I, offloadtest::SpecializationConstant &C) {
410+
I.mapRequired("ConstantID", C.ConstantID);
411+
I.mapRequired("Type", C.Type);
412+
I.mapRequired("Value", C.Value);
413+
}
414+
405415
} // namespace yaml
406416
} // namespace llvm
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#--- simple.hlsl
2+
// bool
3+
[[vk::constant_id(0)]]
4+
const bool spec_bool = false;
5+
RWBuffer<uint> OutBool : register(u0, space0);
6+
7+
// int
8+
[[vk::constant_id(1)]]
9+
const int spec_int = 0;
10+
RWBuffer<int> OutInt : register(u1, space0);
11+
12+
// unsigned int
13+
[[vk::constant_id(2)]]
14+
const uint spec_uint = 0;
15+
RWBuffer<uint> OutUInt : register(u2, space0);
16+
17+
// float
18+
[[vk::constant_id(3)]]
19+
const float spec_float = 0.0;
20+
RWBuffer<float> OutFloat : register(u3, space0);
21+
22+
// int with default value
23+
[[vk::constant_id(4)]]
24+
const int spec_int_default = 1234;
25+
RWBuffer<int> OutIntDefault : register(u4, space0);
26+
27+
[numthreads(1,1,1)]
28+
void main(uint GI : SV_GroupIndex) {
29+
OutBool[GI] = (uint)spec_bool;
30+
OutInt[GI] = spec_int;
31+
OutUInt[GI] = spec_uint;
32+
OutFloat[GI] = spec_float;
33+
OutIntDefault[GI] = spec_int_default;
34+
}
35+
#--- simple.yaml
36+
---
37+
Shaders:
38+
- Stage: Compute
39+
Entry: main
40+
DispatchSize: [1, 1, 1]
41+
SpecializationConstants:
42+
- { ConstantID: 0, Value: 1, Type: Bool }
43+
- { ConstantID: 1, Value: 42, Type: Int32 }
44+
- { ConstantID: 2, Value: 0xDEADBEEF, Type: UInt32 }
45+
- { ConstantID: 3, Value: 3.14, Type: Float32 }
46+
Buffers:
47+
- { Name: OutBool, Format: Int32, FillSize: 4 }
48+
- { Name: OutInt, Format: Int32, FillSize: 4 }
49+
- { Name: OutUInt, Format: UInt32, FillSize: 4 }
50+
- { Name: OutFloat, Format: Float32, FillSize: 4 }
51+
- { Name: OutIntDefault, Format: Int32, FillSize: 4 }
52+
DescriptorSets:
53+
- Resources:
54+
- Name: OutBool
55+
Kind: RWBuffer
56+
DirectXBinding: { Register: 0, Space: 0 }
57+
VulkanBinding: { Binding: 0 }
58+
- Name: OutInt
59+
Kind: RWBuffer
60+
DirectXBinding: { Register: 1, Space: 0 }
61+
VulkanBinding: { Binding: 1 }
62+
- Name: OutUInt
63+
Kind: RWBuffer
64+
DirectXBinding: { Register: 2, Space: 0 }
65+
VulkanBinding: { Binding: 2 }
66+
- Name: OutFloat
67+
Kind: RWBuffer
68+
DirectXBinding: { Register: 3, Space: 0 }
69+
VulkanBinding: { Binding: 3 }
70+
- Name: OutIntDefault
71+
Kind: RWBuffer
72+
DirectXBinding: { Register: 4, Space: 0 }
73+
VulkanBinding: { Binding: 4 }
74+
...
75+
#--- end
76+
77+
# REQUIRES: Vulkan
78+
79+
# RUN: split-file %s %t
80+
# RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/simple.hlsl
81+
# RUN: %offloader %t/simple.yaml %t.o | FileCheck %s
82+
83+
# CHECK: Data: [ 1 ]
84+
# CHECK: Data: [ 42 ]
85+
# CHECK: Data: [ 3735928559 ]
86+
# CHECK: Data: [ {{3.14.*}} ]
87+
# CHECK: Data: [ 1234 ]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#--- simple_64bit.hlsl
2+
// double
3+
[[vk::constant_id(0)]]
4+
const double spec_double = 0.0;
5+
RWStructuredBuffer<double> OutDouble : register(u0, space0);
6+
7+
// short
8+
[[vk::constant_id(1)]]
9+
const int16_t spec_short = 0;
10+
RWStructuredBuffer<int16_t> OutShort : register(u1, space0);
11+
12+
// ushort
13+
[[vk::constant_id(2)]]
14+
const uint16_t spec_ushort = 0;
15+
RWStructuredBuffer<uint16_t> OutUShort : register(u2, space0);
16+
17+
[numthreads(1,1,1)]
18+
void main(uint GI : SV_GroupIndex) {
19+
OutDouble[GI] = spec_double;
20+
OutShort[GI] = spec_short;
21+
OutUShort[GI] = spec_ushort;
22+
}
23+
#--- simple_64bit.yaml
24+
---
25+
Shaders:
26+
- Stage: Compute
27+
Entry: main
28+
DispatchSize: [1, 1, 1]
29+
SpecializationConstants:
30+
- { ConstantID: 0, Value: 2.718, Type: Float64 }
31+
- { ConstantID: 1, Value: 123, Type: Int16 }
32+
- { ConstantID: 2, Value: 456, Type: UInt16 }
33+
Buffers:
34+
- { Name: OutDouble, Format: Float64, Stride: 8, FillSize: 8 }
35+
- { Name: OutShort, Format: Int16, Stride: 2, FillSize: 2 }
36+
- { Name: OutUShort, Format: UInt16, Stride: 2, FillSize: 2 }
37+
DescriptorSets:
38+
- Resources:
39+
- Name: OutDouble
40+
Kind: RWStructuredBuffer
41+
DirectXBinding: { Register: 0, Space: 0 }
42+
VulkanBinding: { Binding: 0 }
43+
- Name: OutShort
44+
Kind: RWStructuredBuffer
45+
DirectXBinding: { Register: 1, Space: 0 }
46+
VulkanBinding: { Binding: 1 }
47+
- Name: OutUShort
48+
Kind: RWStructuredBuffer
49+
DirectXBinding: { Register: 2, Space: 0 }
50+
VulkanBinding: { Binding: 2 }
51+
...
52+
#--- end
53+
54+
# REQUIRES: Vulkan
55+
56+
# XFAIL: DXC
57+
58+
# RUN: split-file %s %t
59+
# RUN: %dxc_target -T cs_6_2 -enable-16bit-types -Fo %t.o %t/simple_64bit.hlsl
60+
# RUN: %offloader %t/simple_64bit.yaml %t.o | FileCheck %s
61+
62+
# CHECK: Data: [ {{2.718}} ]
63+
# CHECK: Data: [ 123 ]
64+
# CHECK: Data: [ 456 ]

0 commit comments

Comments
 (0)