Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions quest/src/comm/comm_routines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "quest/src/gpu/gpu_config.hpp"
#include "quest/src/comm/comm_config.hpp"
#include "quest/src/comm/comm_indices.hpp"
#include "quest/src/cpu/cpu_subroutines.hpp"
#include "quest/src/core/utilities.hpp"
#include <map>
#include <vector>

#if QUEST_COMPILE_MPI
#include <mpi.h>
Expand Down Expand Up @@ -827,3 +831,52 @@ vector<string> comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars)
return {};
#endif
}

void comm_exchangeFusedMultiSwap(Qureg qureg, ConstList64 ctrls, ConstList64 ctrlStates, const std::map<int, int>& swapMap) {
assert_commQuregIsDistributed(qureg);

#if QUEST_COMPILE_MPI
int k = swapMap.size();
if (k == 0) return;

// GPU fallback: sync GPU amps to CPU, perform fused swap on CPU, sync back
if (qureg.isGpuAccelerated)
syncQuregFromGpu(qureg);

qindex chunkSize = qureg.numAmpsPerNode >> k;

std::vector<int> prefixTargs;
for (auto const& [s, p] : swapMap) {
prefixTargs.push_back(p);
}

int myPrefixBits = 0;
for (int i = 0; i < k; i++) {
if (util_getRankBitOfQubit(prefixTargs[i], qureg)) {
myPrefixBits |= (1 << i);
}
}

qcomp* sendBuffer = qureg.cpuCommBuffer;
qcomp* recvBuffer = qureg.cpuCommBuffer + chunkSize;

for (int s = 1; s < (1 << k); s++) {
int target_m = myPrefixBits ^ s;
int pairRank = qureg.rank;
for (int i = 0; i < k; i++) {
if ((s >> i) & 1) {
pairRank = flipBit(pairRank, util_getPrefixInd(prefixTargs[i], qureg));
}
}

cpu_statevec_packFusedMultiSwapBuffers(qureg, swapMap, target_m, sendBuffer);
exchangeArrays(sendBuffer, recvBuffer, chunkSize, pairRank);
cpu_statevec_unpackFusedMultiSwapBuffers(qureg, swapMap, target_m, recvBuffer);
}

if (qureg.isGpuAccelerated)
syncQuregToGpu(qureg);
#else
error_commButEnvNotDistributed();
#endif
}
2 changes: 2 additions & 0 deletions quest/src/comm/comm_routines.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <vector>
#include <string>
#include <map>

using std::vector;

Expand All @@ -39,6 +40,7 @@ void comm_combineAmpsIntoBuffer(Qureg receiver, Qureg sender);

void comm_combineElemsIntoBuffer(Qureg receiver, FullStateDiagMatr sender);

void comm_exchangeFusedMultiSwap(Qureg qureg, ConstList64 ctrls, ConstList64 ctrlStates, const std::map<int, int>& swapMap);


/*
Expand Down
21 changes: 12 additions & 9 deletions quest/src/core/localiser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "quest/src/core/paulilogic.hpp"
#include "quest/src/core/localiser.hpp"
#include "quest/src/core/accelerator.hpp"

#include <map>
#include "quest/src/comm/comm_config.hpp"
#include "quest/src/comm/comm_routines.hpp"
#include "quest/src/cpu/cpu_config.hpp"
Expand Down Expand Up @@ -909,16 +911,17 @@ void anyCtrlMultiSwapBetweenPrefixAndSuffix(Qureg qureg, ConstList64 ctrls, Cons
/// although the latter requires substantially more work like setting up
/// a communicator which may be inelegant alongside our own distribution scheme.

// perform necessary swaps to move all targets into suffix, each of which invokes communication
for (size_t i=0; i<targsA.size(); i++) {

if (targsA[i] == targsB[i])
continue;

int suffixTarg = std::min(targsA[i], targsB[i]);
int prefixTarg = std::max(targsA[i], targsB[i]);
anyCtrlSwapBetweenPrefixAndSuffix(qureg, ctrls, ctrlStates, suffixTarg, prefixTarg);
// Use bitMap to record the final destination of each qubits
std::map<int, int> swapMap;
for (size_t i = 0; i < targsA.size(); i++) {
if (targsA[i] != targsB[i]) {
swapMap[std::min(targsA[i], targsB[i])] = std::max(targsA[i], targsB[i]);
}
}
if (swapMap.empty()) {
return;
}
comm_exchangeFusedMultiSwap(qureg, ctrls, ctrlStates, swapMap);
}


Expand Down
76 changes: 76 additions & 0 deletions quest/src/cpu/cpu_subroutines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,82 @@ qindex cpu_statevec_packPairSummedAmpsIntoBuffer(Qureg qureg, int qubit1, int qu
INSTANTIATE_FUNC_OPTIMISED_FOR_NUM_TARGS( qindex, cpu_statevec_packAmpsIntoBuffer, (Qureg, ConstList64, ConstList64) )


// Pack local amplitudes whose suffix bits (at swap positions) match target_m
// into a contiguous buffer for MPI exchange with the partner node.
template <int NumQubits>
void cpu_statevec_packFusedMultiSwapBuffers_sub(Qureg qureg, const std::map<int, int>& swapMap, int target_m, qcomp* buffer) {

List64 sortedSufTargs = lists_getEmptyList64();
List64 targetStates = lists_getEmptyList64();
int bitIndex = 0;

for (auto const& [suf, pre] : swapMap) {
sortedSufTargs.push_back(suf);
targetStates.push_back((target_m >> bitIndex) & 1);
bitIndex++;
}

SET_VAR_AT_COMPILE_TIME(int, k, NumQubits, swapMap.size());
qindex numIts = qureg.numAmpsPerNode >> k;
qindex qubitStateMask = util_getBitMask(sortedSufTargs, targetStates);

cpu_qcomp* amps = getCpuQcompPtr(qureg.cpuAmps);
cpu_qcomp* outBuffer = getCpuQcompPtr(buffer);

#pragma omp parallel for schedule(static) if(qureg.isMultithreaded)
for (qindex n = 0; n < numIts; n++) {
qindex i = insertBitsWithMaskedValues(n, sortedSufTargs.data(), k, qubitStateMask);
outBuffer[n] = amps[i];
}
}

void cpu_statevec_packFusedMultiSwapBuffers(Qureg qureg, const std::map<int, int>& swapMap, int target_m, qcomp* buffer) {
int k = swapMap.size();
if (k == 1) cpu_statevec_packFusedMultiSwapBuffers_sub<1>(qureg, swapMap, target_m, buffer);
else if (k == 2) cpu_statevec_packFusedMultiSwapBuffers_sub<2>(qureg, swapMap, target_m, buffer);
else if (k == 3) cpu_statevec_packFusedMultiSwapBuffers_sub<3>(qureg, swapMap, target_m, buffer);
else if (k == 4) cpu_statevec_packFusedMultiSwapBuffers_sub<4>(qureg, swapMap, target_m, buffer);
else if (k == 5) cpu_statevec_packFusedMultiSwapBuffers_sub<5>(qureg, swapMap, target_m, buffer);
else cpu_statevec_packFusedMultiSwapBuffers_sub<-1>(qureg, swapMap, target_m, buffer);
}

template <int NumQubits>
void cpu_statevec_unpackFusedMultiSwapBuffers_sub(Qureg qureg, const std::map<int, int>& swapMap, int target_m, qcomp* buffer) {

List64 sortedSufTargs = lists_getEmptyList64();
List64 targetStates = lists_getEmptyList64();
int bitIndex = 0;

for (auto const& [suf, pre] : swapMap) {
sortedSufTargs.push_back(suf);
targetStates.push_back((target_m >> bitIndex) & 1);
bitIndex++;
}

SET_VAR_AT_COMPILE_TIME(int, k, NumQubits, swapMap.size());
qindex numIts = qureg.numAmpsPerNode >> k;
qindex qubitStateMask = util_getBitMask(sortedSufTargs, targetStates);

cpu_qcomp* amps = getCpuQcompPtr(qureg.cpuAmps);
cpu_qcomp* inBuffer = getCpuQcompPtr(buffer);

#pragma omp parallel for schedule(static) if(qureg.isMultithreaded)
for (qindex n = 0; n < numIts; n++) {
qindex i = insertBitsWithMaskedValues(n, sortedSufTargs.data(), k, qubitStateMask);
amps[i] = inBuffer[n];
}
}

void cpu_statevec_unpackFusedMultiSwapBuffers(Qureg qureg, const std::map<int, int>& swapMap, int target_m, qcomp* buffer) {
int k = swapMap.size();
if (k == 1) cpu_statevec_unpackFusedMultiSwapBuffers_sub<1>(qureg, swapMap, target_m, buffer);
else if (k == 2) cpu_statevec_unpackFusedMultiSwapBuffers_sub<2>(qureg, swapMap, target_m, buffer);
else if (k == 3) cpu_statevec_unpackFusedMultiSwapBuffers_sub<3>(qureg, swapMap, target_m, buffer);
else if (k == 4) cpu_statevec_unpackFusedMultiSwapBuffers_sub<4>(qureg, swapMap, target_m, buffer);
else if (k == 5) cpu_statevec_unpackFusedMultiSwapBuffers_sub<5>(qureg, swapMap, target_m, buffer);
else cpu_statevec_unpackFusedMultiSwapBuffers_sub<-1>(qureg, swapMap, target_m, buffer);
}


/*
* SWAPS
Expand Down
5 changes: 5 additions & 0 deletions quest/src/cpu/cpu_subroutines.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "quest/src/core/utilities.hpp"

#include <vector>
#include <map>

using std::vector;

Expand Down Expand Up @@ -48,6 +49,10 @@ template <int NumQubits> qindex cpu_statevec_packAmpsIntoBuffer(Qureg qureg, Con

qindex cpu_statevec_packPairSummedAmpsIntoBuffer(Qureg qureg, int qubit1, int qubit2, int qubit3, int bit2);

void cpu_statevec_packFusedMultiSwapBuffers(Qureg qureg, const std::map<int, int>& swapMap, int target_m, qcomp* buffer);

void cpu_statevec_unpackFusedMultiSwapBuffers(Qureg qureg, const std::map<int, int>& swapMap, int target_m, qcomp* buffer);


/*
* SWAPS
Expand Down