Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 53 additions & 36 deletions source/source_basis/module_pw/pw_gatherscatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,41 @@

namespace ModulePW
{
namespace detail
{
// Copy complex buffers through the interleaved scalar stream so compilers can
// vectorize the contiguous real/imaginary data movement.
template <typename T>
inline void copy_complex_buffer(const std::complex<T>* in, std::complex<T>* out, const int count)
{
const T* __restrict__ in_r = reinterpret_cast<const T*>(in);
T* __restrict__ out_r = reinterpret_cast<T*>(out);
#ifdef __GNUC__
#pragma GCC ivdep
#endif
for (int i = 0; i < 2 * count; ++i)
{
out_r[i] = in_r[i];
}
Comment on lines +20 to +23
}

// Top-level transform copies own the OpenMP parallel region; gather/scatter
// loops call the non-parallel helper inside their existing parallel regions.
template <typename T>
inline void copy_complex_buffer_parallel(const std::complex<T>* in, std::complex<T>* out, const int count)
{
const T* __restrict__ in_r = reinterpret_cast<const T*>(in);
T* __restrict__ out_r = reinterpret_cast<T*>(out);
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int i = 0; i < 2 * count; ++i)
Comment on lines +29 to +36
{
out_r[i] = in_r[i];
}
}
} // namespace detail

/**
* @brief gather planes and scatter sticks
* @param in: (nplane,fftny,fftnx)
Expand All @@ -27,12 +62,9 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
for(int is = 0 ; is < nst_ ; ++is)
{
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[is*nz_];
std::complex<T> *inp = &in[ixy*nz_];
for(int iz = 0 ; iz < nz_ ; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[is*nz_];
const std::complex<T>* inp = &in[ixy*nz_];
detail::copy_complex_buffer(inp, outp, nz_);
}
return;
}
Expand All @@ -50,12 +82,9 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
for (int istot = 0; istot < nstot_gps; ++istot)
{
int ixy = istot2ixy_gps[istot];
std::complex<T> *outp = &out[istot * nplane_gps];
std::complex<T> *inp = &in[ixy * nplane_gps];
for (int iz = 0; iz < nplane_gps; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[istot * nplane_gps];
const std::complex<T>* inp = &in[ixy * nplane_gps];
detail::copy_complex_buffer(inp, outp, nplane_gps);
}

//exchange data
Expand Down Expand Up @@ -90,12 +119,9 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
int nzip = numz_gps[ip];
std::complex<T> *outp0 = &out[startz_gps[ip]];
std::complex<T> *inp0 = &in[startg_gps[ip]];
std::complex<T> *outp = &outp0[is * nz_gps];
std::complex<T> *inp = &inp0[is * nzip ];
for (int izip = 0; izip < nzip; ++izip)
{
outp[izip] = inp[izip];
}
std::complex<T>* outp = &outp0[is * nz_gps];
const std::complex<T>* inp = &inp0[is * nzip ];
detail::copy_complex_buffer(inp, outp, nzip);
}
}
#endif
Expand Down Expand Up @@ -132,12 +158,9 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
for(int is = 0 ; is < nst_ ; ++is)
{
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[ixy*nz_];
std::complex<T> *inp = &in[is*nz_];
for(int iz = 0 ; iz < nz_ ; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[ixy*nz_];
const std::complex<T>* inp = &in[is*nz_];
detail::copy_complex_buffer(inp, outp, nz_);
}
return;
}
Expand All @@ -162,12 +185,9 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
int nzip = numz_[ip];
std::complex<T> *outp0 = &out[startg_[ip]];
std::complex<T> *inp0 = &in[startz_[ip]];
std::complex<T> *outp = &outp0[is * nzip];
std::complex<T> *inp = &inp0[is * nz_ ];
for (int izip = 0; izip < nzip; ++izip)
{
outp[izip] = inp[izip];
}
std::complex<T>* outp = &outp0[is * nzip];
const std::complex<T>* inp = &inp0[is * nz_ ];
detail::copy_complex_buffer(inp, outp, nzip);
}
}

Expand Down Expand Up @@ -205,12 +225,9 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
{
int ixy = istot2ixy[istot];
//int ixy = (ixy / fftny)*ny + ixy % fftny;
std::complex<T> *outp = &out[ixy * nplane];
std::complex<T> *inp = &in[istot * nplane];
for (int iz = 0; iz < nplane; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[ixy * nplane];
const std::complex<T>* inp = &in[istot * nplane];
detail::copy_complex_buffer(inp, outp, nplane);
}
#endif
return;
Expand Down
18 changes: 3 additions & 15 deletions source/source_basis/module_pw/pw_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in,
const int npw_ = this->npw;
const int nxyz_ = this->nxyz;
const int* ig2isz_ = this->ig2isz;
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < nrxx_; ++ir)
{
this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = in[ir];
}
detail::copy_complex_buffer_parallel(in, this->fft_bundle.get_auxr_data<FPTYPE>(), nrxx_);
this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());

this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
Expand Down Expand Up @@ -199,13 +193,7 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in,
}
else
{
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < nrxx_; ++ir)
{
out[ir] = this->fft_bundle.get_auxr_data<FPTYPE>()[ir];
}
detail::copy_complex_buffer_parallel(this->fft_bundle.get_auxr_data<FPTYPE>(), out, nrxx_);
}
ModuleBase::timer::end(this->classname, "recip2real");
}
Expand Down Expand Up @@ -340,4 +328,4 @@ template void PW_Basis::recip2real<double>(const std::complex<double>* in,
std::complex<double>* out,
const bool add,
const double factor) const;
} // namespace ModulePW
} // namespace ModulePW
16 changes: 2 additions & 14 deletions source/source_basis/module_pw/pw_transform_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,7 @@ void PW_Basis_K::real2recip(const std::complex<FPTYPE>* in,

assert(this->gamma_only == false);
auto* auxr = this->fft_bundle.get_auxr_data<FPTYPE>();
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < this->nrxx; ++ir)
{
auxr[ir] = in[ir];
}
detail::copy_complex_buffer_parallel(in, auxr, this->nrxx);
this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());

this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
Expand Down Expand Up @@ -200,13 +194,7 @@ void PW_Basis_K::recip2real(const std::complex<FPTYPE>* in,
}
else
{
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < this->nrxx; ++ir)
{
out[ir] = auxr[ir];
}
detail::copy_complex_buffer_parallel(auxr, out, this->nrxx);
}
ModuleBase::timer::end(this->classname, "recip2real");
}
Expand Down
40 changes: 40 additions & 0 deletions source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "source_base/global_function.h"
#include "source_base/constants.h"
#include "source_base/matrix3.h"
#include <vector>

/************************************************
* serial unit test of functions in pw_basis.cpp
Expand Down Expand Up @@ -188,4 +189,43 @@ TEST_F(PWBasisKTEST, CollectLocalPW)
EXPECT_EQ(basis_k.npwk_max,2721);
}

TEST_F(PWBasisKTEST, ComplexTransformRoundTrip)
{
ModulePW::PW_Basis_K basis_k(device_flag, precision_double);
double lat0 = 2.0;
ModuleBase::Matrix3 latvec(1.0,0.0,1.0,
0.0,2.0,0.0,
0.0,0.0,2.0);
double gridecut = 30.0;
const bool gamma_only_in = false;
const double gk_ecut_in = 20.0;
const int nks_in = 1;
const ModuleBase::Vector3<double> kvec_d_in[1] = { {0.0, 0.0, 0.0} };
const int distribution_type_in = 2;
const bool xprime_in = false;

basis_k.initgrids(lat0, latvec, gridecut);
basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in);
ASSERT_NO_THROW(basis_k.setuptransform());

// Use reciprocal-space input because arbitrary real-space data is projected
// by the plane-wave cutoff and is not exactly recoverable.
std::vector<std::complex<double>> recip_in(basis_k.npwk[0]);
std::vector<std::complex<double>> real_space(basis_k.nrxx);
std::vector<std::complex<double>> recip_out(basis_k.npwk[0]);
Comment on lines +213 to +215
for (int ig = 0; ig < basis_k.npwk[0]; ++ig)
{
const double real_part = (ig % 17 - 8) / 11.0;
const double imag_part = (ig % 19 - 9) / 13.0;
recip_in[ig] = std::complex<double>(real_part, imag_part);
}

basis_k.recip2real(recip_in.data(), real_space.data(), 0);
basis_k.real2recip(real_space.data(), recip_out.data(), 0);

for (int ig = 0; ig < basis_k.npwk[0]; ++ig)
{
EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10);
EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10);
}
}
39 changes: 39 additions & 0 deletions source/source_basis/module_pw/test_serial/pw_basis_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "source_base/global_function.h"
#include "source_base/constants.h"
#include "source_base/matrix3.h"
#include <vector>

/************************************************
* serial unit test of functions in pw_basis.cpp
Expand Down Expand Up @@ -362,3 +363,41 @@ TEST_F(PWBasisTEST,CollectUniqgg)
pwb.collect_uniqgg();
EXPECT_EQ(pwb.ngg,78);
}

TEST_F(PWBasisTEST,ComplexTransformRoundTrip)
{
double lat0 = 2.0;
ModuleBase::Matrix3 latvec(1.0,0.0,1.0,
0.0,2.0,0.0,
0.0,0.0,2.0);
double gridecut = 30.0;
bool gamma_only_in = false;
double pwecut_in = 20.0;
int distribution_type_in = 2;
bool xprime_in = false;

pwb.initgrids(lat0, latvec, gridecut);
pwb.initparameters(gamma_only_in, pwecut_in, distribution_type_in, xprime_in);
ASSERT_NO_THROW(pwb.setuptransform());

// Use reciprocal-space input because arbitrary real-space data is projected
// by the plane-wave cutoff and is not exactly recoverable.
std::vector<std::complex<double>> recip_in(pwb.npw);
std::vector<std::complex<double>> real_space(pwb.nrxx);
std::vector<std::complex<double>> recip_out(pwb.npw);
for (int ig = 0; ig < pwb.npw; ++ig)
{
const double real_part = (ig % 11 - 5) / 7.0;
const double imag_part = (ig % 13 - 6) / 9.0;
recip_in[ig] = std::complex<double>(real_part, imag_part);
}

pwb.recip2real(recip_in.data(), real_space.data());
pwb.real2recip(real_space.data(), recip_out.data());

for (int ig = 0; ig < pwb.npw; ++ig)
{
EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10);
EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10);
}
Comment on lines +398 to +402
}
Loading