diff --git a/source/source_base/truncated_func.h b/source/source_base/truncated_func.h new file mode 100644 index 0000000000..55ce64953a --- /dev/null +++ b/source/source_base/truncated_func.h @@ -0,0 +1,116 @@ +#ifndef MODULE_BASE_TRUNCATED_FUNC_H +#define MODULE_BASE_TRUNCATED_FUNC_H + +#include "source_base/libm/libm.h" +#include +#include +#include + +namespace ModuleBase +{ + +/** + * @brief Truncated exponential function to avoid underflow. + * + * This function returns 0 if the real part of the input is less than -230.0, + * otherwise it calls ModuleBase::libm::exp(x). + * + * @tparam FPTYPE The floating point type (float, double, or complex). + * @param x The input value. + * @return FPTYPE The result of the exponential function. + */ +template +inline FPTYPE truncated_exp(FPTYPE x) +{ + if (std::real(x) < -230.0) + { + return static_cast(0.0); + } + return ModuleBase::libm::exp(x); +} + +/** + * @brief Truncated complementary error function to avoid underflow for large arguments. + * + * This function returns 0 if the real part of the input is greater than 20.0, + * otherwise it calls std::erfc(x). + * + * @tparam FPTYPE The floating point type (float, double, or complex). + * @param x The input value. + * @return FPTYPE The result of the erfc function. + */ +template +inline FPTYPE truncated_erfc(FPTYPE x) +{ + if (std::real(x) > 20.0) + { + return static_cast(0.0); + } + return std::erfc(x); +} + +/** + * @brief Truncated value function to avoid underflow. + * + * This function returns 0 if the absolute value of the input is less than 1.0e-30, + * otherwise it returns the input x. + * + * @tparam FPTYPE The floating point type (float, double, or complex). + * @param x The input value. + * @return FPTYPE The result of the truncation. + */ +/** + * @brief Truncated value function to avoid underflow. + * + * This function modifies the input to 0 if its absolute value is less than 1.0e-30. + * + * @tparam FPTYPE The floating point type (float, double, or complex). + * @param x The input value to be checked and possibly truncated. + */ +template +inline void truncated_underflow(FPTYPE& x) +{ + if (std::abs(x) < 1.0e-30) + { + x = static_cast(0.0); + } +} + +template <> +inline void truncated_underflow(double& x) +{ + const uint64_t u = *reinterpret_cast(&x); + // The exponent bits are 52-62 (11 bits). The bias is 1023. + // 1e-30 corresponds to -100 in base-2 exponent roughly. + // 923 = 1023 - 100. + if (((u >> 52) & 0x7FF) <= 923) + { + x = 0.0; + } +} + +template <> +inline void truncated_underflow(float& x) +{ + const uint32_t u = *reinterpret_cast(&x); + // The exponent bits are 23-30 (8 bits). The bias is 127. + // 1e-30 corresponds to -100 in base-2 exponent roughly. + // 27 = 127 - 100. + if (((u >> 23) & 0xFF) <= 27) + { + x = 0.0f; + } +} + +template +inline void truncated_underflow(std::complex& x) +{ + T* ptr = reinterpret_cast(&x); + truncated_underflow(ptr[0]); + truncated_underflow(ptr[1]); +} + + +} // namespace ModuleBase + +#endif // MODULE_BASE_TRUNCATED_FUNC_H \ No newline at end of file diff --git a/source/source_pw/module_pwdft/forces.cpp b/source/source_pw/module_pwdft/forces.cpp index b5cef87e20..a6894c49ca 100644 --- a/source/source_pw/module_pwdft/forces.cpp +++ b/source/source_pw/module_pwdft/forces.cpp @@ -5,6 +5,7 @@ // new #include "source_base/complexmatrix.h" #include "source_base/libm/libm.h" +#include "source_base/truncated_func.h" #include "source_base/math_integral.h" #include "source_base/mathzone.h" #include "source_base/timer.h" @@ -537,8 +538,7 @@ void Forces::cal_force_ew(const UnitCell& ucell, { ModuleBase::WARNING_QUIT("ewald", "Can't find optimal alpha."); } - upperbound = 2.0 * charge * charge * sqrt(2.0 * alpha / ModuleBase::TWO_PI) - * erfc(sqrt(ucell.tpiba2 * rho_basis->ggecut / 4.0 / alpha)); + upperbound = 2.0 * charge * charge * sqrt(2.0 * alpha / ModuleBase::TWO_PI)* ModuleBase::truncated_erfc(sqrt( ucell.tpiba2 * rho_basis->ggecut / 4.0 / alpha)); } while (upperbound > 1.0e-6); const int ig0 = rho_basis->ig_gge0; #pragma omp parallel for @@ -548,7 +548,8 @@ void Forces::cal_force_ew(const UnitCell& ucell, { continue; // skip G=0 } - aux[ig] *= ModuleBase::libm::exp(-1.0 * rho_basis->gg[ig] * ucell.tpiba2 / alpha / 4.0) + aux[ig] *= ModuleBase::truncated_exp + (-1.0 * rho_basis->gg[ig] * ucell.tpiba2 / alpha / 4.0) / (rho_basis->gg[ig] * ucell.tpiba2); } diff --git a/source/source_pw/module_pwdft/kernels/force_op.cpp b/source/source_pw/module_pwdft/kernels/force_op.cpp index 029afdbf7c..0e0c34ccdd 100644 --- a/source/source_pw/module_pwdft/kernels/force_op.cpp +++ b/source/source_pw/module_pwdft/kernels/force_op.cpp @@ -1,5 +1,7 @@ #include "source_pw/module_pwdft/kernels/force_op.h" +#include "source_base/truncated_func.h" + #ifdef _OPENMP #include #endif @@ -109,9 +111,12 @@ struct cal_force_nl_op for (int ipol = 0; ipol < 3; ipol++) { - const FPTYPE dbb - = (conj(dbecp[ipol * nbands * nkb + ib * nkb + inkb]) * becp[ib * nkb + inkb]) - .real(); +#ifdef __SW + ModuleBase::truncated_underflow(dbecp[ipol * nbands * nkb + ib * nkb + inkb]); + ModuleBase::truncated_underflow(becp[ib * nkb + inkb]); + ModuleBase::truncated_underflow(local_force[ipol]); +#endif + const FPTYPE dbb = (conj(dbecp[ipol * nbands * nkb + ib * nkb + inkb]) * becp[ib * nkb + inkb]).real(); local_force[ipol] -= ps * fac * dbb; // cf[iat*3+ipol] += ps * fac * dbb; } diff --git a/source/source_pw/module_pwdft/kernels/stress_op.cpp b/source/source_pw/module_pwdft/kernels/stress_op.cpp index 46fccdeed6..6034a52710 100644 --- a/source/source_pw/module_pwdft/kernels/stress_op.cpp +++ b/source/source_pw/module_pwdft/kernels/stress_op.cpp @@ -1,5 +1,6 @@ #include "source_pw/module_pwdft/kernels/stress_op.h" +#include "source_base/truncated_func.h" #include "source_base/constants.h" #include "source_base/libm/libm.h" #include "source_base/math_polyint.h" @@ -140,8 +141,10 @@ struct cal_stress_nl_op FPTYPE ps = deeq[((spin * deeq_2 + iat + ia) * deeq_3 + ip1) * deeq_4 + ip2] + ps_qq; const int inkb1 = sum + ia * nproj + ip1; const int inkb2 = sum + ia * nproj + ip2; - // out<<"\n ps = "< { { rhocg1 *= ModuleBase::FOUR_PI / omega / 2.0 / gx_arr[igl]; FPTYPE g2a = (gx_arr[igl]*gx_arr[igl]) / 4.0; - rhocg1 += ModuleBase::FOUR_PI / omega * gx_arr[ngg] * ModuleBase::libm::exp(-g2a) * (g2a + 1) + rhocg1 += ModuleBase::FOUR_PI / omega * gx_arr[ngg] * + ModuleBase::truncated_exp(-g2a) * (g2a + 1) / pow(gx_arr[igl] * gx_arr[igl], 2); drhocg [igl] = rhocg1; } @@ -644,6 +648,9 @@ struct cal_multi_dot_op { #endif for (int i = 0; i < npw; i++) { +#ifdef __SW + ModuleBase::truncated_underflow(psi[i]); +#endif sum += fac * gk1[i] * gk2[i] * d_kfac[i] * std::norm(psi[i]); } return sum; diff --git a/source/source_pw/module_pwdft/vl_pw.cpp b/source/source_pw/module_pwdft/vl_pw.cpp index 672fbd6185..76ee526fbc 100644 --- a/source/source_pw/module_pwdft/vl_pw.cpp +++ b/source/source_pw/module_pwdft/vl_pw.cpp @@ -1,6 +1,7 @@ #include "vl_pw.h" #include "source_io/module_parameter/parameter.h" #include "source_base/libm/libm.h" +#include "source_base/truncated_func.h" #include "source_base/math_integral.h" #include "source_base/timer.h" @@ -226,8 +227,7 @@ void pseudopot_cell_vl::vloc_of_g(const int& msh, aux [ir] = aux1 [ir] * ModuleBase::libm::sin(gx * r [ir]) / gx; } ModuleBase::Integral::Simpson_Integral(msh, aux, rab, vloc_1d[ig] ); - // here we add the analytic fourier transform of the erf function - vloc_1d[ig] -= fac * ModuleBase::libm::exp(- gx2 * 0.25)/ gx2; + vloc_1d[ig] -= fac * ModuleBase::truncated_exp(- gx2 * 0.25)/ gx2; } // enddo const double d_fpi_omega = ModuleBase::FOUR_PI/ucell.omega;//mohan add 2008-06-04