Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 74 additions & 20 deletions source/source_md/md_func.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "md_func.h"
#include "md_statistics.h"

#include "source_base/global_variable.h"
#include "source_base/timer.h"
Expand Down Expand Up @@ -52,28 +53,84 @@ double kinetic_energy(const int& natom, const ModuleBase::Vector3<double>* vel,
return ke;
}

void compute_stress(const UnitCell& unit_in,
const ModuleBase::Vector3<double>* vel,
const double* allmass,
const bool& cal_stress,
const ModuleBase::matrix& virial,
ModuleBase::matrix& stress)
// ============================================================================
// === New: calc_kinetic_state — pure function =================================
// ============================================================================

MDKineticState calc_kinetic_state(const int natom,
const int frozen_freedom,
const double* allmass,
const ModuleBase::Vector3<double>* vel)
{
if (cal_stress)
MDKineticState state;
if (3 * natom == frozen_freedom)
{
state.kinetic = 0.0;
state.temperature = 0.0;
}
else
{
ModuleBase::matrix t_vector;
state.kinetic = kinetic_energy(natom, vel, allmass);
state.temperature = 2.0 * state.kinetic / static_cast<double>(3 * natom - frozen_freedom);
}
return state;
}

temp_vector(unit_in.nat, vel, allmass, t_vector);
// ============================================================================
// === New: calc_stress_state — pure function ==================================
// ============================================================================

MDStressState calc_stress_state(const UnitCell& unit_in,
const ModuleBase::Vector3<double>* vel,
const double* allmass,
const ModuleBase::matrix& virial)
{
MDStressState state;
// create(3,3) zeros the matrix (flag_zero defaults to true), so += below is safe
state.t_vector.create(3, 3);
state.stress.create(3, 3);

// compute temperature tensor (same logic as temp_vector)
for (int ion = 0; ion < unit_in.nat; ++ion)
{
for (int i = 0; i < 3; ++i)
{
for (int j = 0; j < 3; ++j)
{
stress(i, j) = virial(i, j) + t_vector(i, j) / unit_in.omega;
state.t_vector(i, j) += allmass[ion] * vel[ion][i] * vel[ion][j];
}
}
}

// total stress = virial + t_vector/omega
for (int i = 0; i < 3; ++i)
{
for (int j = 0; j < 3; ++j)
{
state.stress(i, j) = virial(i, j) + state.t_vector(i, j) / unit_in.omega;
}
}

return state;
}

// ============================================================================
// === Old interface: compute_stress — now a wrapper around calc_stress_state ===
// ============================================================================

void compute_stress(const UnitCell& unit_in,
const ModuleBase::Vector3<double>* vel,
const double* allmass,
const bool& cal_stress,
const ModuleBase::matrix& virial,
ModuleBase::matrix& stress)
{
if (cal_stress)
{
MDStressState state = calc_stress_state(unit_in, vel, allmass, virial);
stress = state.stress;
}

return;
}

Expand Down Expand Up @@ -450,22 +507,19 @@ double target_temp(const int& istep, const int& nstep, const double& tfirst, con
return tfirst + delta * (tlast - tfirst);
}

// ============================================================================
// === Old interface: current_temp — now a wrapper around calc_kinetic_state ===
// ============================================================================

double current_temp(double& kinetic,
const int& natom,
const int& frozen_freedom,
const double* allmass,
const ModuleBase::Vector3<double>* vel)
{
if (3 * natom == frozen_freedom)
{
kinetic = 0.0;
return 0.0;
}
else
{
kinetic = kinetic_energy(natom, vel, allmass);
return 2 * kinetic / (3 * natom - frozen_freedom);
}
MDKineticState state = calc_kinetic_state(natom, frozen_freedom, allmass, vel);
kinetic = state.kinetic;
return state.temperature;
}

void temp_vector(const int& natom,
Expand Down
43 changes: 43 additions & 0 deletions source/source_md/md_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define MD_FUNC_H

#include "source_esolver/esolver.h"
#include "md_statistics.h"

class Parameter;

Expand Down Expand Up @@ -117,6 +118,48 @@ void force_virial(ModuleESolver::ESolver* p_esolver,
*/
double kinetic_energy(const int& natom, const ModuleBase::Vector3<double>* vel, const double* allmass);

// ============================================================================
// === New: pure-function versions — read-only inputs, explicit struct return ===
// ============================================================================

/**
* @brief compute kinetic energy and temperature as a pure function (no side effects)
*
* Pure function: does not modify any external state.
* Safer than current_temp(kinetic, ...) for unit testing and parallel calls.
*
* @param natom number of atoms
* @param frozen_freedom number of frozen degrees of freedom
* @param allmass atomic mass array
* @param vel atomic velocity array
* @return MDKineticState containing kinetic energy and temperature
*/
MDKineticState calc_kinetic_state(const int natom,
const int frozen_freedom,
const double* allmass,
const ModuleBase::Vector3<double>* vel);

/**
* @brief compute ionic kinetic stress contribution and total stress as a pure function
*
* Pure function: does not modify virial/stress references.
* The caller decides how to use the returned struct.
*
* @param unit_in unitcell information
* @param vel atomic velocity array
* @param allmass atomic mass array
* @param virial lattice virial tensor
* @return MDStressState containing t_vector and total stress
*/
MDStressState calc_stress_state(const UnitCell& unit_in,
const ModuleBase::Vector3<double>* vel,
const double* allmass,
const ModuleBase::matrix& virial);

// ============================================================================
// === Old write-back interfaces preserved as wrappers =========================
// ============================================================================

/**
* @brief calculate the total stress tensor
*
Expand Down
52 changes: 52 additions & 0 deletions source/source_md/md_statistics.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#ifndef MD_STATISTICS_H
#define MD_STATISTICS_H

#include "source_base/matrix.h"

/**
* @brief Kinetic energy and temperature statistics result - pure data structure
*
* Replaces the kinetic energy write-back via reference parameter
* in current_temp(kinetic, natom, frozen_freedom, allmass, vel).
*/
struct MDKineticState
{
double kinetic = 0.0; ///< kinetic energy (Hartree)
double temperature = 0.0; ///< temperature (Hartree)

/// Convenience conversion to Kelvin
double temperature_kelvin(double hartree_to_k) const
{
return temperature * hartree_to_k;
}
};

/**
* @brief Stress statistics result - pure data structure
*
* Replaces the implicit write-back of both virial and stress
* through reference parameters in compute_stress().
* Separates the ionic kinetic contribution tensor t_vector
* from the total stress.
*/
struct MDStressState
{
ModuleBase::matrix t_vector; ///< ionic kinetic contribution tensor (3x3)
ModuleBase::matrix stress; ///< total stress = virial + t_vector/omega (3x3)
};

/**
* @brief FIRE optimizer projection statistics - pure data structure
*
* Replaces the scattered P, sumforce, normvel local variables
* in FIRE::check_fire().
*/
struct FIREProjection
{
double power = 0.0; ///< P = sum v_i · f_i
double force_norm = 0.0; ///< |f| = sqrt(sum |f_i|^2)
double velocity_norm = 0.0; ///< |v| = sqrt(sum |v_i|^2)
double max_force = 0.0; ///< max |f_i| component
};

#endif // MD_STATISTICS_H
23 changes: 19 additions & 4 deletions source/source_md/msst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,15 @@ void MSST::setup(ModuleESolver::ESolver* p_esolver, const std::string& global_re
}
}

MD_func::compute_stress(ucell, vel, allmass, cal_stress, virial, stress);
t_current = MD_func::current_temp(kinetic, ucell.nat, frozen_freedom_, allmass, vel);
if (cal_stress)
{
MDStressState sstate = MD_func::calc_stress_state(ucell, vel, allmass, virial);
stress = sstate.stress;
}

MDKineticState kstate = MD_func::calc_kinetic_state(ucell.nat, frozen_freedom_, allmass, vel);
kinetic = kstate.kinetic;
t_current = kstate.temperature;
}

ModuleBase::timer::end("MSST", "setup");
Expand Down Expand Up @@ -144,8 +151,16 @@ void MSST::second_half()
propagate_vel();

vsum = vel_sum();
MD_func::compute_stress(ucell, vel, allmass, cal_stress, virial, stress);
t_current = MD_func::current_temp(kinetic, ucell.nat, frozen_freedom_, allmass, vel);

if (cal_stress)
{
MDStressState sstate = MD_func::calc_stress_state(ucell, vel, allmass, virial);
stress = sstate.stress;
}

MDKineticState kstate = MD_func::calc_kinetic_state(ucell.nat, frozen_freedom_, allmass, vel);
kinetic = kstate.kinetic;
t_current = kstate.temperature;

/// propagate the time derivative of volume 1/2 step
propagate_voldot();
Expand Down
21 changes: 17 additions & 4 deletions source/source_md/nhchain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,15 @@ void Nose_Hoover::first_half(std::ofstream& ofs)
if (npt_flag)
{
/// update temperature and stress due to velocity rescaling
t_current = MD_func::current_temp(kinetic, ucell.nat, frozen_freedom_, allmass, vel);
MD_func::compute_stress(ucell, vel, allmass, cal_stress, virial, stress);
MDKineticState kstate = MD_func::calc_kinetic_state(ucell.nat, frozen_freedom_, allmass, vel);
kinetic = kstate.kinetic;
t_current = kstate.temperature;

if (cal_stress)
{
MDStressState sstate = MD_func::calc_stress_state(ucell, vel, allmass, virial);
stress = sstate.stress;
}

/// couple stress component due to md_pcouple
couple_stress();
Expand Down Expand Up @@ -287,12 +294,18 @@ void Nose_Hoover::second_half()
}

/// update temperature and kinetic energy due to velocity rescaling
t_current = MD_func::current_temp(kinetic, ucell.nat, frozen_freedom_, allmass, vel);
MDKineticState kstate = MD_func::calc_kinetic_state(ucell.nat, frozen_freedom_, allmass, vel);
kinetic = kstate.kinetic;
t_current = kstate.temperature;

if (npt_flag)
{
/// update stress due to velocity rescaling
MD_func::compute_stress(ucell, vel, allmass, cal_stress, virial, stress);
if (cal_stress)
{
MDStressState sstate = MD_func::calc_stress_state(ucell, vel, allmass, virial);
stress = sstate.stress;
}

/// couple stress component due to md_pcouple
couple_stress();
Expand Down
Loading