diff --git a/source/source_basis/module_pw/pw_gatherscatter.h b/source/source_basis/module_pw/pw_gatherscatter.h index 207320f4268..eddeddb0f76 100644 --- a/source/source_basis/module_pw/pw_gatherscatter.h +++ b/source/source_basis/module_pw/pw_gatherscatter.h @@ -5,6 +5,41 @@ namespace ModulePW { +namespace detail +{ +// Copy complex buffers through the interleaved scalar stream so compilers can +// vectorize the contiguous real/imaginary data movement. +template +inline void copy_complex_buffer(const std::complex* in, std::complex* out, const int count) +{ + const T* __restrict__ in_r = reinterpret_cast(in); + T* __restrict__ out_r = reinterpret_cast(out); +#ifdef __GNUC__ +#pragma GCC ivdep +#endif + for (int i = 0; i < 2 * count; ++i) + { + out_r[i] = in_r[i]; + } +} + +// Top-level transform copies own the OpenMP parallel region; gather/scatter +// loops call the non-parallel helper inside their existing parallel regions. +template +inline void copy_complex_buffer_parallel(const std::complex* in, std::complex* out, const int count) +{ + const T* __restrict__ in_r = reinterpret_cast(in); + T* __restrict__ out_r = reinterpret_cast(out); +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (int i = 0; i < 2 * count; ++i) + { + out_r[i] = in_r[i]; + } +} +} // namespace detail + /** * @brief gather planes and scatter sticks * @param in: (nplane,fftny,fftnx) @@ -27,12 +62,9 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const for(int is = 0 ; is < nst_ ; ++is) { int ixy = istot2ixy_[is]; - std::complex *outp = &out[is*nz_]; - std::complex *inp = &in[ixy*nz_]; - for(int iz = 0 ; iz < nz_ ; ++iz) - { - outp[iz] = inp[iz]; - } + std::complex* outp = &out[is*nz_]; + const std::complex* inp = &in[ixy*nz_]; + detail::copy_complex_buffer(inp, outp, nz_); } return; } @@ -50,12 +82,9 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const for (int istot = 0; istot < nstot_gps; ++istot) { int ixy = istot2ixy_gps[istot]; - std::complex *outp = &out[istot * nplane_gps]; - std::complex *inp = &in[ixy * nplane_gps]; - for (int iz = 0; iz < nplane_gps; ++iz) - { - outp[iz] = inp[iz]; - } + std::complex* outp = &out[istot * nplane_gps]; + const std::complex* inp = &in[ixy * nplane_gps]; + detail::copy_complex_buffer(inp, outp, nplane_gps); } //exchange data @@ -90,12 +119,9 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const int nzip = numz_gps[ip]; std::complex *outp0 = &out[startz_gps[ip]]; std::complex *inp0 = &in[startg_gps[ip]]; - std::complex *outp = &outp0[is * nz_gps]; - std::complex *inp = &inp0[is * nzip ]; - for (int izip = 0; izip < nzip; ++izip) - { - outp[izip] = inp[izip]; - } + std::complex* outp = &outp0[is * nz_gps]; + const std::complex* inp = &inp0[is * nzip ]; + detail::copy_complex_buffer(inp, outp, nzip); } } #endif @@ -132,12 +158,9 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const for(int is = 0 ; is < nst_ ; ++is) { int ixy = istot2ixy_[is]; - std::complex *outp = &out[ixy*nz_]; - std::complex *inp = &in[is*nz_]; - for(int iz = 0 ; iz < nz_ ; ++iz) - { - outp[iz] = inp[iz]; - } + std::complex* outp = &out[ixy*nz_]; + const std::complex* inp = &in[is*nz_]; + detail::copy_complex_buffer(inp, outp, nz_); } return; } @@ -162,12 +185,9 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const int nzip = numz_[ip]; std::complex *outp0 = &out[startg_[ip]]; std::complex *inp0 = &in[startz_[ip]]; - std::complex *outp = &outp0[is * nzip]; - std::complex *inp = &inp0[is * nz_ ]; - for (int izip = 0; izip < nzip; ++izip) - { - outp[izip] = inp[izip]; - } + std::complex* outp = &outp0[is * nzip]; + const std::complex* inp = &inp0[is * nz_ ]; + detail::copy_complex_buffer(inp, outp, nzip); } } @@ -205,12 +225,9 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { int ixy = istot2ixy[istot]; //int ixy = (ixy / fftny)*ny + ixy % fftny; - std::complex *outp = &out[ixy * nplane]; - std::complex *inp = &in[istot * nplane]; - for (int iz = 0; iz < nplane; ++iz) - { - outp[iz] = inp[iz]; - } + std::complex* outp = &out[ixy * nplane]; + const std::complex* inp = &in[istot * nplane]; + detail::copy_complex_buffer(inp, outp, nplane); } #endif return; diff --git a/source/source_basis/module_pw/pw_transform.cpp b/source/source_basis/module_pw/pw_transform.cpp index 220b353e9d4..dc867e0cd86 100644 --- a/source/source_basis/module_pw/pw_transform.cpp +++ b/source/source_basis/module_pw/pw_transform.cpp @@ -34,13 +34,7 @@ void PW_Basis::real2recip(const std::complex* in, const int npw_ = this->npw; const int nxyz_ = this->nxyz; const int* ig2isz_ = this->ig2isz; -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < nrxx_; ++ir) - { - this->fft_bundle.get_auxr_data()[ir] = in[ir]; - } + detail::copy_complex_buffer_parallel(in, this->fft_bundle.get_auxr_data(), nrxx_); this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); @@ -199,13 +193,7 @@ void PW_Basis::recip2real(const std::complex* in, } else { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < nrxx_; ++ir) - { - out[ir] = this->fft_bundle.get_auxr_data()[ir]; - } + detail::copy_complex_buffer_parallel(this->fft_bundle.get_auxr_data(), out, nrxx_); } ModuleBase::timer::end(this->classname, "recip2real"); } @@ -340,4 +328,4 @@ template void PW_Basis::recip2real(const std::complex* in, std::complex* out, const bool add, const double factor) const; -} // namespace ModulePW \ No newline at end of file +} // namespace ModulePW diff --git a/source/source_basis/module_pw/pw_transform_k.cpp b/source/source_basis/module_pw/pw_transform_k.cpp index a09aa2b686f..8c45e3d9b22 100644 --- a/source/source_basis/module_pw/pw_transform_k.cpp +++ b/source/source_basis/module_pw/pw_transform_k.cpp @@ -33,13 +33,7 @@ void PW_Basis_K::real2recip(const std::complex* in, assert(this->gamma_only == false); auto* auxr = this->fft_bundle.get_auxr_data(); -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < this->nrxx; ++ir) - { - auxr[ir] = in[ir]; - } + detail::copy_complex_buffer_parallel(in, auxr, this->nrxx); this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); @@ -200,13 +194,7 @@ void PW_Basis_K::recip2real(const std::complex* in, } else { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < this->nrxx; ++ir) - { - out[ir] = auxr[ir]; - } + detail::copy_complex_buffer_parallel(auxr, out, this->nrxx); } ModuleBase::timer::end(this->classname, "recip2real"); } diff --git a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp index 84932bae2ff..ad3e1764dd1 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include /************************************************ * serial unit test of functions in pw_basis.cpp @@ -188,4 +189,43 @@ TEST_F(PWBasisKTEST, CollectLocalPW) EXPECT_EQ(basis_k.npwk_max,2721); } +TEST_F(PWBasisKTEST, ComplexTransformRoundTrip) +{ + ModulePW::PW_Basis_K basis_k(device_flag, precision_double); + double lat0 = 2.0; + ModuleBase::Matrix3 latvec(1.0,0.0,1.0, + 0.0,2.0,0.0, + 0.0,0.0,2.0); + double gridecut = 30.0; + const bool gamma_only_in = false; + const double gk_ecut_in = 20.0; + const int nks_in = 1; + const ModuleBase::Vector3 kvec_d_in[1] = { {0.0, 0.0, 0.0} }; + const int distribution_type_in = 2; + const bool xprime_in = false; + + basis_k.initgrids(lat0, latvec, gridecut); + basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in); + ASSERT_NO_THROW(basis_k.setuptransform()); + // Use reciprocal-space input because arbitrary real-space data is projected + // by the plane-wave cutoff and is not exactly recoverable. + std::vector> recip_in(basis_k.npwk[0]); + std::vector> real_space(basis_k.nrxx); + std::vector> recip_out(basis_k.npwk[0]); + for (int ig = 0; ig < basis_k.npwk[0]; ++ig) + { + const double real_part = (ig % 17 - 8) / 11.0; + const double imag_part = (ig % 19 - 9) / 13.0; + recip_in[ig] = std::complex(real_part, imag_part); + } + + basis_k.recip2real(recip_in.data(), real_space.data(), 0); + basis_k.real2recip(real_space.data(), recip_out.data(), 0); + + for (int ig = 0; ig < basis_k.npwk[0]; ++ig) + { + EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10); + EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10); + } +} diff --git a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp index ea678b9d97c..57ac8f06554 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include /************************************************ * serial unit test of functions in pw_basis.cpp @@ -362,3 +363,41 @@ TEST_F(PWBasisTEST,CollectUniqgg) pwb.collect_uniqgg(); EXPECT_EQ(pwb.ngg,78); } + +TEST_F(PWBasisTEST,ComplexTransformRoundTrip) +{ + double lat0 = 2.0; + ModuleBase::Matrix3 latvec(1.0,0.0,1.0, + 0.0,2.0,0.0, + 0.0,0.0,2.0); + double gridecut = 30.0; + bool gamma_only_in = false; + double pwecut_in = 20.0; + int distribution_type_in = 2; + bool xprime_in = false; + + pwb.initgrids(lat0, latvec, gridecut); + pwb.initparameters(gamma_only_in, pwecut_in, distribution_type_in, xprime_in); + ASSERT_NO_THROW(pwb.setuptransform()); + + // Use reciprocal-space input because arbitrary real-space data is projected + // by the plane-wave cutoff and is not exactly recoverable. + std::vector> recip_in(pwb.npw); + std::vector> real_space(pwb.nrxx); + std::vector> recip_out(pwb.npw); + for (int ig = 0; ig < pwb.npw; ++ig) + { + const double real_part = (ig % 11 - 5) / 7.0; + const double imag_part = (ig % 13 - 6) / 9.0; + recip_in[ig] = std::complex(real_part, imag_part); + } + + pwb.recip2real(recip_in.data(), real_space.data()); + pwb.real2recip(real_space.data(), recip_out.data()); + + for (int ig = 0; ig < pwb.npw; ++ig) + { + EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10); + EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10); + } +}