From 7c58a45ced643ebcb869408c6722b8babdb44382 Mon Sep 17 00:00:00 2001 From: Aunixt <2400011033@stu.pku.edu.cn> Date: Mon, 25 May 2026 21:24:41 +0800 Subject: [PATCH 1/4] have a try --- .../source_basis/module_pw/pw_gatherscatter.h | 78 +++++++++++++------ 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/source/source_basis/module_pw/pw_gatherscatter.h b/source/source_basis/module_pw/pw_gatherscatter.h index 207320f4268..bd017b23096 100644 --- a/source/source_basis/module_pw/pw_gatherscatter.h +++ b/source/source_basis/module_pw/pw_gatherscatter.h @@ -27,11 +27,16 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const for(int is = 0 ; is < nst_ ; ++is) { int ixy = istot2ixy_[is]; - std::complex *outp = &out[is*nz_]; - std::complex *inp = &in[ixy*nz_]; - for(int iz = 0 ; iz < nz_ ; ++iz) + std::complex* outp = &out[is*nz_]; + std::complex* inp = &in[ixy*nz_]; + T* __restrict__ outp_r = reinterpret_cast(outp); + const T* __restrict__ inp_r = reinterpret_cast(inp); +#ifdef __GNUC__ +#pragma GCC ivdep +#endif + for(int iz = 0 ; iz < 2 * nz_ ; ++iz) { - outp[iz] = inp[iz]; + outp_r[iz] = inp_r[iz]; } } return; @@ -50,11 +55,16 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const for (int istot = 0; istot < nstot_gps; ++istot) { int ixy = istot2ixy_gps[istot]; - std::complex *outp = &out[istot * nplane_gps]; - std::complex *inp = &in[ixy * nplane_gps]; - for (int iz = 0; iz < nplane_gps; ++iz) + std::complex* outp = &out[istot * nplane_gps]; + std::complex* inp = &in[ixy * nplane_gps]; + T* __restrict__ outp_r = reinterpret_cast(outp); + const T* __restrict__ inp_r = reinterpret_cast(inp); +#ifdef __GNUC__ +#pragma GCC ivdep +#endif + for (int iz = 0; iz < 2 * nplane_gps; ++iz) { - outp[iz] = inp[iz]; + outp_r[iz] = inp_r[iz]; } } @@ -90,11 +100,16 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const int nzip = numz_gps[ip]; std::complex *outp0 = &out[startz_gps[ip]]; std::complex *inp0 = &in[startg_gps[ip]]; - std::complex *outp = &outp0[is * nz_gps]; - std::complex *inp = &inp0[is * nzip ]; - for (int izip = 0; izip < nzip; ++izip) + std::complex* outp = &outp0[is * nz_gps]; + std::complex* inp = &inp0[is * nzip ]; + T* __restrict__ outp_r = reinterpret_cast(outp); + const T* __restrict__ inp_r = reinterpret_cast(inp); +#ifdef __GNUC__ +#pragma GCC ivdep +#endif + for (int izip = 0; izip < 2 * nzip; ++izip) { - outp[izip] = inp[izip]; + outp_r[izip] = inp_r[izip]; } } } @@ -132,11 +147,16 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const for(int is = 0 ; is < nst_ ; ++is) { int ixy = istot2ixy_[is]; - std::complex *outp = &out[ixy*nz_]; - std::complex *inp = &in[is*nz_]; - for(int iz = 0 ; iz < nz_ ; ++iz) + std::complex* outp = &out[ixy*nz_]; + std::complex* inp = &in[is*nz_]; + T* __restrict__ outp_r = reinterpret_cast(outp); + const T* __restrict__ inp_r = reinterpret_cast(inp); +#ifdef __GNUC__ +#pragma GCC ivdep +#endif + for(int iz = 0 ; iz < 2 * nz_ ; ++iz) { - outp[iz] = inp[iz]; + outp_r[iz] = inp_r[iz]; } } return; @@ -162,11 +182,16 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const int nzip = numz_[ip]; std::complex *outp0 = &out[startg_[ip]]; std::complex *inp0 = &in[startz_[ip]]; - std::complex *outp = &outp0[is * nzip]; - std::complex *inp = &inp0[is * nz_ ]; - for (int izip = 0; izip < nzip; ++izip) + std::complex* outp = &outp0[is * nzip]; + std::complex* inp = &inp0[is * nz_ ]; + T* __restrict__ outp_r = reinterpret_cast(outp); + const T* __restrict__ inp_r = reinterpret_cast(inp); +#ifdef __GNUC__ +#pragma GCC ivdep +#endif + for (int izip = 0; izip < 2 * nzip; ++izip) { - outp[izip] = inp[izip]; + outp_r[izip] = inp_r[izip]; } } } @@ -205,11 +230,16 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { int ixy = istot2ixy[istot]; //int ixy = (ixy / fftny)*ny + ixy % fftny; - std::complex *outp = &out[ixy * nplane]; - std::complex *inp = &in[istot * nplane]; - for (int iz = 0; iz < nplane; ++iz) + std::complex* outp = &out[ixy * nplane]; + std::complex* inp = &in[istot * nplane]; + T* __restrict__ outp_r = reinterpret_cast(outp); + const T* __restrict__ inp_r = reinterpret_cast(inp); +#ifdef __GNUC__ +#pragma GCC ivdep +#endif + for (int iz = 0; iz < 2 * nplane; ++iz) { - outp[iz] = inp[iz]; + outp_r[iz] = inp_r[iz]; } } #endif From c268969f90cbd1b83c1585419f467b7c27f7c98c Mon Sep 17 00:00:00 2001 From: Aunixt <2400011033@stu.pku.edu.cn> Date: Sun, 31 May 2026 00:35:09 +0800 Subject: [PATCH 2/4] refine complex buffer copies in module_pw --- .../source_basis/module_pw/pw_gatherscatter.h | 103 ++++++++---------- .../source_basis/module_pw/pw_transform.cpp | 18 +-- .../source_basis/module_pw/pw_transform_k.cpp | 16 +-- 3 files changed, 48 insertions(+), 89 deletions(-) diff --git a/source/source_basis/module_pw/pw_gatherscatter.h b/source/source_basis/module_pw/pw_gatherscatter.h index bd017b23096..52c24449009 100644 --- a/source/source_basis/module_pw/pw_gatherscatter.h +++ b/source/source_basis/module_pw/pw_gatherscatter.h @@ -5,6 +5,37 @@ namespace ModulePW { +namespace detail +{ +template +inline void copy_complex_buffer(const std::complex* in, std::complex* out, const int count) +{ + const T* __restrict__ in_r = reinterpret_cast(in); + T* __restrict__ out_r = reinterpret_cast(out); +#ifdef __GNUC__ +#pragma GCC ivdep +#endif + for (int i = 0; i < 2 * count; ++i) + { + out_r[i] = in_r[i]; + } +} + +template +inline void copy_complex_buffer_parallel(const std::complex* in, std::complex* out, const int count) +{ + const T* __restrict__ in_r = reinterpret_cast(in); + T* __restrict__ out_r = reinterpret_cast(out); +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (int i = 0; i < 2 * count; ++i) + { + out_r[i] = in_r[i]; + } +} +} // namespace detail + /** * @brief gather planes and scatter sticks * @param in: (nplane,fftny,fftnx) @@ -28,16 +59,8 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const { int ixy = istot2ixy_[is]; std::complex* outp = &out[is*nz_]; - std::complex* inp = &in[ixy*nz_]; - T* __restrict__ outp_r = reinterpret_cast(outp); - const T* __restrict__ inp_r = reinterpret_cast(inp); -#ifdef __GNUC__ -#pragma GCC ivdep -#endif - for(int iz = 0 ; iz < 2 * nz_ ; ++iz) - { - outp_r[iz] = inp_r[iz]; - } + const std::complex* inp = &in[ixy*nz_]; + detail::copy_complex_buffer(inp, outp, nz_); } return; } @@ -56,16 +79,8 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const { int ixy = istot2ixy_gps[istot]; std::complex* outp = &out[istot * nplane_gps]; - std::complex* inp = &in[ixy * nplane_gps]; - T* __restrict__ outp_r = reinterpret_cast(outp); - const T* __restrict__ inp_r = reinterpret_cast(inp); -#ifdef __GNUC__ -#pragma GCC ivdep -#endif - for (int iz = 0; iz < 2 * nplane_gps; ++iz) - { - outp_r[iz] = inp_r[iz]; - } + const std::complex* inp = &in[ixy * nplane_gps]; + detail::copy_complex_buffer(inp, outp, nplane_gps); } //exchange data @@ -101,16 +116,8 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const std::complex *outp0 = &out[startz_gps[ip]]; std::complex *inp0 = &in[startg_gps[ip]]; std::complex* outp = &outp0[is * nz_gps]; - std::complex* inp = &inp0[is * nzip ]; - T* __restrict__ outp_r = reinterpret_cast(outp); - const T* __restrict__ inp_r = reinterpret_cast(inp); -#ifdef __GNUC__ -#pragma GCC ivdep -#endif - for (int izip = 0; izip < 2 * nzip; ++izip) - { - outp_r[izip] = inp_r[izip]; - } + const std::complex* inp = &inp0[is * nzip ]; + detail::copy_complex_buffer(inp, outp, nzip); } } #endif @@ -148,16 +155,8 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { int ixy = istot2ixy_[is]; std::complex* outp = &out[ixy*nz_]; - std::complex* inp = &in[is*nz_]; - T* __restrict__ outp_r = reinterpret_cast(outp); - const T* __restrict__ inp_r = reinterpret_cast(inp); -#ifdef __GNUC__ -#pragma GCC ivdep -#endif - for(int iz = 0 ; iz < 2 * nz_ ; ++iz) - { - outp_r[iz] = inp_r[iz]; - } + const std::complex* inp = &in[is*nz_]; + detail::copy_complex_buffer(inp, outp, nz_); } return; } @@ -183,16 +182,8 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const std::complex *outp0 = &out[startg_[ip]]; std::complex *inp0 = &in[startz_[ip]]; std::complex* outp = &outp0[is * nzip]; - std::complex* inp = &inp0[is * nz_ ]; - T* __restrict__ outp_r = reinterpret_cast(outp); - const T* __restrict__ inp_r = reinterpret_cast(inp); -#ifdef __GNUC__ -#pragma GCC ivdep -#endif - for (int izip = 0; izip < 2 * nzip; ++izip) - { - outp_r[izip] = inp_r[izip]; - } + const std::complex* inp = &inp0[is * nz_ ]; + detail::copy_complex_buffer(inp, outp, nzip); } } @@ -231,16 +222,8 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const int ixy = istot2ixy[istot]; //int ixy = (ixy / fftny)*ny + ixy % fftny; std::complex* outp = &out[ixy * nplane]; - std::complex* inp = &in[istot * nplane]; - T* __restrict__ outp_r = reinterpret_cast(outp); - const T* __restrict__ inp_r = reinterpret_cast(inp); -#ifdef __GNUC__ -#pragma GCC ivdep -#endif - for (int iz = 0; iz < 2 * nplane; ++iz) - { - outp_r[iz] = inp_r[iz]; - } + const std::complex* inp = &in[istot * nplane]; + detail::copy_complex_buffer(inp, outp, nplane); } #endif return; diff --git a/source/source_basis/module_pw/pw_transform.cpp b/source/source_basis/module_pw/pw_transform.cpp index 220b353e9d4..dc867e0cd86 100644 --- a/source/source_basis/module_pw/pw_transform.cpp +++ b/source/source_basis/module_pw/pw_transform.cpp @@ -34,13 +34,7 @@ void PW_Basis::real2recip(const std::complex* in, const int npw_ = this->npw; const int nxyz_ = this->nxyz; const int* ig2isz_ = this->ig2isz; -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < nrxx_; ++ir) - { - this->fft_bundle.get_auxr_data()[ir] = in[ir]; - } + detail::copy_complex_buffer_parallel(in, this->fft_bundle.get_auxr_data(), nrxx_); this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); @@ -199,13 +193,7 @@ void PW_Basis::recip2real(const std::complex* in, } else { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < nrxx_; ++ir) - { - out[ir] = this->fft_bundle.get_auxr_data()[ir]; - } + detail::copy_complex_buffer_parallel(this->fft_bundle.get_auxr_data(), out, nrxx_); } ModuleBase::timer::end(this->classname, "recip2real"); } @@ -340,4 +328,4 @@ template void PW_Basis::recip2real(const std::complex* in, std::complex* out, const bool add, const double factor) const; -} // namespace ModulePW \ No newline at end of file +} // namespace ModulePW diff --git a/source/source_basis/module_pw/pw_transform_k.cpp b/source/source_basis/module_pw/pw_transform_k.cpp index a09aa2b686f..8c45e3d9b22 100644 --- a/source/source_basis/module_pw/pw_transform_k.cpp +++ b/source/source_basis/module_pw/pw_transform_k.cpp @@ -33,13 +33,7 @@ void PW_Basis_K::real2recip(const std::complex* in, assert(this->gamma_only == false); auto* auxr = this->fft_bundle.get_auxr_data(); -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < this->nrxx; ++ir) - { - auxr[ir] = in[ir]; - } + detail::copy_complex_buffer_parallel(in, auxr, this->nrxx); this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); @@ -200,13 +194,7 @@ void PW_Basis_K::recip2real(const std::complex* in, } else { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < this->nrxx; ++ir) - { - out[ir] = auxr[ir]; - } + detail::copy_complex_buffer_parallel(auxr, out, this->nrxx); } ModuleBase::timer::end(this->classname, "recip2real"); } From 754fe85bb75b7dc141258b4fce5697dc908c8c81 Mon Sep 17 00:00:00 2001 From: Aunixt <2400011033@stu.pku.edu.cn> Date: Sun, 31 May 2026 00:44:07 +0800 Subject: [PATCH 3/4] add module_pw complex transform round-trip tests --- .../module_pw/test_serial/pw_basis_k_test.cpp | 38 +++++++++++++++++++ .../module_pw/test_serial/pw_basis_test.cpp | 37 ++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp index 84932bae2ff..71026855674 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include /************************************************ * serial unit test of functions in pw_basis.cpp @@ -188,4 +189,41 @@ TEST_F(PWBasisKTEST, CollectLocalPW) EXPECT_EQ(basis_k.npwk_max,2721); } +TEST_F(PWBasisKTEST, ComplexTransformRoundTrip) +{ + ModulePW::PW_Basis_K basis_k(device_flag, precision_double); + double lat0 = 2.0; + ModuleBase::Matrix3 latvec(1.0,0.0,1.0, + 0.0,2.0,0.0, + 0.0,0.0,2.0); + double gridecut = 30.0; + const bool gamma_only_in = false; + const double gk_ecut_in = 20.0; + const int nks_in = 1; + const ModuleBase::Vector3 kvec_d_in[1] = { {0.0, 0.0, 0.0} }; + const int distribution_type_in = 2; + const bool xprime_in = false; + + basis_k.initgrids(lat0, latvec, gridecut); + basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in); + ASSERT_NO_THROW(basis_k.setuptransform()); + std::vector> recip_in(basis_k.npwk[0]); + std::vector> real_space(basis_k.nrxx); + std::vector> recip_out(basis_k.npwk[0]); + for (int ig = 0; ig < basis_k.npwk[0]; ++ig) + { + const double real_part = (ig % 17 - 8) / 11.0; + const double imag_part = (ig % 19 - 9) / 13.0; + recip_in[ig] = std::complex(real_part, imag_part); + } + + basis_k.recip2real(recip_in.data(), real_space.data(), 0); + basis_k.real2recip(real_space.data(), recip_out.data(), 0); + + for (int ig = 0; ig < basis_k.npwk[0]; ++ig) + { + EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10); + EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10); + } +} diff --git a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp index ea678b9d97c..13ba7252046 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include /************************************************ * serial unit test of functions in pw_basis.cpp @@ -362,3 +363,39 @@ TEST_F(PWBasisTEST,CollectUniqgg) pwb.collect_uniqgg(); EXPECT_EQ(pwb.ngg,78); } + +TEST_F(PWBasisTEST,ComplexTransformRoundTrip) +{ + double lat0 = 2.0; + ModuleBase::Matrix3 latvec(1.0,0.0,1.0, + 0.0,2.0,0.0, + 0.0,0.0,2.0); + double gridecut = 30.0; + bool gamma_only_in = false; + double pwecut_in = 20.0; + int distribution_type_in = 2; + bool xprime_in = false; + + pwb.initgrids(lat0, latvec, gridecut); + pwb.initparameters(gamma_only_in, pwecut_in, distribution_type_in, xprime_in); + ASSERT_NO_THROW(pwb.setuptransform()); + + std::vector> recip_in(pwb.npw); + std::vector> real_space(pwb.nrxx); + std::vector> recip_out(pwb.npw); + for (int ig = 0; ig < pwb.npw; ++ig) + { + const double real_part = (ig % 11 - 5) / 7.0; + const double imag_part = (ig % 13 - 6) / 9.0; + recip_in[ig] = std::complex(real_part, imag_part); + } + + pwb.recip2real(recip_in.data(), real_space.data()); + pwb.real2recip(real_space.data(), recip_out.data()); + + for (int ig = 0; ig < pwb.npw; ++ig) + { + EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10); + EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10); + } +} From 25ebe2e3064d4976a3544802284239621c32a0e4 Mon Sep 17 00:00:00 2001 From: Aunixt <2400011033@stu.pku.edu.cn> Date: Sun, 31 May 2026 00:50:28 +0800 Subject: [PATCH 4/4] document module_pw copy helpers and tests --- source/source_basis/module_pw/pw_gatherscatter.h | 4 ++++ source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp | 2 ++ source/source_basis/module_pw/test_serial/pw_basis_test.cpp | 2 ++ 3 files changed, 8 insertions(+) diff --git a/source/source_basis/module_pw/pw_gatherscatter.h b/source/source_basis/module_pw/pw_gatherscatter.h index 52c24449009..eddeddb0f76 100644 --- a/source/source_basis/module_pw/pw_gatherscatter.h +++ b/source/source_basis/module_pw/pw_gatherscatter.h @@ -7,6 +7,8 @@ namespace ModulePW { namespace detail { +// Copy complex buffers through the interleaved scalar stream so compilers can +// vectorize the contiguous real/imaginary data movement. template inline void copy_complex_buffer(const std::complex* in, std::complex* out, const int count) { @@ -21,6 +23,8 @@ inline void copy_complex_buffer(const std::complex* in, std::complex* out, } } +// Top-level transform copies own the OpenMP parallel region; gather/scatter +// loops call the non-parallel helper inside their existing parallel regions. template inline void copy_complex_buffer_parallel(const std::complex* in, std::complex* out, const int count) { diff --git a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp index 71026855674..ad3e1764dd1 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp @@ -208,6 +208,8 @@ TEST_F(PWBasisKTEST, ComplexTransformRoundTrip) basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in); ASSERT_NO_THROW(basis_k.setuptransform()); + // Use reciprocal-space input because arbitrary real-space data is projected + // by the plane-wave cutoff and is not exactly recoverable. std::vector> recip_in(basis_k.npwk[0]); std::vector> real_space(basis_k.nrxx); std::vector> recip_out(basis_k.npwk[0]); diff --git a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp index 13ba7252046..57ac8f06554 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp @@ -380,6 +380,8 @@ TEST_F(PWBasisTEST,ComplexTransformRoundTrip) pwb.initparameters(gamma_only_in, pwecut_in, distribution_type_in, xprime_in); ASSERT_NO_THROW(pwb.setuptransform()); + // Use reciprocal-space input because arbitrary real-space data is projected + // by the plane-wave cutoff and is not exactly recoverable. std::vector> recip_in(pwb.npw); std::vector> real_space(pwb.nrxx); std::vector> recip_out(pwb.npw);