Skip to content
19 changes: 19 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def forward_common_atomic(
charge_spin=charge_spin,
)
ret_dict = self.apply_out_stat(ret_dict, atype)
ret_dict = self._apply_aparam_output_gate_after_out_stat(ret_dict, aparam)

# nf x nloc
atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc)
Expand Down Expand Up @@ -617,6 +618,24 @@ def deserialize(cls, data: dict) -> "BaseAtomicModel":
obj[kk] = variables[kk]
return obj

def _apply_aparam_output_gate_after_out_stat(
self,
ret_dict: dict[str, Array],
aparam: Array | None,
) -> dict[str, Array]:
"""Gate atomic outputs after out_bias has been applied."""
fitting = getattr(self, "fitting_net", None)
if fitting is None or not fitting.use_aparam_output_gate:
return ret_dict
var_name = fitting.var_name
if var_name not in ret_dict:
return ret_dict
ret_dict[var_name] = fitting.apply_aparam_output_gate_to_atomic_output(
ret_dict[var_name],
aparam,
)
return ret_dict

def apply_out_stat(
self,
ret: dict[str, Array],
Expand Down
93 changes: 86 additions & 7 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def __init__(
precision: str = DEFAULT_PRECISION,
layer_name: list[str | None] | None = None,
use_aparam_as_mask: bool = False,
use_aparam_output_gate: bool = False,
aparam_gate_norm: float = 1.0,
aparam_gate_clamp: bool = True,
spin: Any = None,
mixed_types: bool = True,
exclude_types: list[int] = [],
Expand All @@ -138,6 +141,15 @@ def __init__(
seed: int | list[int] | None = None,
default_fparam: list[float] | None = None,
) -> None:
if use_aparam_output_gate and numb_aparam <= 0:
raise ValueError(
"use_aparam_output_gate requires numb_aparam > 0, "
f"got numb_aparam={numb_aparam}"
)
if aparam_gate_norm <= 0.0:
raise ValueError(
f"aparam_gate_norm must be positive, got {aparam_gate_norm}"
)
self.var_name = var_name
self.ntypes = ntypes
self.dim_descrpt = dim_descrpt
Expand All @@ -164,6 +176,9 @@ def __init__(
self.prec = PRECISION_DICT[self.precision.lower()]
self.layer_name = layer_name
self.use_aparam_as_mask = use_aparam_as_mask
self.use_aparam_output_gate = use_aparam_output_gate
self.aparam_gate_norm = aparam_gate_norm
self.aparam_gate_clamp = aparam_gate_clamp
self.spin = spin
self.mixed_types = mixed_types
# order matters, should be place after the assignment of ntypes
Expand Down Expand Up @@ -594,6 +609,9 @@ def serialize(self) -> dict:
"trainable": self.trainable,
"layer_name": self.layer_name,
"use_aparam_as_mask": self.use_aparam_as_mask,
"use_aparam_output_gate": self.use_aparam_output_gate,
"aparam_gate_norm": self.aparam_gate_norm,
"aparam_gate_clamp": self.aparam_gate_clamp,
"spin": self.spin,
}

Expand All @@ -610,6 +628,58 @@ def deserialize(cls, data: dict) -> "GeneralFitting":
obj.nets = NetworkCollection.deserialize(nets)
return obj

def _compute_aparam_output_gate(self, aparam_raw: Array) -> Array:
"""Hard-coded gate g = a^2 / (sigma^2 * norm) from raw aparam."""
xp = array_api_compat.array_namespace(aparam_raw)
assert self.aparam_inv_std is not None
sigma = 1.0 / self.aparam_inv_std
gate = (aparam_raw * aparam_raw) / (
sigma * sigma * self.aparam_gate_norm + 1e-12
)
if self.numb_aparam > 1:
gate = xp.prod(gate, axis=-1, keepdims=True)
if self.aparam_gate_clamp:
gate = xp.clip(gate, 0.0, 1.0)
return gate

def _apply_aparam_output_gate(
self,
outs: Array,
aparam_raw: Array | None,
) -> Array:
if not self.use_aparam_output_gate:
return outs
if aparam_raw is None:
raise ValueError(
"aparam is required when use_aparam_output_gate is enabled"
)
gate = self._compute_aparam_output_gate(aparam_raw)
return outs * gate

def apply_aparam_output_gate_to_atomic_output(
self,
outs: Array,
aparam: Array | None,
) -> Array:
"""Apply the aparam gate to atomic outputs after out_bias is added."""
if not self.use_aparam_output_gate:
return outs
if aparam is None:
raise ValueError(
"aparam is required when use_aparam_output_gate is enabled"
)
xp = array_api_compat.array_namespace(outs, aparam)
try:
aparam_raw = xp.reshape(
aparam, (outs.shape[0], outs.shape[1], self.numb_aparam)
)
except (ValueError, RuntimeError) as e:
raise ValueError(
f"input aparam: cannot reshape {aparam.shape} "
f"into ({outs.shape[0]}, {outs.shape[1]}, {self.numb_aparam})."
) from e
return self._apply_aparam_output_gate(outs, aparam_raw)

def _call_common(
self,
descriptor: Array,
Expand Down Expand Up @@ -693,24 +763,33 @@ def _call_common(
[xx_zeros, fparam],
axis=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam is not None, "aparam should not be None"
aparam_raw: Array | None = None
if self.numb_aparam > 0 and aparam is not None:
try:
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
aparam_raw = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
except (ValueError, RuntimeError) as e:
raise ValueError(
f"input aparam: cannot reshape {aparam.shape} "
f"into ({nf}, {nloc}, {self.numb_aparam})."
) from e
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
if self.use_aparam_output_gate and aparam_raw is None:
raise ValueError(
"aparam is required when use_aparam_output_gate is enabled"
)

# check aparam dim, concate to input descriptor
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam_raw is not None, "aparam should not be None"
aparam_embed = (aparam_raw - self.aparam_avg[...]) * self.aparam_inv_std[
...
]
xx = xp.concat(
[xx, aparam],
[xx, aparam_embed],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = xp.concat(
[xx_zeros, aparam],
[xx_zeros, aparam_embed],
axis=-1,
)

Expand Down
21 changes: 21 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def forward_common_atomic(
charge_spin=charge_spin,
)
ret_dict = self.apply_out_stat(ret_dict, atype)
ret_dict = self._apply_aparam_output_gate_after_out_stat(ret_dict, aparam)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].to(torch.int32)
Expand Down Expand Up @@ -539,6 +540,26 @@ def compute_or_load_out_stat(
bias_adjust_mode="set-by-statistic",
)

def _apply_aparam_output_gate_after_out_stat(
self,
ret_dict: dict[str, torch.Tensor],
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
"""Gate atomic outputs after out_bias has been applied."""
if not hasattr(self, "fitting_net"):
return ret_dict
fitting = self.fitting_net
if not fitting.use_aparam_output_gate:
return ret_dict
var_name = fitting.var_name
if var_name not in ret_dict:
return ret_dict
ret_dict[var_name] = fitting.apply_aparam_output_gate_to_atomic_output(
ret_dict[var_name],
aparam,
)
return ret_dict

def apply_out_stat(
self,
ret: dict[str, torch.Tensor],
Expand Down
32 changes: 24 additions & 8 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import inspect
import logging
from collections.abc import (
Callable,
Expand Down Expand Up @@ -73,6 +74,9 @@ def __init__(
self.enable_eval_fitting_last_layer_hook = False
self.eval_descriptor_list = []
self.eval_fitting_last_layer_list = []
self._descriptor_accepts_aparam = (
"aparam" in inspect.signature(self.descriptor.forward).parameters
)

eval_descriptor_list: list[torch.Tensor]
eval_fitting_last_layer_list: list[torch.Tensor]
Expand Down Expand Up @@ -281,14 +285,26 @@ def forward_atomic(
default_cs_tensor = default_cs_tensor.to(device=extended_coord.device)
charge_spin = torch.tile(default_cs_tensor.unsqueeze(0), [nframes, 1])

descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict,
charge_spin=charge_spin if self.add_chg_spin_ebd else None,
)
charge_spin_arg = charge_spin if self.add_chg_spin_ebd else None
if self._descriptor_accepts_aparam:
descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict,
charge_spin=charge_spin_arg,
aparam=aparam,
)
else:
descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict,
charge_spin=charge_spin_arg,
)
assert descriptor is not None
if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor.detach())
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
DescrptBlockSeA,
DescrptSeA,
)
from .se_a_vg import (
DescrptBlockSeAVg,
DescrptSeAVg,
)
from .se_atten_v2 import (
DescrptSeAttenV2,
)
Expand All @@ -51,13 +55,15 @@
"DescriptorBlock",
"DescrptBlockRepformers",
"DescrptBlockSeA",
"DescrptBlockSeAVg",
"DescrptBlockSeAtten",
"DescrptBlockSeTTebd",
"DescrptDPA1",
"DescrptDPA2",
"DescrptDPA3",
"DescrptHybrid",
"DescrptSeA",
"DescrptSeAVg",
"DescrptSeAttenV2",
"DescrptSeR",
"DescrptSeT",
Expand Down
Loading