Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions source/source_base/truncated_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#ifndef MODULE_BASE_TRUNCATED_FUNC_H
#define MODULE_BASE_TRUNCATED_FUNC_H

#include "source_base/libm/libm.h"
#include <cstdint>
#include <cstring>
#include <complex>

namespace ModuleBase
{

/**
* @brief Truncated exponential function to avoid underflow.
*
* This function returns 0 if the real part of the input is less than -230.0,
* otherwise it calls ModuleBase::libm::exp(x).
*
* @tparam FPTYPE The floating point type (float, double, or complex).
* @param x The input value.
* @return FPTYPE The result of the exponential function.
*/
template <typename FPTYPE>
inline FPTYPE truncated_exp(FPTYPE x)
{
if (std::real(x) < -230.0)
{
return static_cast<FPTYPE>(0.0);
}
return ModuleBase::libm::exp(x);
}

/**
* @brief Truncated complementary error function to avoid underflow for large arguments.
*
* This function returns 0 if the real part of the input is greater than 20.0,
* otherwise it calls std::erfc(x).
*
* @tparam FPTYPE The floating point type (float, double, or complex).
* @param x The input value.
* @return FPTYPE The result of the erfc function.
*/
template <typename FPTYPE>
inline FPTYPE truncated_erfc(FPTYPE x)
{
if (std::real(x) > 20.0)
{
return static_cast<FPTYPE>(0.0);
}
return std::erfc(x);
}

/**
* @brief Truncated value function to avoid underflow.
*
* This function returns 0 if the absolute value of the input is less than 1.0e-30,
* otherwise it returns the input x.
*
* @tparam FPTYPE The floating point type (float, double, or complex).
* @param x The input value.
* @return FPTYPE The result of the truncation.
*/
/**
* @brief Truncated value function to avoid underflow.
*
* This function modifies the input to 0 if its absolute value is less than 1.0e-30.
*
* @tparam FPTYPE The floating point type (float, double, or complex).
* @param x The input value to be checked and possibly truncated.
*/
template <typename FPTYPE>
inline void truncated_underflow(FPTYPE& x)
{
if (std::abs(x) < 1.0e-30)
{
x = static_cast<FPTYPE>(0.0);
}
}

template <>
inline void truncated_underflow(double& x)
{
const uint64_t u = *reinterpret_cast<const uint64_t*>(&x);
// The exponent bits are 52-62 (11 bits). The bias is 1023.
// 1e-30 corresponds to -100 in base-2 exponent roughly.
// 923 = 1023 - 100.
if (((u >> 52) & 0x7FF) <= 923)
{
x = 0.0;
}
}

template <>
inline void truncated_underflow(float& x)
{
const uint32_t u = *reinterpret_cast<const uint32_t*>(&x);
// The exponent bits are 23-30 (8 bits). The bias is 127.
// 1e-30 corresponds to -100 in base-2 exponent roughly.
// 27 = 127 - 100.
if (((u >> 23) & 0xFF) <= 27)
{
x = 0.0f;
}
}

template <typename T>
inline void truncated_underflow(std::complex<T>& x)
{
T* ptr = reinterpret_cast<T*>(&x);
truncated_underflow(ptr[0]);
truncated_underflow(ptr[1]);
}


} // namespace ModuleBase

#endif // MODULE_BASE_TRUNCATED_FUNC_H
7 changes: 4 additions & 3 deletions source/source_pw/module_pwdft/forces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// new
#include "source_base/complexmatrix.h"
#include "source_base/libm/libm.h"
#include "source_base/truncated_func.h"
#include "source_base/math_integral.h"
#include "source_base/mathzone.h"
#include "source_base/timer.h"
Expand Down Expand Up @@ -537,8 +538,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
{
ModuleBase::WARNING_QUIT("ewald", "Can't find optimal alpha.");
}
upperbound = 2.0 * charge * charge * sqrt(2.0 * alpha / ModuleBase::TWO_PI)
* erfc(sqrt(ucell.tpiba2 * rho_basis->ggecut / 4.0 / alpha));
upperbound = 2.0 * charge * charge * sqrt(2.0 * alpha / ModuleBase::TWO_PI)* ModuleBase::truncated_erfc(sqrt( ucell.tpiba2 * rho_basis->ggecut / 4.0 / alpha));
} while (upperbound > 1.0e-6);
const int ig0 = rho_basis->ig_gge0;
#pragma omp parallel for
Expand All @@ -548,7 +548,8 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
{
continue; // skip G=0
}
aux[ig] *= ModuleBase::libm::exp(-1.0 * rho_basis->gg[ig] * ucell.tpiba2 / alpha / 4.0)
aux[ig] *= ModuleBase::truncated_exp
(-1.0 * rho_basis->gg[ig] * ucell.tpiba2 / alpha / 4.0)
/ (rho_basis->gg[ig] * ucell.tpiba2);
}

Expand Down
11 changes: 8 additions & 3 deletions source/source_pw/module_pwdft/kernels/force_op.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "source_pw/module_pwdft/kernels/force_op.h"

#include "source_base/truncated_func.h"

#ifdef _OPENMP
#include <omp.h>
#endif
Expand Down Expand Up @@ -109,9 +111,12 @@ struct cal_force_nl_op<FPTYPE, base_device::DEVICE_CPU>

for (int ipol = 0; ipol < 3; ipol++)
{
const FPTYPE dbb
= (conj(dbecp[ipol * nbands * nkb + ib * nkb + inkb]) * becp[ib * nkb + inkb])
.real();
#ifdef __SW
ModuleBase::truncated_underflow(dbecp[ipol * nbands * nkb + ib * nkb + inkb]);
ModuleBase::truncated_underflow(becp[ib * nkb + inkb]);
ModuleBase::truncated_underflow(local_force[ipol]);
#endif
const FPTYPE dbb = (conj(dbecp[ipol * nbands * nkb + ib * nkb + inkb]) * becp[ib * nkb + inkb]).real();
local_force[ipol] -= ps * fac * dbb;
// cf[iat*3+ipol] += ps * fac * dbb;
}
Expand Down
13 changes: 10 additions & 3 deletions source/source_pw/module_pwdft/kernels/stress_op.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "source_pw/module_pwdft/kernels/stress_op.h"

#include "source_base/truncated_func.h"
#include "source_base/constants.h"
#include "source_base/libm/libm.h"
#include "source_base/math_polyint.h"
Expand Down Expand Up @@ -140,8 +141,10 @@ struct cal_stress_nl_op<FPTYPE, base_device::DEVICE_CPU>
FPTYPE ps = deeq[((spin * deeq_2 + iat + ia) * deeq_3 + ip1) * deeq_4 + ip2] + ps_qq;
const int inkb1 = sum + ia * nproj + ip1;
const int inkb2 = sum + ia * nproj + ip2;
// out<<"\n ps = "<<ps;

#ifdef __SW
ModuleBase::truncated_underflow(dbecp[ib * nkb + inkb1]);
ModuleBase::truncated_underflow(becp[ib * nkb + inkb2]);
#endif
const FPTYPE dbb = (conj(dbecp[ib * nkb + inkb1]) * becp[ib * nkb + inkb2]).real();
local_stress -= ps * fac * dbb;
}
Expand Down Expand Up @@ -618,7 +621,8 @@ struct cal_stress_drhoc_aux_op<FPTYPE, base_device::DEVICE_CPU> {
{
rhocg1 *= ModuleBase::FOUR_PI / omega / 2.0 / gx_arr[igl];
FPTYPE g2a = (gx_arr[igl]*gx_arr[igl]) / 4.0;
rhocg1 += ModuleBase::FOUR_PI / omega * gx_arr[ngg] * ModuleBase::libm::exp(-g2a) * (g2a + 1)
rhocg1 += ModuleBase::FOUR_PI / omega * gx_arr[ngg] *
ModuleBase::truncated_exp(-g2a) * (g2a + 1)
/ pow(gx_arr[igl] * gx_arr[igl], 2);
drhocg [igl] = rhocg1;
}
Expand All @@ -644,6 +648,9 @@ struct cal_multi_dot_op<FPTYPE, base_device::DEVICE_CPU> {
#endif
for (int i = 0; i < npw; i++)
{
#ifdef __SW
ModuleBase::truncated_underflow(psi[i]);
#endif
sum += fac * gk1[i] * gk2[i] * d_kfac[i] * std::norm(psi[i]);
}
return sum;
Expand Down
4 changes: 2 additions & 2 deletions source/source_pw/module_pwdft/vl_pw.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "vl_pw.h"
#include "source_io/module_parameter/parameter.h"
#include "source_base/libm/libm.h"
#include "source_base/truncated_func.h"
#include "source_base/math_integral.h"
#include "source_base/timer.h"

Expand Down Expand Up @@ -226,8 +227,7 @@ void pseudopot_cell_vl::vloc_of_g(const int& msh,
aux [ir] = aux1 [ir] * ModuleBase::libm::sin(gx * r [ir]) / gx;
}
ModuleBase::Integral::Simpson_Integral(msh, aux, rab, vloc_1d[ig] );
// here we add the analytic fourier transform of the erf function
vloc_1d[ig] -= fac * ModuleBase::libm::exp(- gx2 * 0.25)/ gx2;
vloc_1d[ig] -= fac * ModuleBase::truncated_exp(- gx2 * 0.25)/ gx2;
} // enddo

const double d_fpi_omega = ModuleBase::FOUR_PI/ucell.omega;//mohan add 2008-06-04
Expand Down
Loading