diff --git a/source/source_basis/module_pw/pw_basis.cpp b/source/source_basis/module_pw/pw_basis.cpp index 549fec8e5a4..d3a85f49bda 100644 --- a/source/source_basis/module_pw/pw_basis.cpp +++ b/source/source_basis/module_pw/pw_basis.cpp @@ -39,6 +39,8 @@ PW_Basis:: ~PW_Basis() delete[] startr; delete[] ig2igg; delete[] gg_uniq; + this->comm_workbuf_float_.reset(); + this->comm_workbuf_double_.reset(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -124,9 +126,23 @@ void PW_Basis::getstartgr() { this->startr[ip] = this->startr[ip-1] + this->numr[ip-1]; } + this->allocate_comm_buffers(); return; } +void PW_Basis::allocate_comm_buffers() +{ + if (this->poolnproc <= 0) + { + return; + } + const std::size_t max_size = static_cast( + this->startr[this->poolnproc - 1] + this->numr[this->poolnproc - 1] + + this->startg[this->poolnproc - 1] + this->numg[this->poolnproc - 1]); + this->comm_workbuf_float_.reset(new std::complex[max_size]); + this->comm_workbuf_double_.reset(new std::complex[max_size]); +} + /// /// Collect planewaves on current core, and construct gg, gdirect, gcar according to ig2isz and is2fftixy. /// known: ig2isz, is2fftixy diff --git a/source/source_basis/module_pw/pw_basis.h b/source/source_basis/module_pw/pw_basis.h index b834cb0e0f4..d1700e62b6f 100644 --- a/source/source_basis/module_pw/pw_basis.h +++ b/source/source_basis/module_pw/pw_basis.h @@ -8,7 +8,10 @@ #include "source_base/vector3.h" #include #include "source_base/module_fft/fft_bundle.h" +#include #include +#include +#include #ifdef __MPI #include "mpi.h" #endif @@ -148,7 +151,7 @@ class PW_Basis //prepare for MPI_Alltoall void getstartgr(); - + void allocate_comm_buffers(); public: //collect gdirect, gcar, gg @@ -420,6 +423,9 @@ class PW_Basis template void gathers_scatterp(std::complex* in, std::complex* out) const; + template + std::complex* acquire_comm_workbuf(const int size) const; + public: //get fftixy2is; void getfftixy2is(int * fftixy2is) const; @@ -441,7 +447,25 @@ class PW_Basis std::string precision = "double"; ///< single, double, mixing bool double_data_ = true; ///< if has double data bool float_data_ = false; ///< if has float data + std::unique_ptr[]> comm_workbuf_float_; + std::unique_ptr[]> comm_workbuf_double_; }; + +template <> +inline std::complex* PW_Basis::acquire_comm_workbuf(const int size) const +{ + (void)size; + assert(this->comm_workbuf_float_ != nullptr); + return this->comm_workbuf_float_.get(); +} + +template <> +inline std::complex* PW_Basis::acquire_comm_workbuf(const int size) const +{ + (void)size; + assert(this->comm_workbuf_double_ != nullptr); + return this->comm_workbuf_double_.get(); +} } #endif // PWBASIS_H #include "pw_basis_sup.h" diff --git a/source/source_basis/module_pw/pw_gatherscatter.h b/source/source_basis/module_pw/pw_gatherscatter.h index 207320f4268..55695c9acf7 100644 --- a/source/source_basis/module_pw/pw_gatherscatter.h +++ b/source/source_basis/module_pw/pw_gatherscatter.h @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/timer.h" #include +#include namespace ModulePW { @@ -15,8 +16,9 @@ namespace ModulePW template void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const { - - if(this->poolnproc == 1) //In this case nst=nstot, nz = nplane, + ModuleBase::timer::start(this->classname, "gatherp_scatters"); + + if(this->poolnproc == 1) //In this case nst=nstot, nz = nplane, { const int nst_ = this->nst; const int nz_ = this->nz; @@ -34,6 +36,7 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const outp[iz] = inp[iz]; } } + ModuleBase::timer::end(this->classname, "gatherp_scatters"); return; } @@ -41,64 +44,148 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const #ifdef __MPI //change (nplane fftnxy) to (nplane,nstot) // Hence, we can send them at one time. + ModuleBase::timer::start(this->classname, "gatherp_pack"); const int nstot_gps = this->nstot; const int nplane_gps = this->nplane; const int* istot2ixy_gps = this->istot2ixy; + const int* numg_gps = this->numg; + const int* numr_gps = this->numr; + const int* startg_gps = this->startg; + const int* startr_gps = this->startr; + const int poolrank_gps = this->poolrank; + const int poolnproc_gps = this->poolnproc; + const int send_count_gps = startr_gps[poolnproc_gps - 1] + numr_gps[poolnproc_gps - 1]; + const int recv_count_gps = startg_gps[poolnproc_gps - 1] + numg_gps[poolnproc_gps - 1]; + std::complex* commbuf = this->acquire_comm_workbuf(send_count_gps + recv_count_gps); + std::complex* sendbuf = commbuf; + // Keep a dedicated receive slice so ranks with zero local planes do not + // need their logical input array to also satisfy the receive-buffer bound. + std::complex* recvbuf = commbuf + send_count_gps; + if (nplane_gps > 0) + { #ifdef _OPENMP - #pragma omp parallel for + #pragma omp parallel for #endif - for (int istot = 0; istot < nstot_gps; ++istot) - { - int ixy = istot2ixy_gps[istot]; - std::complex *outp = &out[istot * nplane_gps]; - std::complex *inp = &in[ixy * nplane_gps]; - for (int iz = 0; iz < nplane_gps; ++iz) + for (int istot = 0; istot < nstot_gps; ++istot) { - outp[iz] = inp[iz]; + int ixy = istot2ixy_gps[istot]; + std::complex *outp = &sendbuf[istot * nplane_gps]; + std::complex *inp = &in[ixy * nplane_gps]; + for (int iz = 0; iz < nplane_gps; ++iz) + { + outp[iz] = inp[iz]; + } } } + ModuleBase::timer::end(this->classname, "gatherp_pack"); //exchange data //(nplane,nstot) to (numz[ip],ns, poolnproc) + MPI_Datatype mpi_type = MPI_DATATYPE_NULL; if(typeid(T) == typeid(double)) { - MPI_Alltoallv(out, numr, startr, MPI_DOUBLE_COMPLEX, in, numg, startg, MPI_DOUBLE_COMPLEX, this->pool_world); + mpi_type = MPI_DOUBLE_COMPLEX; } else if(typeid(T) == typeid(float)) { - MPI_Alltoallv(out, numr, startr, MPI_COMPLEX, in, numg, startg, MPI_COMPLEX, this->pool_world); + mpi_type = MPI_COMPLEX; } else { - ModuleBase::WARNING_QUIT("PW_Basis::gatherp_scatters", "Unsupported data type for MPI_Alltoallv"); + ModuleBase::WARNING_QUIT("PW_Basis::gatherp_scatters", "Unsupported data type for MPI gather/scatter"); + } + std::vector recv_requests(poolnproc_gps, MPI_REQUEST_NULL); + std::vector send_requests(poolnproc_gps, MPI_REQUEST_NULL); + std::vector recv_status(poolnproc_gps); + std::vector recv_indices(poolnproc_gps, MPI_UNDEFINED); + int active_recvs = 0; + int active_sends = 0; + + ModuleBase::timer::start(this->classname, "gatherp_alltoallv"); + for (int ip = 0; ip < poolnproc_gps; ++ip) + { + if (ip == poolrank_gps || numg_gps[ip] == 0) + { + continue; + } + MPI_Irecv(&recvbuf[startg_gps[ip]], numg_gps[ip], mpi_type, ip, 0, this->pool_world, &recv_requests[ip]); + ++active_recvs; + } + for (int ip = 0; ip < poolnproc_gps; ++ip) + { + if (ip == poolrank_gps || numr_gps[ip] == 0) + { + continue; + } + MPI_Isend(&sendbuf[startr_gps[ip]], numr_gps[ip], mpi_type, ip, 0, this->pool_world, &send_requests[ip]); + ++active_sends; } + ModuleBase::timer::end(this->classname, "gatherp_alltoallv"); // change (nz,ns) to (numz[ip],ns, poolnproc) - const int poolnproc_gps = this->poolnproc; const int nst_gps = this->nst; const int nz_gps = this->nz; const int* numz_gps = this->numz; - const int* startg_gps = this->startg; const int* startz_gps = this->startz; + auto unpack_peer = [&](const int ip) + { + const int nzip = numz_gps[ip]; #ifdef _OPENMP - #pragma omp parallel for collapse(2) + #pragma omp parallel for #endif - for (int ip = 0; ip < poolnproc_gps ;++ip) - { for (int is = 0; is < nst_gps; ++is) { - int nzip = numz_gps[ip]; - std::complex *outp0 = &out[startz_gps[ip]]; - std::complex *inp0 = &in[startg_gps[ip]]; - std::complex *outp = &outp0[is * nz_gps]; - std::complex *inp = &inp0[is * nzip ]; + std::complex *outp = &out[is * nz_gps + startz_gps[ip]]; + std::complex *inp = &recvbuf[startg_gps[ip] + is * nzip]; for (int izip = 0; izip < nzip; ++izip) { outp[izip] = inp[izip]; } } + }; + + ModuleBase::timer::start(this->classname, "gatherp_unpack"); +#ifdef _OPENMP + #pragma omp parallel for +#endif + for (int i = 0; i < numg_gps[poolrank_gps]; ++i) + { + recvbuf[startg_gps[poolrank_gps] + i] = sendbuf[startr_gps[poolrank_gps] + i]; + } + unpack_peer(poolrank_gps); + ModuleBase::timer::end(this->classname, "gatherp_unpack"); + + while (active_recvs > 0) + { + int outcount = 0; + ModuleBase::timer::start(this->classname, "gatherp_alltoallv"); + MPI_Waitsome(poolnproc_gps, + recv_requests.data(), + &outcount, + recv_indices.data(), + recv_status.data()); + ModuleBase::timer::end(this->classname, "gatherp_alltoallv"); + if (outcount == MPI_UNDEFINED) + { + break; + } + for (int idx = 0; idx < outcount; ++idx) + { + ModuleBase::timer::start(this->classname, "gatherp_unpack"); + unpack_peer(recv_indices[idx]); + ModuleBase::timer::end(this->classname, "gatherp_unpack"); + } + active_recvs -= outcount; + } + + if (active_sends > 0) + { + ModuleBase::timer::start(this->classname, "gatherp_alltoallv"); + MPI_Waitall(poolnproc_gps, send_requests.data(), MPI_STATUSES_IGNORE); + ModuleBase::timer::end(this->classname, "gatherp_alltoallv"); } #endif + ModuleBase::timer::end(this->classname, "gatherp_scatters"); return; } @@ -112,7 +199,8 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const template void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { - if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot, + ModuleBase::timer::start(this->classname, "gathers_scatterp"); + if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot, { const int nrxx_ = this->nrxx; const int nst_ = this->nst; @@ -139,19 +227,29 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const outp[iz] = inp[iz]; } } + ModuleBase::timer::end(this->classname, "gathers_scatterp"); return; } #ifdef __MPI // change (nz,ns) to (numz[ip],ns, poolnproc) - // Hence, we can send them at one time. + // Hence, we can send them at one time. + ModuleBase::timer::start(this->classname, "gathers_pack"); const int poolnproc_ = this->poolnproc; const int nst_ = this->nst; const int nz_ = this->nz; const int* numz_ = this->numz; const int* startg_ = this->startg; const int* startz_ = this->startz; + const int* nst_per_ = this->nst_per; + const int* startr_ = this->startr; + const int poolrank_ = this->poolrank; + const int send_count_ = startg_[poolnproc_ - 1] + this->numg[poolnproc_ - 1]; + const int recv_count_ = startr_[poolnproc_ - 1] + this->numr[poolnproc_ - 1]; + std::complex* commbuf = this->acquire_comm_workbuf(send_count_ + recv_count_); + std::complex* sendbuf = commbuf; + std::complex* recvbuf = commbuf + send_count_; #ifdef _OPENMP #pragma omp parallel for collapse(2) #endif @@ -160,7 +258,7 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const for (int is = 0; is < nst_; ++is) { int nzip = numz_[ip]; - std::complex *outp0 = &out[startg_[ip]]; + std::complex *outp0 = &sendbuf[startg_[ip]]; std::complex *inp0 = &in[startz_[ip]]; std::complex *outp = &outp0[is * nzip]; std::complex *inp = &inp0[is * nz_ ]; @@ -170,22 +268,52 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const } } } + ModuleBase::timer::end(this->classname, "gathers_pack"); //exchange data //(numz[ip],ns, poolnproc) to (nplane,nstot) + MPI_Datatype mpi_type = MPI_DATATYPE_NULL; if(typeid(T) == typeid(double)) { - MPI_Alltoallv(out, numg, startg, MPI_DOUBLE_COMPLEX, in, numr, startr, MPI_DOUBLE_COMPLEX, this->pool_world); + mpi_type = MPI_DOUBLE_COMPLEX; } else if(typeid(T) == typeid(float)) { - MPI_Alltoallv(out, numg, startg, MPI_COMPLEX, in, numr, startr, MPI_COMPLEX, this->pool_world); + mpi_type = MPI_COMPLEX; } else { - ModuleBase::WARNING_QUIT("PW_Basis::gathers_scatterp", "Unsupported data type for MPI_Alltoallv"); + ModuleBase::WARNING_QUIT("PW_Basis::gathers_scatterp", "Unsupported data type for MPI gather/scatter"); + } + std::vector recv_requests(poolnproc_, MPI_REQUEST_NULL); + std::vector send_requests(poolnproc_, MPI_REQUEST_NULL); + std::vector recv_status(poolnproc_); + std::vector recv_indices(poolnproc_, MPI_UNDEFINED); + int active_recvs = 0; + int active_sends = 0; + + ModuleBase::timer::start(this->classname, "gathers_alltoallv"); + for (int ip = 0; ip < poolnproc_; ++ip) + { + if (ip == poolrank_ || this->numr[ip] == 0) + { + continue; + } + MPI_Irecv(&recvbuf[startr_[ip]], this->numr[ip], mpi_type, ip, 0, this->pool_world, &recv_requests[ip]); + ++active_recvs; + } + for (int ip = 0; ip < poolnproc_; ++ip) + { + if (ip == poolrank_ || this->numg[ip] == 0) + { + continue; + } + MPI_Isend(&sendbuf[startg_[ip]], this->numg[ip], mpi_type, ip, 0, this->pool_world, &send_requests[ip]); + ++active_sends; } + ModuleBase::timer::end(this->classname, "gathers_alltoallv"); + ModuleBase::timer::start(this->classname, "gathers_clear"); const int nrxx_gsp = this->nrxx; #ifdef _OPENMP #pragma omp parallel for schedule(static) @@ -194,25 +322,82 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { out[i] = std::complex(0, 0); } + ModuleBase::timer::end(this->classname, "gathers_clear"); + //change (nplane,nstot) to (nplane fftnxy) - const int nstot = this->nstot; const int nplane = this->nplane; const int* istot2ixy = this->istot2ixy; + std::vector istot_offsets(poolnproc_, 0); + for (int ip = 1; ip < poolnproc_; ++ip) + { + istot_offsets[ip] = istot_offsets[ip - 1] + nst_per_[ip - 1]; + } + auto unpack_peer = [&](const int ip) + { + const int peer_nst = nst_per_[ip]; + if (peer_nst == 0 || nplane == 0) + { + return; + } + const int istot0 = istot_offsets[ip]; #ifdef _OPENMP -#pragma omp parallel for + #pragma omp parallel for +#endif + for (int is = 0; is < peer_nst; ++is) + { + const int istot = istot0 + is; + const int ixy = istot2ixy[istot]; + std::complex *outp = &out[ixy * nplane]; + std::complex *inp = &recvbuf[startr_[ip] + is * nplane]; + for (int iz = 0; iz < nplane; ++iz) + { + outp[iz] = inp[iz]; + } + } + }; + + ModuleBase::timer::start(this->classname, "gathers_unpack"); +#ifdef _OPENMP + #pragma omp parallel for #endif - for (int istot = 0;istot < nstot; ++istot) + for (int i = 0; i < this->numr[poolrank_]; ++i) { - int ixy = istot2ixy[istot]; - //int ixy = (ixy / fftny)*ny + ixy % fftny; - std::complex *outp = &out[ixy * nplane]; - std::complex *inp = &in[istot * nplane]; - for (int iz = 0; iz < nplane; ++iz) + recvbuf[startr_[poolrank_] + i] = sendbuf[startg_[poolrank_] + i]; + } + unpack_peer(poolrank_); + ModuleBase::timer::end(this->classname, "gathers_unpack"); + + while (active_recvs > 0) + { + int outcount = 0; + ModuleBase::timer::start(this->classname, "gathers_alltoallv"); + MPI_Waitsome(poolnproc_, + recv_requests.data(), + &outcount, + recv_indices.data(), + recv_status.data()); + ModuleBase::timer::end(this->classname, "gathers_alltoallv"); + if (outcount == MPI_UNDEFINED) { - outp[iz] = inp[iz]; + break; } + for (int idx = 0; idx < outcount; ++idx) + { + ModuleBase::timer::start(this->classname, "gathers_unpack"); + unpack_peer(recv_indices[idx]); + ModuleBase::timer::end(this->classname, "gathers_unpack"); + } + active_recvs -= outcount; + } + + if (active_sends > 0) + { + ModuleBase::timer::start(this->classname, "gathers_alltoallv"); + MPI_Waitall(poolnproc_, send_requests.data(), MPI_STATUSES_IGNORE); + ModuleBase::timer::end(this->classname, "gathers_alltoallv"); } #endif + ModuleBase::timer::end(this->classname, "gathers_scatterp"); return; } diff --git a/source/source_basis/module_pw/test/CMakeLists.txt b/source/source_basis/module_pw/test/CMakeLists.txt index b126791088f..0ffab75dcb9 100644 --- a/source/source_basis/module_pw/test/CMakeLists.txt +++ b/source/source_basis/module_pw/test/CMakeLists.txt @@ -15,7 +15,7 @@ AddTest( test6-1-1.cpp test6-1-2.cpp test6-2-1.cpp test6-2-2.cpp test6-3-1.cpp test6-4-1.cpp test6-4-2.cpp test7-1.cpp test6-2-1.cpp test7-3-1.cpp test7-3-2.cpp test8-1.cpp test8-2-1.cpp test8-3-1.cpp test8-3-2.cpp - test_tool.cpp test-big.cpp test-other.cpp test_sup.cpp + test_tool.cpp test-big.cpp test-other.cpp test_sup.cpp test_comm_roundtrip.cpp ) add_test(NAME MODULE_PW_pw_test_parallel diff --git a/source/source_basis/module_pw/test/test-big.cpp b/source/source_basis/module_pw/test/test-big.cpp index f1c2082d0b2..03b9520be23 100644 --- a/source/source_basis/module_pw/test/test-big.cpp +++ b/source/source_basis/module_pw/test/test-big.cpp @@ -53,7 +53,7 @@ TEST_F(PWTEST,test_big) pwktest.initgrids(lat0,latvec, pwtest.nx, pwtest.ny, pwtest.nz); pwtest.initparameters(gamma_only,wfcecut,distribution_type,xprime); pwktest.initparameters(gamma_only,wfcecut,nks,kvec_d,distribution_type, xprime); - static_cast(pwtest).setuptransform(); + static_cast(pwtest).setuptransform(); pwktest.setuptransform(); EXPECT_EQ(pwtest.nx%2, 0); EXPECT_EQ(pwtest.ny%2, 0); @@ -85,7 +85,7 @@ TEST_F(PWTEST,test_big) class TestPW_Basis_Big : public ::testing::Test { public: - ModulePW::PW_Basis_Big pwtest = ModulePW::PW_Basis_Big(); + ModulePW::PW_Basis_Big pwtest; }; // Test the function with nproc = 0 (bx and by) @@ -157,4 +157,4 @@ TEST_F(TestPW_Basis_Big, BzNprocNoResultTest) { int nproc = 5; pwtest.autoset_big_cell_size(b_size, nc_size, nproc); EXPECT_EQ(b_size, 3); -} \ No newline at end of file +} diff --git a/source/source_basis/module_pw/test/test_comm_roundtrip.cpp b/source/source_basis/module_pw/test/test_comm_roundtrip.cpp new file mode 100644 index 00000000000..1fb13bfaf3f --- /dev/null +++ b/source/source_basis/module_pw/test/test_comm_roundtrip.cpp @@ -0,0 +1,202 @@ +#include "../pw_basis.h" +#ifdef __MPI +#include "source_base/parallel_global.h" +#include "mpi.h" +#include "test_tool.h" +#endif +#include "source_base/global_function.h" +#include "pw_test.h" + +extern int nproc_in_pool, rank_in_pool; + +namespace +{ +class PW_Basis_Comm_Accessor : public ModulePW::PW_Basis +{ + public: + PW_Basis_Comm_Accessor(const std::string& device_, const std::string& precision_) + : ModulePW::PW_Basis(device_, precision_) + { + } + + using ModulePW::PW_Basis::gatherp_scatters; + using ModulePW::PW_Basis::gathers_scatterp; +}; + +template +void zero_complex_buffer(std::complex* data, const int size) +{ + for (int i = 0; i < size; ++i) + { + data[i] = std::complex(0.0, 0.0); + } +} + +template +void fill_plane_major_sticks(const BasisType& pw, std::complex* plane) +{ + zero_complex_buffer(plane, pw.nrxx); + for (int istot = 0; istot < pw.nstot; ++istot) + { + const int ixy = pw.istot2ixy[istot]; + for (int iz = 0; iz < pw.nplane; ++iz) + { + const int gz = pw.startz_current + iz; + const double real = (rank_in_pool + 1) * 1000.0 + istot * 10.0 + gz; + const double imag = (ixy + 1) * 0.25 + gz * 0.5; + plane[ixy * pw.nplane + iz] = std::complex(real, imag); + } + } +} + +template +void expect_plane_major_equal(const BasisType& pw, + const std::complex* expected, + const std::complex* actual) +{ + for (int ir = 0; ir < pw.nrxx; ++ir) + { + EXPECT_DOUBLE_EQ(expected[ir].real(), actual[ir].real()); + EXPECT_DOUBLE_EQ(expected[ir].imag(), actual[ir].imag()); + } +} + +template +int comm_roundtrip_work_size(const BasisType& pw) +{ + const int gather_size = pw.nst * pw.nz; + const int scatter_recv_size = pw.startr[pw.poolnproc - 1] + pw.numr[pw.poolnproc - 1]; + return std::max(gather_size, scatter_recv_size); +} + +template +void expect_stick_major_equal(const BasisType& pw, const std::complex* sticks) +{ + int istot0 = 0; + for (int ip = 0; ip < pw.poolrank; ++ip) + { + istot0 += pw.nst_per[ip]; + } + + for (int is = 0; is < pw.nst; ++is) + { + const int global_istot = istot0 + is; + const int ixy = pw.is2fftixy[is]; + EXPECT_EQ(ixy, pw.istot2ixy[global_istot]); + for (int iz = 0; iz < pw.nz; ++iz) + { + int owner = -1; + for (int ip = 0; ip < pw.poolnproc; ++ip) + { + if (iz >= pw.startz[ip] && iz < pw.startz[ip] + pw.numz[ip]) + { + owner = ip; + break; + } + } + EXPECT_GE(owner, 0); + const double real = (owner + 1) * 1000.0 + global_istot * 10.0 + iz; + const double imag = (ixy + 1) * 0.25 + iz * 0.5; + EXPECT_DOUBLE_EQ(real, sticks[is * pw.nz + iz].real()); + EXPECT_DOUBLE_EQ(imag, sticks[is * pw.nz + iz].imag()); + } + } +} + +bool case_has_zero_plane_stress(const int nx, const int ny, const int nz) +{ + PW_Basis_Comm_Accessor pwtest(device_flag, precision_flag); + ModuleBase::Matrix3 latvec(1, 0, 0, 0, 1, 0, 0, 0, 1); + const double lat0 = 4.0; + const double wfcecut = 20.0; + +#ifdef __MPI + pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD); +#endif + pwtest.initgrids(lat0, latvec, nx, ny, nz); + pwtest.initparameters(false, wfcecut, 1, true); + pwtest.setuptransform(); + +#ifdef __MPI + const int local_stress = (pwtest.nplane == 0 && pwtest.nst > 0) ? 1 : 0; + int any_stress = 0; + MPI_Allreduce(&local_stress, &any_stress, 1, MPI_INT, MPI_MAX, POOL_WORLD); + return any_stress == 1; +#else + return false; +#endif +} + +void run_comm_roundtrip_case(const int nx, const int ny, const int nz) +{ + PW_Basis_Comm_Accessor pwtest(device_flag, precision_flag); + ModuleBase::Matrix3 latvec(1, 0, 0, 0, 1, 0, 0, 0, 1); + const double lat0 = 4.0; + const double wfcecut = 20.0; + +#ifdef __MPI + pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD); +#endif + pwtest.initgrids(lat0, latvec, nx, ny, nz); + pwtest.initparameters(false, wfcecut, 1, true); + pwtest.setuptransform(); + + std::complex* plane_in = new std::complex[pwtest.nrxx]; + std::complex* plane_ref = new std::complex[pwtest.nrxx]; + std::complex* plane_out = new std::complex[pwtest.nrxx]; + const int sticks_work_size = comm_roundtrip_work_size(pwtest); + std::complex* sticks = new std::complex[sticks_work_size]; + + fill_plane_major_sticks(pwtest, plane_in); + for (int ir = 0; ir < pwtest.nrxx; ++ir) + { + plane_ref[ir] = plane_in[ir]; + } + zero_complex_buffer(plane_out, pwtest.nrxx); + zero_complex_buffer(sticks, sticks_work_size); + + pwtest.gatherp_scatters(plane_in, sticks); + expect_stick_major_equal(pwtest, sticks); + pwtest.gathers_scatterp(sticks, plane_out); + + expect_plane_major_equal(pwtest, plane_ref, plane_out); + + delete[] plane_in; + delete[] plane_ref; + delete[] plane_out; + delete[] sticks; +} +} // namespace + +TEST_F(PWTEST, test_comm_roundtrip_pw_basis) +{ + run_comm_roundtrip_case(10, 10, 10); +} + +TEST_F(PWTEST, test_comm_roundtrip_pw_basis_zero_plane_pressure) +{ + const int candidate_cases[][3] = { + {10, 10, 2}, + {16, 16, 2}, + {20, 20, 2}, + {24, 24, 2}, + {20, 20, 3}, + {24, 24, 3}, + {32, 16, 2}, + {32, 32, 2}, + }; + + for (const auto& candidate_case : candidate_cases) + { + const int nx = candidate_case[0]; + const int ny = candidate_case[1]; + const int nz = candidate_case[2]; + if (case_has_zero_plane_stress(nx, ny, nz)) + { + run_comm_roundtrip_case(nx, ny, nz); + return; + } + } + + GTEST_SKIP() << "No zero-plane/stick stress layout found for the current MPI decomposition."; +}