diff --git a/quest/src/comm/comm_routines.cpp b/quest/src/comm/comm_routines.cpp index cf6956454..e37aec37d 100644 --- a/quest/src/comm/comm_routines.cpp +++ b/quest/src/comm/comm_routines.cpp @@ -21,6 +21,10 @@ #include "quest/src/gpu/gpu_config.hpp" #include "quest/src/comm/comm_config.hpp" #include "quest/src/comm/comm_indices.hpp" +#include "quest/src/cpu/cpu_subroutines.hpp" +#include "quest/src/core/utilities.hpp" +#include +#include #if QUEST_COMPILE_MPI #include @@ -827,3 +831,52 @@ vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars) return {}; #endif } + +void comm_exchangeFusedMultiSwap(Qureg qureg, ConstList64 ctrls, ConstList64 ctrlStates, const std::map& swapMap) { + assert_commQuregIsDistributed(qureg); + +#if QUEST_COMPILE_MPI + int k = swapMap.size(); + if (k == 0) return; + + // GPU fallback: sync GPU amps to CPU, perform fused swap on CPU, sync back + if (qureg.isGpuAccelerated) + syncQuregFromGpu(qureg); + + qindex chunkSize = qureg.numAmpsPerNode >> k; + + std::vector prefixTargs; + for (auto const& [s, p] : swapMap) { + prefixTargs.push_back(p); + } + + int myPrefixBits = 0; + for (int i = 0; i < k; i++) { + if (util_getRankBitOfQubit(prefixTargs[i], qureg)) { + myPrefixBits |= (1 << i); + } + } + + qcomp* sendBuffer = qureg.cpuCommBuffer; + qcomp* recvBuffer = qureg.cpuCommBuffer + chunkSize; + + for (int s = 1; s < (1 << k); s++) { + int target_m = myPrefixBits ^ s; + int pairRank = qureg.rank; + for (int i = 0; i < k; i++) { + if ((s >> i) & 1) { + pairRank = flipBit(pairRank, util_getPrefixInd(prefixTargs[i], qureg)); + } + } + + cpu_statevec_packFusedMultiSwapBuffers(qureg, swapMap, target_m, sendBuffer); + exchangeArrays(sendBuffer, recvBuffer, chunkSize, pairRank); + cpu_statevec_unpackFusedMultiSwapBuffers(qureg, swapMap, target_m, recvBuffer); + } + + if (qureg.isGpuAccelerated) + syncQuregToGpu(qureg); +#else + error_commButEnvNotDistributed(); +#endif +} diff --git a/quest/src/comm/comm_routines.hpp b/quest/src/comm/comm_routines.hpp index e75e889f6..e5bdad685 100644 --- a/quest/src/comm/comm_routines.hpp +++ b/quest/src/comm/comm_routines.hpp @@ -16,6 +16,7 @@ #include #include +#include using std::vector; @@ -39,6 +40,7 @@ void comm_combineAmpsIntoBuffer(Qureg receiver, Qureg sender); void comm_combineElemsIntoBuffer(Qureg receiver, FullStateDiagMatr sender); +void comm_exchangeFusedMultiSwap(Qureg qureg, ConstList64 ctrls, ConstList64 ctrlStates, const std::map& swapMap); /* diff --git a/quest/src/core/localiser.cpp b/quest/src/core/localiser.cpp index 83a23b921..0892f5293 100644 --- a/quest/src/core/localiser.cpp +++ b/quest/src/core/localiser.cpp @@ -23,6 +23,8 @@ #include "quest/src/core/paulilogic.hpp" #include "quest/src/core/localiser.hpp" #include "quest/src/core/accelerator.hpp" + +#include #include "quest/src/comm/comm_config.hpp" #include "quest/src/comm/comm_routines.hpp" #include "quest/src/cpu/cpu_config.hpp" @@ -909,16 +911,17 @@ void anyCtrlMultiSwapBetweenPrefixAndSuffix(Qureg qureg, ConstList64 ctrls, Cons /// although the latter requires substantially more work like setting up /// a communicator which may be inelegant alongside our own distribution scheme. - // perform necessary swaps to move all targets into suffix, each of which invokes communication - for (size_t i=0; i swapMap; + for (size_t i = 0; i < targsA.size(); i++) { + if (targsA[i] != targsB[i]) { + swapMap[std::min(targsA[i], targsB[i])] = std::max(targsA[i], targsB[i]); + } + } + if (swapMap.empty()) { + return; } + comm_exchangeFusedMultiSwap(qureg, ctrls, ctrlStates, swapMap); } diff --git a/quest/src/cpu/cpu_subroutines.cpp b/quest/src/cpu/cpu_subroutines.cpp index 59df946e9..4bb1a1b54 100644 --- a/quest/src/cpu/cpu_subroutines.cpp +++ b/quest/src/cpu/cpu_subroutines.cpp @@ -284,6 +284,82 @@ qindex cpu_statevec_packPairSummedAmpsIntoBuffer(Qureg qureg, int qubit1, int qu INSTANTIATE_FUNC_OPTIMISED_FOR_NUM_TARGS( qindex, cpu_statevec_packAmpsIntoBuffer, (Qureg, ConstList64, ConstList64) ) +// Pack local amplitudes whose suffix bits (at swap positions) match target_m +// into a contiguous buffer for MPI exchange with the partner node. +template +void cpu_statevec_packFusedMultiSwapBuffers_sub(Qureg qureg, const std::map& swapMap, int target_m, qcomp* buffer) { + + List64 sortedSufTargs = lists_getEmptyList64(); + List64 targetStates = lists_getEmptyList64(); + int bitIndex = 0; + + for (auto const& [suf, pre] : swapMap) { + sortedSufTargs.push_back(suf); + targetStates.push_back((target_m >> bitIndex) & 1); + bitIndex++; + } + + SET_VAR_AT_COMPILE_TIME(int, k, NumQubits, swapMap.size()); + qindex numIts = qureg.numAmpsPerNode >> k; + qindex qubitStateMask = util_getBitMask(sortedSufTargs, targetStates); + + cpu_qcomp* amps = getCpuQcompPtr(qureg.cpuAmps); + cpu_qcomp* outBuffer = getCpuQcompPtr(buffer); + + #pragma omp parallel for schedule(static) if(qureg.isMultithreaded) + for (qindex n = 0; n < numIts; n++) { + qindex i = insertBitsWithMaskedValues(n, sortedSufTargs.data(), k, qubitStateMask); + outBuffer[n] = amps[i]; + } +} + +void cpu_statevec_packFusedMultiSwapBuffers(Qureg qureg, const std::map& swapMap, int target_m, qcomp* buffer) { + int k = swapMap.size(); + if (k == 1) cpu_statevec_packFusedMultiSwapBuffers_sub<1>(qureg, swapMap, target_m, buffer); + else if (k == 2) cpu_statevec_packFusedMultiSwapBuffers_sub<2>(qureg, swapMap, target_m, buffer); + else if (k == 3) cpu_statevec_packFusedMultiSwapBuffers_sub<3>(qureg, swapMap, target_m, buffer); + else if (k == 4) cpu_statevec_packFusedMultiSwapBuffers_sub<4>(qureg, swapMap, target_m, buffer); + else if (k == 5) cpu_statevec_packFusedMultiSwapBuffers_sub<5>(qureg, swapMap, target_m, buffer); + else cpu_statevec_packFusedMultiSwapBuffers_sub<-1>(qureg, swapMap, target_m, buffer); +} + +template +void cpu_statevec_unpackFusedMultiSwapBuffers_sub(Qureg qureg, const std::map& swapMap, int target_m, qcomp* buffer) { + + List64 sortedSufTargs = lists_getEmptyList64(); + List64 targetStates = lists_getEmptyList64(); + int bitIndex = 0; + + for (auto const& [suf, pre] : swapMap) { + sortedSufTargs.push_back(suf); + targetStates.push_back((target_m >> bitIndex) & 1); + bitIndex++; + } + + SET_VAR_AT_COMPILE_TIME(int, k, NumQubits, swapMap.size()); + qindex numIts = qureg.numAmpsPerNode >> k; + qindex qubitStateMask = util_getBitMask(sortedSufTargs, targetStates); + + cpu_qcomp* amps = getCpuQcompPtr(qureg.cpuAmps); + cpu_qcomp* inBuffer = getCpuQcompPtr(buffer); + + #pragma omp parallel for schedule(static) if(qureg.isMultithreaded) + for (qindex n = 0; n < numIts; n++) { + qindex i = insertBitsWithMaskedValues(n, sortedSufTargs.data(), k, qubitStateMask); + amps[i] = inBuffer[n]; + } +} + +void cpu_statevec_unpackFusedMultiSwapBuffers(Qureg qureg, const std::map& swapMap, int target_m, qcomp* buffer) { + int k = swapMap.size(); + if (k == 1) cpu_statevec_unpackFusedMultiSwapBuffers_sub<1>(qureg, swapMap, target_m, buffer); + else if (k == 2) cpu_statevec_unpackFusedMultiSwapBuffers_sub<2>(qureg, swapMap, target_m, buffer); + else if (k == 3) cpu_statevec_unpackFusedMultiSwapBuffers_sub<3>(qureg, swapMap, target_m, buffer); + else if (k == 4) cpu_statevec_unpackFusedMultiSwapBuffers_sub<4>(qureg, swapMap, target_m, buffer); + else if (k == 5) cpu_statevec_unpackFusedMultiSwapBuffers_sub<5>(qureg, swapMap, target_m, buffer); + else cpu_statevec_unpackFusedMultiSwapBuffers_sub<-1>(qureg, swapMap, target_m, buffer); +} + /* * SWAPS diff --git a/quest/src/cpu/cpu_subroutines.hpp b/quest/src/cpu/cpu_subroutines.hpp index 3dbae057b..ac642ce4d 100644 --- a/quest/src/cpu/cpu_subroutines.hpp +++ b/quest/src/cpu/cpu_subroutines.hpp @@ -16,6 +16,7 @@ #include "quest/src/core/utilities.hpp" #include +#include using std::vector; @@ -48,6 +49,10 @@ template qindex cpu_statevec_packAmpsIntoBuffer(Qureg qureg, Con qindex cpu_statevec_packPairSummedAmpsIntoBuffer(Qureg qureg, int qubit1, int qubit2, int qubit3, int bit2); +void cpu_statevec_packFusedMultiSwapBuffers(Qureg qureg, const std::map& swapMap, int target_m, qcomp* buffer); + +void cpu_statevec_unpackFusedMultiSwapBuffers(Qureg qureg, const std::map& swapMap, int target_m, qcomp* buffer); + /* * SWAPS