Skip to content

Commit 85f7653

Browse files
authored
Execution Tests: Long Vector - WaveMulti* Ops (#7925)
This PR addressed #7612. All new tests were verified against a private local build of WARP which includes fixes for WazveMulti* Ops with long vectors.
1 parent 9cb683f commit 85f7653

File tree

4 files changed

+302
-5
lines changed

4 files changed

+302
-5
lines changed

tools/clang/unittests/HLSLExec/LongVectorOps.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ INPUT_SET(Bitwise)
2020
INPUT_SET(SelectCond)
2121
INPUT_SET(FloatSpecial)
2222
INPUT_SET(AllOnes)
23+
INPUT_SET(WaveMultiPrefixBitwise)
2324

2425
#undef INPUT_SET
2526

@@ -213,5 +214,10 @@ OP_DEFAULT_DEFINES(Wave, WaveReadLaneAt, 1, "TestWaveReadLaneAt", "", " -DFUNC_W
213214
OP_DEFAULT_DEFINES(Wave, WaveReadLaneFirst, 1, "TestWaveReadLaneFirst", "", " -DFUNC_WAVE_READ_LANE_FIRST=1")
214215
OP_DEFAULT_DEFINES(Wave, WavePrefixSum, 1, "TestWavePrefixSum", "", " -DFUNC_WAVE_PREFIX_SUM=1 -DIS_WAVE_PREFIX_OP=1")
215216
OP_DEFAULT_DEFINES(Wave, WavePrefixProduct, 1, "TestWavePrefixProduct", "", " -DFUNC_WAVE_PREFIX_PRODUCT=1 -DIS_WAVE_PREFIX_OP=1")
217+
OP(Wave, WaveMultiPrefixSum, 1, "TestWaveMultiPrefixSum", "", " -DFUNC_WAVE_MULTI_PREFIX_SUM=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", Default1, Default2, Default3)
218+
OP(Wave, WaveMultiPrefixProduct, 1, "TestWaveMultiPrefixProduct", "", " -DFUNC_WAVE_MULTI_PREFIX_PRODUCT=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", Default1, Default2, Default3)
219+
OP(Wave, WaveMultiPrefixBitAnd, 1, "TestWaveMultiPrefixBitAnd", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_AND=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3)
220+
OP(Wave, WaveMultiPrefixBitOr, 1, "TestWaveMultiPrefixBitOr", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_OR=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3)
221+
OP(Wave, WaveMultiPrefixBitXor, 1, "TestWaveMultiPrefixBitXor", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_XOR=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3)
216222

217223
#undef OP

tools/clang/unittests/HLSLExec/LongVectorTestData.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ INPUT_SET(InputSet::Bitwise, std::numeric_limits<int16_t>::min(), -1, 0, 1, 3,
290290
std::numeric_limits<int16_t>::max());
291291
INPUT_SET(InputSet::SelectCond, 0, 1);
292292
INPUT_SET(InputSet::AllOnes, 1);
293+
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF,
294+
-1);
293295
END_INPUT_SETS()
294296

295297
BEGIN_INPUT_SETS(int32_t)
@@ -304,6 +306,8 @@ INPUT_SET(InputSet::Bitwise, std::numeric_limits<int32_t>::min(), -1, 0, 1, 3,
304306
std::numeric_limits<int32_t>::max());
305307
INPUT_SET(InputSet::SelectCond, 0, 1);
306308
INPUT_SET(InputSet::AllOnes, 1);
309+
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF,
310+
-1);
307311
END_INPUT_SETS()
308312

309313
BEGIN_INPUT_SETS(int64_t)
@@ -318,6 +322,8 @@ INPUT_SET(InputSet::Bitwise, std::numeric_limits<int64_t>::min(), -1, 0, 1, 3,
318322
std::numeric_limits<int64_t>::max());
319323
INPUT_SET(InputSet::SelectCond, 0, 1);
320324
INPUT_SET(InputSet::AllOnes, 1);
325+
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF,
326+
-1ll);
321327
END_INPUT_SETS()
322328

323329
BEGIN_INPUT_SETS(uint16_t)
@@ -330,6 +336,8 @@ INPUT_SET(InputSet::Bitwise, 0, 1, 3, 6, 9, 0x5555, 0xAAAA, 0x8000, 127,
330336
std::numeric_limits<uint16_t>::max());
331337
INPUT_SET(InputSet::SelectCond, 0, 1);
332338
INPUT_SET(InputSet::AllOnes, 1);
339+
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0x10, 0x12, 0xF,
340+
std::numeric_limits<uint16_t>::max());
333341
END_INPUT_SETS()
334342

335343
BEGIN_INPUT_SETS(uint32_t)
@@ -342,6 +350,8 @@ INPUT_SET(InputSet::Bitwise, 0, 1, 3, 6, 9, 0x55555555, 0xAAAAAAAA, 0x80000000,
342350
127, std::numeric_limits<uint32_t>::max());
343351
INPUT_SET(InputSet::SelectCond, 0, 1);
344352
INPUT_SET(InputSet::AllOnes, 1);
353+
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0xA, 0xC, 0xF,
354+
std::numeric_limits<uint32_t>::max());
345355
END_INPUT_SETS()
346356

347357
BEGIN_INPUT_SETS(uint64_t)
@@ -355,6 +365,8 @@ INPUT_SET(InputSet::Bitwise, 0, 1, 3, 6, 9, 0x5555555555555555,
355365
std::numeric_limits<uint64_t>::max());
356366
INPUT_SET(InputSet::SelectCond, 0, 1);
357367
INPUT_SET(InputSet::AllOnes, 1);
368+
INPUT_SET(InputSet::WaveMultiPrefixBitwise, 0x0, 0x1, 0x3, 0x4, 0xA, 0xC, 0xF,
369+
std::numeric_limits<uint64_t>::max());
358370
END_INPUT_SETS()
359371

360372
BEGIN_INPUT_SETS(HLSLHalf_t)

tools/clang/unittests/HLSLExec/LongVectors.cpp

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,7 +1392,7 @@ template <typename T> T waveActiveBitAnd(T A, UINT) {
13921392
WAVE_OP(OpType::WaveActiveBitAnd, (waveActiveBitAnd(A, WaveSize)));
13931393

13941394
template <typename T> T waveActiveBitOr(T A, UINT) {
1395-
// We set the LSB to 0 in one of the lanes.
1395+
// We set the LSB to 1 in one of the lanes.
13961396
return static_cast<T>(A | static_cast<T>(1));
13971397
}
13981398

@@ -1405,6 +1405,60 @@ template <typename T> T waveActiveBitXor(T A, UINT) {
14051405

14061406
WAVE_OP(OpType::WaveActiveBitXor, (waveActiveBitXor(A, WaveSize)));
14071407

1408+
WAVE_OP(OpType::WaveMultiPrefixBitAnd, waveMultiPrefixBitAnd(A, WaveSize));
1409+
1410+
template <typename T> T waveMultiPrefixBitAnd(T A, UINT) {
1411+
// All lanes in the group mask use a mask to filter for only the second and
1412+
// third LSBs.
1413+
return static_cast<T>(A & static_cast<T>(0x6));
1414+
}
1415+
1416+
WAVE_OP(OpType::WaveMultiPrefixBitOr, waveMultiPrefixBitOr(A, WaveSize));
1417+
1418+
template <typename T> T waveMultiPrefixBitOr(T A, UINT) {
1419+
// All lanes in the group mask clear the second LSB.
1420+
return static_cast<T>(A & ~static_cast<T>(0x2));
1421+
}
1422+
1423+
template <typename T>
1424+
struct Op<OpType::WaveMultiPrefixBitXor, T, 1> : StrictValidation {};
1425+
1426+
template <typename T> struct ExpectedBuilder<OpType::WaveMultiPrefixBitXor, T> {
1427+
static std::vector<T> buildExpected(Op<OpType::WaveMultiPrefixBitXor, T, 1> &,
1428+
const InputSets<T> &Inputs, UINT) {
1429+
DXASSERT_NOMSG(Inputs.size() == 1);
1430+
1431+
std::vector<T> Expected;
1432+
const size_t VectorSize = Inputs[0].size();
1433+
1434+
// We get a little creative for MultiPrefixBitXor. The mask we use for the
1435+
// group in the shader is 0xE (0b1110), which includes lanes 1, 2, and 3.
1436+
// Prefix ops don't include the value of the current lane in their result.
1437+
// So, for this test we store the result of WaveMultiPrefixBitXor from lane
1438+
// 3. This means only the values from lanes 1 and 2 contribute to the result
1439+
// at lane 3.
1440+
//
1441+
// In the shader:
1442+
// - Lane 0: Set to 0 (not in mask, shouldn't affect result)
1443+
// - Lane 1: Keeps original input values
1444+
// - Lane 2: Lower half + last element set to 0, upper half keeps input
1445+
// - Lane 3: Stores the prefix XOR result (lanes 1 XOR lanes 2)
1446+
//
1447+
// Expected result: Lower half matches input (lane 1 XOR 0), upper half is
1448+
// 0s, except last element matches input.
1449+
for (size_t I = 0; I < VectorSize / 2; ++I)
1450+
Expected.push_back(Inputs[0][I]);
1451+
for (size_t I = VectorSize / 2; I < VectorSize - 1; ++I)
1452+
Expected.push_back(0);
1453+
1454+
// We also set the last element to 0 on lane 2 so the last element in the
1455+
// output vector matches the last element in the input vector.
1456+
Expected.push_back(Inputs[0][VectorSize - 1]);
1457+
1458+
return Expected;
1459+
}
1460+
};
1461+
14081462
template <typename T>
14091463
struct Op<OpType::WaveActiveAllEqual, T, 1> : StrictValidation {};
14101464

@@ -1463,16 +1517,29 @@ template <typename T> struct ExpectedBuilder<OpType::WaveReadLaneFirst, T> {
14631517
WAVE_OP(OpType::WavePrefixSum, (wavePrefixSum(A, WaveSize)));
14641518

14651519
template <typename T> T wavePrefixSum(T A, UINT WaveSize) {
1466-
// We test the prefix sume in the 'middle' lane. This choice is arbitrary.
1467-
return static_cast<T>(A * static_cast<T>(WaveSize / 2));
1520+
// We test the prefix sum in the 'middle' lane. This choice is arbitrary.
1521+
return A * static_cast<T>(WaveSize / 2);
1522+
}
1523+
1524+
WAVE_OP(OpType::WaveMultiPrefixSum, (waveMultiPrefixSum(A, WaveSize)));
1525+
1526+
template <typename T> T waveMultiPrefixSum(T A, UINT) {
1527+
return A * static_cast<T>(2u);
14681528
}
14691529

14701530
WAVE_OP(OpType::WavePrefixProduct, (wavePrefixProduct(A, WaveSize)));
14711531

14721532
template <typename T> T wavePrefixProduct(T A, UINT) {
14731533
// We test the the prefix product in the 3rd lane to avoid overflow issues.
14741534
// So the result is A * A.
1475-
return static_cast<T>(A * A);
1535+
return A * A;
1536+
}
1537+
1538+
WAVE_OP(OpType::WaveMultiPrefixProduct, (waveMultiPrefixProduct(A, WaveSize)));
1539+
1540+
template <typename T> T waveMultiPrefixProduct(T A, UINT) {
1541+
// The group mask has 3 lanes.
1542+
return A * A;
14761543
}
14771544

14781545
#undef WAVE_OP
@@ -2404,6 +2471,11 @@ class DxilConf_SM69_Vectorized {
24042471
HLK_WAVEOP_TEST(WaveReadLaneFirst, int16_t);
24052472
HLK_WAVEOP_TEST(WavePrefixSum, int16_t);
24062473
HLK_WAVEOP_TEST(WavePrefixProduct, int16_t);
2474+
HLK_WAVEOP_TEST(WaveMultiPrefixSum, int16_t);
2475+
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, int16_t);
2476+
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int16_t);
2477+
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int16_t);
2478+
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int16_t);
24072479
HLK_WAVEOP_TEST(WaveActiveSum, int32_t);
24082480
HLK_WAVEOP_TEST(WaveActiveMin, int32_t);
24092481
HLK_WAVEOP_TEST(WaveActiveMax, int32_t);
@@ -2412,7 +2484,12 @@ class DxilConf_SM69_Vectorized {
24122484
HLK_WAVEOP_TEST(WaveReadLaneAt, int32_t);
24132485
HLK_WAVEOP_TEST(WaveReadLaneFirst, int32_t);
24142486
HLK_WAVEOP_TEST(WavePrefixSum, int32_t);
2487+
HLK_WAVEOP_TEST(WaveMultiPrefixSum, int32_t);
2488+
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, int32_t);
24152489
HLK_WAVEOP_TEST(WavePrefixProduct, int32_t);
2490+
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int32_t);
2491+
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int32_t);
2492+
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int32_t);
24162493
HLK_WAVEOP_TEST(WaveActiveSum, int64_t);
24172494
HLK_WAVEOP_TEST(WaveActiveMin, int64_t);
24182495
HLK_WAVEOP_TEST(WaveActiveMax, int64_t);
@@ -2422,7 +2499,14 @@ class DxilConf_SM69_Vectorized {
24222499
HLK_WAVEOP_TEST(WaveReadLaneFirst, int64_t);
24232500
HLK_WAVEOP_TEST(WavePrefixSum, int64_t);
24242501
HLK_WAVEOP_TEST(WavePrefixProduct, int64_t);
2502+
HLK_WAVEOP_TEST(WaveMultiPrefixSum, int64_t);
2503+
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, int64_t);
2504+
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int64_t);
2505+
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int64_t);
2506+
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int64_t);
24252507

2508+
// Note: WaveActiveBit* ops don't support uint16_t in HLSL
2509+
// But the WaveMultiPrefixBit ops support all int and uint types
24262510
HLK_WAVEOP_TEST(WaveActiveSum, uint16_t);
24272511
HLK_WAVEOP_TEST(WaveActiveMin, uint16_t);
24282512
HLK_WAVEOP_TEST(WaveActiveMax, uint16_t);
@@ -2432,11 +2516,15 @@ class DxilConf_SM69_Vectorized {
24322516
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint16_t);
24332517
HLK_WAVEOP_TEST(WavePrefixSum, uint16_t);
24342518
HLK_WAVEOP_TEST(WavePrefixProduct, uint16_t);
2519+
HLK_WAVEOP_TEST(WaveMultiPrefixSum, uint16_t);
2520+
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, uint16_t);
2521+
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint16_t);
2522+
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint16_t);
2523+
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint16_t);
24352524
HLK_WAVEOP_TEST(WaveActiveSum, uint32_t);
24362525
HLK_WAVEOP_TEST(WaveActiveMin, uint32_t);
24372526
HLK_WAVEOP_TEST(WaveActiveMax, uint32_t);
24382527
HLK_WAVEOP_TEST(WaveActiveProduct, uint32_t);
2439-
// Note: WaveActiveBit* ops don't support uint16_t in HLSL
24402528
HLK_WAVEOP_TEST(WaveActiveBitAnd, uint32_t);
24412529
HLK_WAVEOP_TEST(WaveActiveBitOr, uint32_t);
24422530
HLK_WAVEOP_TEST(WaveActiveBitXor, uint32_t);
@@ -2445,6 +2533,11 @@ class DxilConf_SM69_Vectorized {
24452533
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint32_t);
24462534
HLK_WAVEOP_TEST(WavePrefixSum, uint32_t);
24472535
HLK_WAVEOP_TEST(WavePrefixProduct, uint32_t);
2536+
HLK_WAVEOP_TEST(WaveMultiPrefixSum, uint32_t);
2537+
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, uint32_t);
2538+
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint32_t);
2539+
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint32_t);
2540+
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint32_t);
24482541
HLK_WAVEOP_TEST(WaveActiveSum, uint64_t);
24492542
HLK_WAVEOP_TEST(WaveActiveMin, uint64_t);
24502543
HLK_WAVEOP_TEST(WaveActiveMax, uint64_t);
@@ -2457,6 +2550,11 @@ class DxilConf_SM69_Vectorized {
24572550
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint64_t);
24582551
HLK_WAVEOP_TEST(WavePrefixSum, uint64_t);
24592552
HLK_WAVEOP_TEST(WavePrefixProduct, uint64_t);
2553+
HLK_WAVEOP_TEST(WaveMultiPrefixSum, uint64_t);
2554+
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, uint64_t);
2555+
HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint64_t);
2556+
HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint64_t);
2557+
HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint64_t);
24602558

24612559
HLK_WAVEOP_TEST(WaveActiveSum, HLSLHalf_t);
24622560
HLK_WAVEOP_TEST(WaveActiveMin, HLSLHalf_t);
@@ -2467,6 +2565,8 @@ class DxilConf_SM69_Vectorized {
24672565
HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLHalf_t);
24682566
HLK_WAVEOP_TEST(WavePrefixSum, HLSLHalf_t);
24692567
HLK_WAVEOP_TEST(WavePrefixProduct, HLSLHalf_t);
2568+
HLK_WAVEOP_TEST(WaveMultiPrefixSum, HLSLHalf_t);
2569+
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, HLSLHalf_t);
24702570
HLK_WAVEOP_TEST(WaveActiveSum, float);
24712571
HLK_WAVEOP_TEST(WaveActiveMin, float);
24722572
HLK_WAVEOP_TEST(WaveActiveMax, float);
@@ -2476,6 +2576,8 @@ class DxilConf_SM69_Vectorized {
24762576
HLK_WAVEOP_TEST(WaveReadLaneFirst, float);
24772577
HLK_WAVEOP_TEST(WavePrefixSum, float);
24782578
HLK_WAVEOP_TEST(WavePrefixProduct, float);
2579+
HLK_WAVEOP_TEST(WaveMultiPrefixSum, float);
2580+
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, float);
24792581
HLK_WAVEOP_TEST(WaveActiveSum, double);
24802582
HLK_WAVEOP_TEST(WaveActiveMin, double);
24812583
HLK_WAVEOP_TEST(WaveActiveMax, double);
@@ -2485,6 +2587,8 @@ class DxilConf_SM69_Vectorized {
24852587
HLK_WAVEOP_TEST(WaveReadLaneFirst, double);
24862588
HLK_WAVEOP_TEST(WavePrefixSum, double);
24872589
HLK_WAVEOP_TEST(WavePrefixProduct, double);
2590+
HLK_WAVEOP_TEST(WaveMultiPrefixSum, double);
2591+
HLK_WAVEOP_TEST(WaveMultiPrefixProduct, double);
24882592

24892593
private:
24902594
bool Initialized = false;

0 commit comments

Comments
 (0)