Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 31 additions & 106 deletions source/source_pw/module_pwdft/hamilt_pw.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
#include "hamilt_pw.h"

#include "source_io/module_parameter/parameter.h"
#include "source_base/global_function.h"
#include "source_base/global_variable.h"
#include "source_base/parallel_reduce.h"

#include "op_pw_veff.h"
#include "op_pw_ekin.h"
#include "op_pw_exx.h"
#include "op_pw_meta.h"
#include "op_pw_nl.h"
#include "op_pw_proj.h"
#include "op_pw_exx.h"
#include "op_pw_veff.h"
#include "source_base/global_function.h"
#include "source_base/global_variable.h"
#include "source_base/parallel_reduce.h"
#include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info
#include "source_io/module_parameter/parameter.h"

namespace hamilt
{
Expand All @@ -22,7 +21,8 @@ HamiltPW<T, Device>::HamiltPW(elecstate::Potential* pot_in,
K_Vectors* pkv,
pseudopot_cell_vnl* nlpp,
Plus_U* p_dftu, // mohan add 2025-11-06
const UnitCell* ucell): ucell(ucell)
const UnitCell* ucell)
: ucell(ucell)
{
this->classname = "HamiltPW";
this->ppcell = nlpp;
Expand All @@ -39,7 +39,7 @@ HamiltPW<T, Device>::HamiltPW(elecstate::Potential* pot_in,
// Operator<double>* ekinetic = new Ekinetic<OperatorLCAO<double>>
Operator<T, Device>* ekinetic
= new Ekinetic<OperatorPW<T, Device>>(tpiba2, gk2, wfc_basis->nks, wfc_basis->npwk_max);
if(this->ops == nullptr)
if (this->ops == nullptr)
{
this->ops = ekinetic;
}
Expand All @@ -59,7 +59,7 @@ HamiltPW<T, Device>::HamiltPW(elecstate::Potential* pot_in,
{
pot_register_in.push_back("hartree");
}
//no variable can choose xc, maybe it is necessary
// no variable can choose xc, maybe it is necessary
pot_register_in.push_back("xc");
if (PARAM.inp.imp_sol)
{
Expand All @@ -78,20 +78,21 @@ HamiltPW<T, Device>::HamiltPW(elecstate::Potential* pot_in,
pot_register_in.push_back("ml_exx");
}
// DFT-1/2
if (PARAM.inp.dfthalf_type == 1) {
if (PARAM.inp.dfthalf_type == 1)
{
pot_register_in.push_back("dfthalf");
}
//only Potential is not empty, Veff and Meta are available
if(pot_register_in.size()>0)
// only Potential is not empty, Veff and Meta are available
if (pot_register_in.size() > 0)
{
//register Potential by gathered operator
// register Potential by gathered operator
pot_in->pot_register(pot_register_in);
Operator<T, Device>* veff = new Veff<OperatorPW<T, Device>>(isk,
pot_in->get_veff_smooth_data<Real>(),
pot_in->get_veff_smooth().nr,
pot_in->get_veff_smooth().nc,
wfc_basis);
if(this->ops == nullptr)
if (this->ops == nullptr)
{
this->ops = veff;
}
Expand All @@ -110,9 +111,8 @@ HamiltPW<T, Device>::HamiltPW(elecstate::Potential* pot_in,
}
if (PARAM.inp.vnl_in_h)
{
Operator<T, Device>* nonlocal
= new Nonlocal<OperatorPW<T, Device>>(isk, this->ppcell, ucell, wfc_basis);
if(this->ops == nullptr)
Operator<T, Device>* nonlocal = new Nonlocal<OperatorPW<T, Device>>(isk, this->ppcell, ucell, wfc_basis);
if (this->ops == nullptr)
{
this->ops = nonlocal;
}
Expand All @@ -121,11 +121,13 @@ HamiltPW<T, Device>::HamiltPW(elecstate::Potential* pot_in,
this->ops->add(nonlocal);
}
}
if(PARAM.inp.sc_mag_switch || PARAM.inp.dft_plus_u)
if (PARAM.inp.sc_mag_switch || PARAM.inp.dft_plus_u)
{
Operator<T, Device>* onsite_proj
= new OnsiteProj<OperatorPW<T, Device>>(isk, ucell, p_dftu,
PARAM.inp.sc_mag_switch, (PARAM.inp.dft_plus_u>0));
Operator<T, Device>* onsite_proj = new OnsiteProj<OperatorPW<T, Device>>(isk,
ucell,
p_dftu,
PARAM.inp.sc_mag_switch,
(PARAM.inp.dft_plus_u > 0));
this->ops->add(onsite_proj);
}
if (GlobalC::exx_info.info_global.cal_exx)
Expand All @@ -144,97 +146,21 @@ HamiltPW<T, Device>::HamiltPW(elecstate::Potential* pot_in,
return;
}

template<typename T, typename Device>
template <typename T, typename Device>
HamiltPW<T, Device>::~HamiltPW()
{
if(this->ops!= nullptr)
if (this->ops != nullptr)
{
delete this->ops;
}
}

template<typename T, typename Device>
template <typename T, typename Device>
void HamiltPW<T, Device>::updateHk(const int ik)
{
ModuleBase::TITLE("HamiltPW","updateHk");
ModuleBase::TITLE("HamiltPW", "updateHk");
this->ops->init(ik);
ModuleBase::TITLE("HamiltPW","updateHk");
}

template<typename T, typename Device>
template<typename T_in, typename Device_in>
HamiltPW<T, Device>::HamiltPW(const HamiltPW<T_in, Device_in> *hamilt)
{
this->classname = hamilt->classname;
this->ppcell = hamilt->ppcell;
this->qq_nt = hamilt->qq_nt;
this->qq_so = hamilt->qq_so;
this->vkb = hamilt->vkb;
OperatorPW<std::complex<T_in>, Device_in> * node =
reinterpret_cast<OperatorPW<std::complex<T_in>, Device_in> *>(hamilt->ops);

while(node != nullptr) {
if (node->classname == "Ekinetic") {
Operator<T, Device>* ekinetic =
new Ekinetic<OperatorPW<T, Device>>(
reinterpret_cast<const Ekinetic<OperatorPW<T_in, Device_in>>*>(node));
if(this->ops == nullptr) {
this->ops = ekinetic;
}
else {
this->ops->add(ekinetic);
}
// this->ops = reinterpret_cast<Operator<T, Device>*>(node);
}
else if (node->classname == "Nonlocal") {
Operator<T, Device>* nonlocal =
new Nonlocal<OperatorPW<T, Device>>(
reinterpret_cast<const Nonlocal<OperatorPW<T_in, Device_in>>*>(node));
if(this->ops == nullptr) {
this->ops = nonlocal;
}
else {
this->ops->add(nonlocal);
}
}
else if (node->classname == "Veff") {
Operator<T, Device>* veff =
new Veff<OperatorPW<T, Device>>(
reinterpret_cast<const Veff<OperatorPW<T_in, Device_in>>*>(node));
if(this->ops == nullptr) {
this->ops = veff;
}
else {
this->ops->add(veff);
}
}
else if (node->classname == "Meta") {
Operator<T, Device>* meta =
new Meta<OperatorPW<T, Device>>(
reinterpret_cast<const Meta<OperatorPW<T_in, Device_in>>*>(node));
if(this->ops == nullptr) {
this->ops = meta;
}
else {
this->ops->add(meta);
}
}
else if (node->classname == "OnsiteProj") {
Operator<T, Device>* onsite_proj =
new OnsiteProj<OperatorPW<T, Device>>(
reinterpret_cast<const OnsiteProj<OperatorPW<T_in, Device_in>>*>(node));
if(this->ops == nullptr) {
this->ops = onsite_proj;
}
else {
this->ops->add(onsite_proj);
}
}
else {
ModuleBase::WARNING_QUIT("HamiltPW", "Unrecognized Operator type!");
}
node = reinterpret_cast<OperatorPW<std::complex<T_in>, Device_in> *>(node->next_op);
}
ModuleBase::TITLE("HamiltPW", "updateHk");
}

// This routine applies the S matrix to m wavefunctions psi and puts
Expand Down Expand Up @@ -390,8 +316,8 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
}
}

template<typename T, typename Device>
void HamiltPW<T, Device>::set_exx_helper(Exx_Helper<T, Device> &exx_helper)
template <typename T, typename Device>
void HamiltPW<T, Device>::set_exx_helper(Exx_Helper<T, Device>& exx_helper)
{
auto op = this->ops;
while (op != nullptr)
Expand All @@ -400,7 +326,6 @@ void HamiltPW<T, Device>::set_exx_helper(Exx_Helper<T, Device> &exx_helper)
{
exx_helper.op_exx = reinterpret_cast<OperatorEXXPW<T, Device>*>(op);
exx_helper.set_op();

}
op = op->next_op;
}
Expand Down
30 changes: 13 additions & 17 deletions source/source_pw/module_pwdft/hamilt_pw.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#ifndef HAMILTPW_H
#define HAMILTPW_H

#include "source_base/kernels/math_kernel_op.h"
#include "source_base/macros.h"
#include "source_cell/klist.h"
#include "source_estate/module_pot/potential_new.h"
#include "source_esolver/esolver_ks_pw.h"
#include "source_estate/module_pot/potential_new.h"
#include "source_hamilt/hamilt.h"
#include "source_pw/module_pwdft/vnl_pw.h"
#include "source_base/kernels/math_kernel_op.h"
#include "source_pw/module_pwdft/exx_helper.h"
#include "source_lcao/module_dftu/dftu.h" // mohan add 2025-11-06
#include "source_pw/module_pwdft/exx_helper.h"
#include "source_pw/module_pwdft/vnl_pw.h"

namespace hamilt
{
Expand All @@ -18,22 +18,18 @@ template <typename T, typename Device = base_device::DEVICE_CPU>
class HamiltPW : public Hamilt<T, Device>
{
private:
// Note GetTypeReal<T>::type will
// return T if T is real type(float, double),
// Note GetTypeReal<T>::type will
// return T if T is real type(float, double),
// otherwise return the real type of T(complex<float>, std::complex<double>)
using Real = typename GetTypeReal<T>::type;

public:

HamiltPW(elecstate::Potential* pot_in,
ModulePW::PW_Basis_K* wfc_basis,
K_Vectors* p_kv,
pseudopot_cell_vnl* nlpp,
Plus_U *p_dftu, // mohan add 2025-11-06
const UnitCell* ucell);

template<typename T_in, typename Device_in = Device>
explicit HamiltPW(const HamiltPW<T_in, Device_in>* hamilt);
HamiltPW(elecstate::Potential* pot_in,
ModulePW::PW_Basis_K* wfc_basis,
K_Vectors* p_kv,
pseudopot_cell_vnl* nlpp,
Plus_U* p_dftu, // mohan add 2025-11-06
const UnitCell* ucell);

~HamiltPW();

Expand All @@ -49,7 +45,7 @@ class HamiltPW : public Hamilt<T, Device>

void set_exx_helper(Exx_Helper<T, Device>& exx_helper_in);

protected:
protected:
// used in sPhi, which are calculated in hPsi or sPhi
const pseudopot_cell_vnl* ppcell = nullptr;
const UnitCell* const ucell = nullptr;
Expand Down
Loading