Skip to content
Open
16 changes: 16 additions & 0 deletions source/source_basis/module_pw/pw_basis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ PW_Basis:: ~PW_Basis()
delete[] startr;
delete[] ig2igg;
delete[] gg_uniq;
this->comm_workbuf_float_.reset();
this->comm_workbuf_double_.reset();
#if defined(__CUDA) || defined(__ROCM)
if (this->device == "gpu")
{
Expand Down Expand Up @@ -124,9 +126,23 @@ void PW_Basis::getstartgr()
{
this->startr[ip] = this->startr[ip-1] + this->numr[ip-1];
}
this->allocate_comm_buffers();
return;
}

void PW_Basis::allocate_comm_buffers()
{
if (this->poolnproc <= 0)
{
return;
}
const std::size_t max_size = static_cast<std::size_t>(
this->startr[this->poolnproc - 1] + this->numr[this->poolnproc - 1]
+ this->startg[this->poolnproc - 1] + this->numg[this->poolnproc - 1]);
this->comm_workbuf_float_.reset(new std::complex<float>[max_size]);
this->comm_workbuf_double_.reset(new std::complex<double>[max_size]);
}

///
/// Collect planewaves on current core, and construct gg, gdirect, gcar according to ig2isz and is2fftixy.
/// known: ig2isz, is2fftixy
Expand Down
26 changes: 25 additions & 1 deletion source/source_basis/module_pw/pw_basis.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
#include "source_base/vector3.h"
#include <complex>
#include "source_base/module_fft/fft_bundle.h"
#include <cassert>
#include <cstring>
#include <memory>
#include <vector>
#ifdef __MPI
#include "mpi.h"
#endif
Expand Down Expand Up @@ -148,7 +151,7 @@ class PW_Basis

//prepare for MPI_Alltoall
void getstartgr();

void allocate_comm_buffers();

public:
//collect gdirect, gcar, gg
Expand Down Expand Up @@ -420,6 +423,9 @@ class PW_Basis
template <typename T>
void gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const;

template <typename T>
std::complex<T>* acquire_comm_workbuf(const int size) const;

public:
//get fftixy2is;
void getfftixy2is(int * fftixy2is) const;
Expand All @@ -441,7 +447,25 @@ class PW_Basis
std::string precision = "double"; ///< single, double, mixing
bool double_data_ = true; ///< if has double data
bool float_data_ = false; ///< if has float data
std::unique_ptr<std::complex<float>[]> comm_workbuf_float_;
std::unique_ptr<std::complex<double>[]> comm_workbuf_double_;
};

template <>
inline std::complex<float>* PW_Basis::acquire_comm_workbuf<float>(const int size) const
{
(void)size;
assert(this->comm_workbuf_float_ != nullptr);
return this->comm_workbuf_float_.get();
}

template <>
inline std::complex<double>* PW_Basis::acquire_comm_workbuf<double>(const int size) const
{
(void)size;
assert(this->comm_workbuf_double_ != nullptr);
return this->comm_workbuf_double_.get();
}
}
#endif // PWBASIS_H
#include "pw_basis_sup.h"
Expand Down
Loading
Loading