diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index cf59af94db..74dffd3d4a 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -304,6 +304,7 @@ def forward_common_atomic( charge_spin=charge_spin, ) ret_dict = self.apply_out_stat(ret_dict, atype) + ret_dict = self._apply_aparam_output_gate_after_out_stat(ret_dict, aparam) # nf x nloc atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc) @@ -617,6 +618,24 @@ def deserialize(cls, data: dict) -> "BaseAtomicModel": obj[kk] = variables[kk] return obj + def _apply_aparam_output_gate_after_out_stat( + self, + ret_dict: dict[str, Array], + aparam: Array | None, + ) -> dict[str, Array]: + """Gate atomic outputs after out_bias has been applied.""" + fitting = getattr(self, "fitting_net", None) + if fitting is None or not fitting.use_aparam_output_gate: + return ret_dict + var_name = fitting.var_name + if var_name not in ret_dict: + return ret_dict + ret_dict[var_name] = fitting.apply_aparam_output_gate_to_atomic_output( + ret_dict[var_name], + aparam, + ) + return ret_dict + def apply_out_stat( self, ret: dict[str, Array], diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index b9129a4364..e985dc650a 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -130,6 +130,9 @@ def __init__( precision: str = DEFAULT_PRECISION, layer_name: list[str | None] | None = None, use_aparam_as_mask: bool = False, + use_aparam_output_gate: bool = False, + aparam_gate_norm: float = 1.0, + aparam_gate_clamp: bool = True, spin: Any = None, mixed_types: bool = True, exclude_types: list[int] = [], @@ -138,6 +141,15 @@ def __init__( seed: int | list[int] | None = None, default_fparam: list[float] | None = None, ) -> None: + if use_aparam_output_gate and numb_aparam <= 0: + raise ValueError( + "use_aparam_output_gate requires numb_aparam > 0, " + f"got numb_aparam={numb_aparam}" + ) + if aparam_gate_norm <= 0.0: + raise ValueError( + f"aparam_gate_norm must be positive, got {aparam_gate_norm}" + ) self.var_name = var_name self.ntypes = ntypes self.dim_descrpt = dim_descrpt @@ -164,6 +176,9 @@ def __init__( self.prec = PRECISION_DICT[self.precision.lower()] self.layer_name = layer_name self.use_aparam_as_mask = use_aparam_as_mask + self.use_aparam_output_gate = use_aparam_output_gate + self.aparam_gate_norm = aparam_gate_norm + self.aparam_gate_clamp = aparam_gate_clamp self.spin = spin self.mixed_types = mixed_types # order matters, should be place after the assignment of ntypes @@ -594,6 +609,9 @@ def serialize(self) -> dict: "trainable": self.trainable, "layer_name": self.layer_name, "use_aparam_as_mask": self.use_aparam_as_mask, + "use_aparam_output_gate": self.use_aparam_output_gate, + "aparam_gate_norm": self.aparam_gate_norm, + "aparam_gate_clamp": self.aparam_gate_clamp, "spin": self.spin, } @@ -610,6 +628,58 @@ def deserialize(cls, data: dict) -> "GeneralFitting": obj.nets = NetworkCollection.deserialize(nets) return obj + def _compute_aparam_output_gate(self, aparam_raw: Array) -> Array: + """Hard-coded gate g = a^2 / (sigma^2 * norm) from raw aparam.""" + xp = array_api_compat.array_namespace(aparam_raw) + assert self.aparam_inv_std is not None + sigma = 1.0 / self.aparam_inv_std + gate = (aparam_raw * aparam_raw) / ( + sigma * sigma * self.aparam_gate_norm + 1e-12 + ) + if self.numb_aparam > 1: + gate = xp.prod(gate, axis=-1, keepdims=True) + if self.aparam_gate_clamp: + gate = xp.clip(gate, 0.0, 1.0) + return gate + + def _apply_aparam_output_gate( + self, + outs: Array, + aparam_raw: Array | None, + ) -> Array: + if not self.use_aparam_output_gate: + return outs + if aparam_raw is None: + raise ValueError( + "aparam is required when use_aparam_output_gate is enabled" + ) + gate = self._compute_aparam_output_gate(aparam_raw) + return outs * gate + + def apply_aparam_output_gate_to_atomic_output( + self, + outs: Array, + aparam: Array | None, + ) -> Array: + """Apply the aparam gate to atomic outputs after out_bias is added.""" + if not self.use_aparam_output_gate: + return outs + if aparam is None: + raise ValueError( + "aparam is required when use_aparam_output_gate is enabled" + ) + xp = array_api_compat.array_namespace(outs, aparam) + try: + aparam_raw = xp.reshape( + aparam, (outs.shape[0], outs.shape[1], self.numb_aparam) + ) + except (ValueError, RuntimeError) as e: + raise ValueError( + f"input aparam: cannot reshape {aparam.shape} " + f"into ({outs.shape[0]}, {outs.shape[1]}, {self.numb_aparam})." + ) from e + return self._apply_aparam_output_gate(outs, aparam_raw) + def _call_common( self, descriptor: Array, @@ -693,24 +763,33 @@ def _call_common( [xx_zeros, fparam], axis=-1, ) - # check aparam dim, concate to input descriptor - if self.numb_aparam > 0 and not self.use_aparam_as_mask: - assert aparam is not None, "aparam should not be None" + aparam_raw: Array | None = None + if self.numb_aparam > 0 and aparam is not None: try: - aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) + aparam_raw = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) except (ValueError, RuntimeError) as e: raise ValueError( f"input aparam: cannot reshape {aparam.shape} " f"into ({nf}, {nloc}, {self.numb_aparam})." ) from e - aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...] + if self.use_aparam_output_gate and aparam_raw is None: + raise ValueError( + "aparam is required when use_aparam_output_gate is enabled" + ) + + # check aparam dim, concate to input descriptor + if self.numb_aparam > 0 and not self.use_aparam_as_mask: + assert aparam_raw is not None, "aparam should not be None" + aparam_embed = (aparam_raw - self.aparam_avg[...]) * self.aparam_inv_std[ + ... + ] xx = xp.concat( - [xx, aparam], + [xx, aparam_embed], axis=-1, ) if xx_zeros is not None: xx_zeros = xp.concat( - [xx_zeros, aparam], + [xx_zeros, aparam_embed], axis=-1, ) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 8605db9359..87ed3856b7 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -377,6 +377,7 @@ def forward_common_atomic( charge_spin=charge_spin, ) ret_dict = self.apply_out_stat(ret_dict, atype) + ret_dict = self._apply_aparam_output_gate_after_out_stat(ret_dict, aparam) # nf x nloc atom_mask = ext_atom_mask[:, :nloc].to(torch.int32) @@ -539,6 +540,26 @@ def compute_or_load_out_stat( bias_adjust_mode="set-by-statistic", ) + def _apply_aparam_output_gate_after_out_stat( + self, + ret_dict: dict[str, torch.Tensor], + aparam: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + """Gate atomic outputs after out_bias has been applied.""" + if not hasattr(self, "fitting_net"): + return ret_dict + fitting = self.fitting_net + if not fitting.use_aparam_output_gate: + return ret_dict + var_name = fitting.var_name + if var_name not in ret_dict: + return ret_dict + ret_dict[var_name] = fitting.apply_aparam_output_gate_to_atomic_output( + ret_dict[var_name], + aparam, + ) + return ret_dict + def apply_out_stat( self, ret: dict[str, torch.Tensor], diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 783ee9e766..2b4f0bcbde 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import inspect import logging from collections.abc import ( Callable, @@ -73,6 +74,9 @@ def __init__( self.enable_eval_fitting_last_layer_hook = False self.eval_descriptor_list = [] self.eval_fitting_last_layer_list = [] + self._descriptor_accepts_aparam = ( + "aparam" in inspect.signature(self.descriptor.forward).parameters + ) eval_descriptor_list: list[torch.Tensor] eval_fitting_last_layer_list: list[torch.Tensor] @@ -281,14 +285,26 @@ def forward_atomic( default_cs_tensor = default_cs_tensor.to(device=extended_coord.device) charge_spin = torch.tile(default_cs_tensor.unsqueeze(0), [nframes, 1]) - descriptor, rot_mat, g2, h2, sw = self.descriptor( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - comm_dict=comm_dict, - charge_spin=charge_spin if self.add_chg_spin_ebd else None, - ) + charge_spin_arg = charge_spin if self.add_chg_spin_ebd else None + if self._descriptor_accepts_aparam: + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + comm_dict=comm_dict, + charge_spin=charge_spin_arg, + aparam=aparam, + ) + else: + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + comm_dict=comm_dict, + charge_spin=charge_spin_arg, + ) assert descriptor is not None if self.enable_eval_descriptor_hook: self.eval_descriptor_list.append(descriptor.detach()) diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 6da6c3b864..b38b797eae 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -29,6 +29,10 @@ DescrptBlockSeA, DescrptSeA, ) +from .se_a_vg import ( + DescrptBlockSeAVg, + DescrptSeAVg, +) from .se_atten_v2 import ( DescrptSeAttenV2, ) @@ -51,6 +55,7 @@ "DescriptorBlock", "DescrptBlockRepformers", "DescrptBlockSeA", + "DescrptBlockSeAVg", "DescrptBlockSeAtten", "DescrptBlockSeTTebd", "DescrptDPA1", @@ -58,6 +63,7 @@ "DescrptDPA3", "DescrptHybrid", "DescrptSeA", + "DescrptSeAVg", "DescrptSeAttenV2", "DescrptSeR", "DescrptSeT", diff --git a/deepmd/pt/model/descriptor/env_mat_vg.py b/deepmd/pt/model/descriptor/env_mat_vg.py new file mode 100644 index 0000000000..eb28dfb051 --- /dev/null +++ b/deepmd/pt/model/descriptor/env_mat_vg.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Environment matrix for variational-Gaussian (VG) smooth descriptors.""" + +import torch + +from deepmd.pt.utils.preprocess import ( + compute_smooth_weight, +) + +VG_ENV_DIM: int = 5 + + +def vg_gaussian_radial_phi( + length: torch.Tensor, + sigma_ij: torch.Tensor, + protection: float = 0.0, +) -> torch.Tensor: + """Gaussian-averaged 1/r kernel: (1/r) * erf(r / (sqrt(2)*sigma)).""" + r = length + protection + sigma = sigma_ij + 1e-12 + # literal sqrt(2) for TorchScript (module-level floats are not allowed) + return (1.0 / r) * torch.erf(r / (1.4142135623730951 * sigma)) + + +def vg_smooth_radial( + length: torch.Tensor, + sigma_ij: torch.Tensor, + rcut_smth: float, + rcut: float, + protection: float = 0.0, +) -> torch.Tensor: + """Radial kernel s(r, sigma) with the same smooth cutoff as DP-SE.""" + phi = vg_gaussian_radial_phi(length, sigma_ij, protection=protection) + weight = compute_smooth_weight(length, rcut_smth, rcut) + return phi * weight + + +def _normalize_aparam_vg( + aparam: torch.Tensor, + nlist: torch.Tensor, + nall: int, +) -> torch.Tensor: + """Ensure aparam is [nf, nloc, 1] with nloc matching nlist.""" + nf, nloc, _ = nlist.shape + if aparam.ndim == 2: + aparam = aparam.unsqueeze(-1) + aparam = aparam[..., :1] + if aparam.shape[1] == nloc: + return aparam + if aparam.shape[1] == nall: + return aparam[:, :nloc, :] + return aparam.reshape(nf, nloc, 1) + + +def _gather_neighbor_sigma( + aparam: torch.Tensor, + nlist: torch.Tensor, + nloc: int, + nall: int, +) -> torch.Tensor: + """Map per-atom aparam to neighbor-list sigma values.""" + nf, _, nnei = nlist.shape + mask = nlist >= 0 + sigma_loc = aparam[:, :nloc, 0] + # Pad one slot so invalid neighbors (-1) can map to index nall, as in env_mat. + sigma_ext = torch.zeros( + (nf, nall + 1), + dtype=sigma_loc.dtype, + device=sigma_loc.device, + ) + sigma_ext[:, :nloc] = sigma_loc + nlist_safe = torch.where(mask, nlist, nall).to(torch.int64) + index = nlist_safe.reshape(nf, -1) + sigma_nei = torch.gather(sigma_ext, 1, index) + sigma_nei = sigma_nei.view(nf, nloc, nnei) + return torch.where(mask, sigma_nei, torch.zeros_like(sigma_nei)) + + +def _make_env_mat_vg( + nlist: torch.Tensor, + coord: torch.Tensor, + aparam: torch.Tensor, + rcut: float, + rcut_smth: float, + protection: float = 0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build the 5-column VG local environment matrix.""" + bsz, natoms, nnei = nlist.shape + coord = coord.view(bsz, -1, 3) + nall = coord.shape[1] + aparam = _normalize_aparam_vg(aparam, nlist, nall) + mask = nlist >= 0 + nlist_safe = torch.where(mask, nlist, nall) + + coord_l = coord[:, :natoms].view(bsz, -1, 1, 3) + index = nlist_safe.view(bsz, -1).to(torch.int64).unsqueeze(-1).expand(-1, -1, 3) + coord_pad = torch.concat([coord, coord[:, -1:, :] + rcut], dim=1) + coord_r = torch.gather(coord_pad, 1, index).view(bsz, natoms, nnei, 3) + diff = coord_r - coord_l + length = torch.linalg.norm(diff, dim=-1, keepdim=True) + length = length + (~mask).unsqueeze(-1) + + sigma_loc = aparam[:, :natoms, 0] + sigma_neighbor = _gather_neighbor_sigma(aparam, nlist, natoms, nall) + # [bsz, natoms, 1] + [bsz, natoms, nnei] -> [bsz, natoms, nnei] + sigma_ij = torch.sqrt( + sigma_loc.unsqueeze(-1).square() + sigma_neighbor.square() + ).unsqueeze(-1) + + s_val = vg_smooth_radial( + length, + sigma_ij, + rcut_smth, + rcut, + protection=protection, + ) + s_val = s_val * mask.unsqueeze(-1) + x_hat = diff / (length + protection) + sigma_col = s_val * sigma_ij / (length + protection) + + env_mat = torch.cat( + [ + s_val, + s_val * x_hat, + sigma_col, + ], + dim=-1, + ) + return env_mat, diff * mask.unsqueeze(-1), s_val + + +def prod_env_mat_vg( + extended_coord: torch.Tensor, + nlist: torch.Tensor, + atype: torch.Tensor, + aparam: torch.Tensor, + mean: torch.Tensor, + stddev: torch.Tensor, + rcut: float, + rcut_smth: float, + protection: float = 0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Normalized VG environment matrix for descriptor embedding.""" + env_mat, diff, switch = _make_env_mat_vg( + nlist, + extended_coord, + aparam, + rcut, + rcut_smth, + protection=protection, + ) + t_avg = mean[atype] + t_std = stddev[atype] + env_mat = (env_mat - t_avg) / t_std + return env_mat, diff, switch + + +def tabulate_fusion_se_a_vg( + table: torch.Tensor, + table_info: torch.Tensor, + em_x: torch.Tensor, + em: torch.Tensor, + last_layer_size: int, +) -> torch.Tensor: + """Tabulate the VG 5-column env mat via two 4-column fusion calls.""" + gr4 = torch.ops.deepmd.tabulate_fusion_se_a( + table, + table_info, + em_x, + em[..., :4].contiguous(), + last_layer_size, + )[0] + em5 = torch.zeros_like(em[..., :4]) + em5[..., 0:1] = em[..., 4:5] + gr5 = torch.ops.deepmd.tabulate_fusion_se_a( + table, + table_info, + em_x, + em5.contiguous(), + last_layer_size, + )[0] + return torch.cat([gr4, gr5[:, 0:1, :]], dim=1) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index d840c8c001..4bcc74895d 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -322,6 +322,7 @@ def forward( comm_dict: dict[str, torch.Tensor] | None = None, fparam: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -361,6 +362,7 @@ def forward( The smooth switch function. """ + del aparam # cast the input to internal precsion coord_ext = coord_ext.to(dtype=self.prec) g1, rot_mat, g2, h2, sw = self.sea.forward( diff --git a/deepmd/pt/model/descriptor/se_a_vg.py b/deepmd/pt/model/descriptor/se_a_vg.py new file mode 100644 index 0000000000..8dadca8748 --- /dev/null +++ b/deepmd/pt/model/descriptor/se_a_vg.py @@ -0,0 +1,635 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Variational-Gaussian smooth descriptor (se_a_vg) for DeepMD PT backend.""" + +from __future__ import ( + annotations, +) + +import itertools +from collections.abc import ( + Callable, +) +from typing import ( + Any, + ClassVar, + Final, +) + +import numpy as np +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.descriptor import ( + DescriptorBlock, +) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.descriptor.env_mat_vg import ( + VG_ENV_DIM, + prod_env_mat_vg, + tabulate_fusion_se_a_vg, +) +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.network.mlp import ( + EmbeddingNet, + NetworkCollection, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) +from deepmd.pt.utils.update_sel import ( + UpdateSel, +) +from deepmd.pt.utils.utils import ( + ActivationFn, + to_numpy_array, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.path import ( + DPPath, +) + +if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"): + + def tabulate_fusion_se_a( + argument0: torch.Tensor, + argument1: torch.Tensor, + argument2: torch.Tensor, + argument3: torch.Tensor, + argument4: int, + ) -> list[torch.Tensor]: + raise NotImplementedError( + "tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for model compression for details." + ) + + torch.ops.deepmd.tabulate_fusion_se_a = tabulate_fusion_se_a + + +@DescriptorBlock.register("se_a_vg") +class DescrptBlockSeAVg(DescriptorBlock): + """DP-SE descriptor block with VGM Gaussian-averaged radial kernel (5-column env mat).""" + + ndescrpt: Final[int] + __constants__: ClassVar[list] = ["ndescrpt"] + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: int | list[int], + neuron: list[int] | None = None, + axis_neuron: int = 16, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: list[tuple[int, int]] | None = None, + env_protection: float = 0.0, + type_one_side: bool = True, + trainable: bool = True, + seed: int | list[int] | None = None, + **kwargs: Any, + ) -> None: + del kwargs + super().__init__() + if neuron is None: + neuron = [25, 50, 100] + if exclude_types is None: + exclude_types = [] + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) + self.neuron = neuron + self.filter_neuron = self.neuron + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt + self.env_protection = env_protection + self.ntypes = len(sel) + self.type_one_side = type_one_side + self.seed = seed + self.reinit_exclude(exclude_types) + + self.sel = sel if isinstance(sel, list) else [sel] + self.sec = [0, *np.cumsum(self.sel).tolist()] + self.nnei = sum(self.sel) + self.ndescrpt = self.nnei * VG_ENV_DIM + + wanted_shape = (self.ntypes, self.nnei, VG_ENV_DIM) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + + ndim = 1 if self.type_one_side else 2 + filter_layers = NetworkCollection( + ndim=ndim, ntypes=len(self.sel), network_type="embedding_network" + ) + for ii, embedding_idx in enumerate( + itertools.product(range(self.ntypes), repeat=ndim) + ): + filter_layers[embedding_idx] = EmbeddingNet( + 1, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + seed=child_seed(self.seed, ii), + trainable=trainable, + ) + self.filter_layers = filter_layers + self.stats = None + self.trainable = trainable + for param in self.parameters(): + param.requires_grad = trainable + self.compress = False + self.compress_info = nn.ParameterList( + [ + nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu")) + for _ in range(len(self.filter_layers.networks)) + ] + ) + self.compress_data = nn.ParameterList( + [ + nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE)) + for _ in range(len(self.filter_layers.networks)) + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_rcut_smth(self) -> float: + return self.rcut_smth + + def get_nsel(self) -> int: + return self.nnei + + def get_sel(self) -> list[int]: + return self.sel + + def get_ntypes(self) -> int: + return self.ntypes + + def get_dim_out(self) -> int: + return self.dim_out + + def get_dim_emb(self) -> int: + return self.neuron[-1] + + def get_dim_in(self) -> int: + return 0 + + def mixed_types(self) -> bool: + return False + + def get_env_protection(self) -> float: + return self.env_protection + + @property + def dim_out(self) -> int: + return self.filter_neuron[-1] * self.axis_neuron + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key: str) -> torch.Tensor: + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] | None = None, + ) -> None: + if exclude_types is None: + exclude_types = [] + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def compute_input_stats( + self, + merged: Callable[[], list[dict]] | list[dict], + path: DPPath | None = None, + ) -> None: + if callable(merged): + sampled = merged() + else: + sampled = merged + sumv = np.zeros((self.ntypes, self.nnei, VG_ENV_DIM), dtype=np.float64) + sumv2 = np.zeros_like(sumv) + sumn = np.zeros((self.ntypes, self.nnei), dtype=np.float64) + for system in sampled: + coord = to_numpy_array(system["coord"]) + atype_raw = system["atype"] + if isinstance(atype_raw, torch.Tensor): + atype = atype_raw.detach().cpu().numpy().astype(np.int32) + else: + atype = np.asarray(atype_raw, dtype=np.int32) + box_raw = system.get("box") + box = to_numpy_array(box_raw) if box_raw is not None else None + nframes, nloc = atype.shape[:2] + aparam = system.get("aparam") + if aparam is None: + aparam_np = np.zeros((nframes, nloc, 1), dtype=np.float64) + else: + aparam_flat = to_numpy_array(aparam).reshape(-1) + expected = nframes * nloc + if aparam_flat.size != expected: + raise ValueError( + f"aparam size {aparam_flat.size} != nframes*nloc " + f"({nframes}*{nloc}={expected}); check numb_aparam " + "and training data aparam layout" + ) + aparam_np = aparam_flat.reshape(nframes, nloc, 1) + coord_t = torch.tensor(coord, dtype=self.prec, device=env.DEVICE) + atype_t = torch.tensor(atype, dtype=torch.long, device=env.DEVICE) + aparam_t = torch.tensor(aparam_np, dtype=self.prec, device=env.DEVICE) + box_t = None + if box is not None: + box_t = torch.tensor(box, dtype=self.prec, device=env.DEVICE) + extended_coord, extended_atype, _, nlist = ( + extend_input_and_build_neighbor_list( + coord_t, + atype_t, + self.rcut, + self.sel, + mixed_types=False, + box=box_t, + ) + ) + nloc_nlist = nlist.shape[1] + env_mat, _, _ = prod_env_mat_vg( + extended_coord, + nlist, + extended_atype[:, :nloc_nlist], + aparam_t, + self.mean, + torch.ones_like(self.stddev), + self.rcut, + self.rcut_smth, + protection=self.env_protection, + ) + nnei_nlist = nlist.shape[2] + env_mat = ( + env_mat.detach() + .cpu() + .numpy() + .reshape(nframes, nloc_nlist, nnei_nlist, VG_ENV_DIM) + ) + for ff in range(nframes): + for ii in range(nloc_nlist): + ti = int(atype[ff, ii]) + sumv[ti] += env_mat[ff, ii] + sumv2[ti] += env_mat[ff, ii] * env_mat[ff, ii] + sumn[ti] += 1.0 + sumn_safe = np.maximum(sumn, 1.0)[..., None] + mean = sumv / sumn_safe + var = sumv2 / sumn_safe - mean * mean + stddev = np.sqrt(np.maximum(var, 1e-2)) + if not self.set_davg_zero: + self.mean.copy_(torch.tensor(mean, dtype=self.prec, device=env.DEVICE)) + self.stddev.copy_(torch.tensor(stddev, dtype=self.prec, device=env.DEVICE)) + + def enable_compression( + self, + table_data: dict[str, torch.Tensor], + table_config: list[int | float], + lower: dict[str, int], + upper: dict[str, int], + ) -> None: + for embedding_idx, ll in enumerate(self.filter_layers.networks): + del ll + if self.type_one_side: + ii = embedding_idx + ti = -1 + else: + ii = embedding_idx // self.ntypes + ti = embedding_idx % self.ntypes + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + else: + net = "filter_" + str(ti) + "_net_" + str(ii) + info_ii = torch.as_tensor( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=self.prec, + device="cpu", + ) + tensor_data_ii = table_data[net].to(device=env.DEVICE, dtype=self.prec) + self.compress_data[embedding_idx] = tensor_data_ii + self.compress_info[embedding_idx] = info_ii + self.compress = True + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, mapping, type_embedding + nf = nlist.shape[0] + nloc = nlist.shape[1] + atype = extended_atype[:, :nloc] + if aparam is None: + aparam = torch.zeros( + (nf, nloc, 1), + dtype=self.prec, + device=extended_coord.device, + ) + else: + aparam = aparam.to(dtype=self.prec, device=extended_coord.device) + aparam = aparam[..., :1] + if aparam.shape[1] != nloc: + aparam = aparam.reshape(nf, nloc, 1) + + dmatrix, diff, sw = prod_env_mat_vg( + extended_coord, + nlist, + atype, + aparam, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + protection=self.env_protection, + ) + # literal 5 (= VG_ENV_DIM) for TorchScript + dmatrix = dmatrix.view(-1, self.nnei, 5) + nfnl = dmatrix.shape[0] + xyz_scatter = torch.zeros( + [nfnl, 5, self.filter_neuron[-1]], + dtype=self.prec, + device=extended_coord.device, + ) + exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei) + for embedding_idx, (ll, compress_data_ii, compress_info_ii) in enumerate( + zip( + self.filter_layers.networks, + self.compress_data, + self.compress_info, + ) + ): + if self.type_one_side: + ii = embedding_idx + ti_mask = None + else: + ii = embedding_idx // self.ntypes + ti = embedding_idx % self.ntypes + ti_mask = atype.ravel().eq(ti) + if ti_mask is not None: + mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]] + rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :] + else: + mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] + rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] + rr = rr * mm[:, :, None] + ss = rr[:, :, :1] + if self.compress: + ss = ss.reshape(-1, 1) + gr = tabulate_fusion_se_a_vg( + compress_data_ii.contiguous(), + compress_info_ii.cpu().contiguous(), + ss.contiguous(), + rr.contiguous(), + self.filter_neuron[-1], + ) + else: + gg = ll.forward(ss) + gr = torch.matmul(rr.permute(0, 2, 1), gg) + if ti_mask is not None: + xyz_scatter[ti_mask] += gr + else: + xyz_scatter += gr + + xyz_scatter /= self.nnei + xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) + rot_mat = xyz_scatter_1[:, :, 1:4] + xyz_scatter_2 = xyz_scatter[:, :, 0 : self.axis_neuron] + result = torch.matmul(xyz_scatter_1, xyz_scatter_2) + result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron) + rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) + return result, rot_mat, None, None, sw + + def has_message_passing(self) -> bool: + return False + + def need_sorted_nlist_for_lower(self) -> bool: + return False + + +@BaseDescriptor.register("se_a_vg") +@BaseDescriptor.register("se_e2_a_vg") +class DescrptSeAVg(DescrptSeA): + """VG-aware wrapper around :class:`DescrptBlockSeAVg`.""" + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: list[int] | int, + neuron: list[int] | None = None, + axis_neuron: int = 16, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: list[tuple[int, int]] | None = None, + env_protection: float = 0.0, + type_one_side: bool = True, + trainable: bool = True, + seed: int | list[int] | None = None, + ntypes: int | None = None, + type_map: list[str] | None = None, + spin: Any | None = None, + ) -> None: + del ntypes, spin + nn.Module.__init__(self) + BaseDescriptor.__init__(self) + self.type_map = type_map + self.compress = False + self.prec = PRECISION_DICT[precision] + self.sea = DescrptBlockSeAVg( + rcut, + rcut_smth, + sel, + neuron=neuron or [25, 50, 100], + axis_neuron=axis_neuron, + set_davg_zero=set_davg_zero, + activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, + exclude_types=exclude_types or [], + env_protection=env_protection, + type_one_side=type_one_side, + trainable=trainable, + seed=seed, + ) + + def forward( + self, + coord_ext: torch.Tensor, + atype_ext: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del comm_dict, fparam, charge_spin + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.sea.forward( + nlist, + coord_ext, + atype_ext, + aparam=aparam, + mapping=mapping, + ) + return ( + g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) + if rot_mat is not None + else None, + None, + None, + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None, + ) + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + if self.compress: + raise ValueError("Compression is already enabled.") + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + data["type_one_side"], + data["exclude_types"], + ActivationFn(data["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self.sea.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.compress = True + + def serialize(self) -> dict: + obj = self.sea + return { + "@class": "Descriptor", + "type": "se_a_vg", + "@version": 2, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "resnet_dt": obj.resnet_dt, + "set_davg_zero": obj.set_davg_zero, + "activation_function": obj.activation_function, + "precision": RESERVED_PRECISION_DICT[obj.prec], + "embeddings": obj.filter_layers.serialize(), + "env_mat": DPEnvMat( + obj.rcut, obj.rcut_smth, obj.env_protection + ).serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": obj["davg"].detach().cpu().numpy(), + "dstd": obj["dstd"].detach().cpu().numpy(), + }, + "type_map": self.type_map, + "trainable": True, + "type_one_side": obj.type_one_side, + "spin": None, + } + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat: Any | None = None + ) -> None: + raise NotImplementedError( + "Descriptor se_a_vg does not support changing type related params yet." + ) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: list[str] | None, + local_jdata: dict, + ) -> tuple[dict, float | None]: + local_jdata_cpy = local_jdata.copy() + min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel( + train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False + ) + return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 08a70bfd93..f8377ab63d 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -145,8 +145,13 @@ def get_spin_model(model_params: dict) -> SpinModel: or model_params["descriptor"]["env_protection"] == 0.0 ): model_params["descriptor"]["env_protection"] = 0.01 - if model_params["descriptor"]["type"] in ["se_e2_a"]: - # only expand sel for se_e2_a + if model_params["descriptor"]["type"] in [ + "se_e2_a", + "se_a", + "se_a_vg", + "se_e2_a_vg", + ]: + # only expand sel for se_a family descriptors model_params["descriptor"]["sel"] += model_params["descriptor"]["sel"] backbone_model = get_standard_model(model_params) return SpinEnergyModel(backbone_model=backbone_model, spin=spin) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 439d3d11d9..79c9ee0fd5 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -404,6 +404,13 @@ class GeneralFitting(Fitting): A list of strings. Give the name to each type of atoms. use_aparam_as_mask: bool If True, the aparam will not be used in fitting net for embedding. + use_aparam_output_gate: bool + If True and numb_aparam > 0, multiply the fitting output by + g = a^2 / (sigma^2 * aparam_gate_norm) using raw aparam. + aparam_gate_norm: float + Normalization factor in the aparam output gate denominator. + aparam_gate_clamp: bool + If True, clamp the aparam output gate to [0, 1]. default_fparam: list[float], optional The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net. @@ -430,10 +437,22 @@ def __init__( remove_vaccum_contribution: list[bool] | None = None, type_map: list[str] | None = None, use_aparam_as_mask: bool = False, + use_aparam_output_gate: bool = False, + aparam_gate_norm: float = 1.0, + aparam_gate_clamp: bool = True, default_fparam: list[float] | None = None, **kwargs: Any, ) -> None: super().__init__() + if use_aparam_output_gate and numb_aparam <= 0: + raise ValueError( + "use_aparam_output_gate requires numb_aparam > 0, " + f"got numb_aparam={numb_aparam}" + ) + if aparam_gate_norm <= 0.0: + raise ValueError( + f"aparam_gate_norm must be positive, got {aparam_gate_norm}" + ) self.var_name = var_name self.ntypes = ntypes self.dim_descrpt = dim_descrpt @@ -451,6 +470,9 @@ def __init__( self.seed = seed self.type_map = type_map self.use_aparam_as_mask = use_aparam_as_mask + self.use_aparam_output_gate = use_aparam_output_gate + self.aparam_gate_norm = aparam_gate_norm + self.aparam_gate_clamp = aparam_gate_clamp # order matters, should be place after the assignment of ntypes self.reinit_exclude(exclude_types) self.trainable = trainable @@ -622,6 +644,9 @@ def serialize(self) -> dict: "trainable": [self.trainable] * (len(self.neuron) + 1), "layer_name": None, "use_aparam_as_mask": self.use_aparam_as_mask, + "use_aparam_output_gate": self.use_aparam_output_gate, + "aparam_gate_norm": self.aparam_gate_norm, + "aparam_gate_clamp": self.aparam_gate_clamp, "spin": None, } @@ -736,6 +761,60 @@ def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: def _extend_a_avg_std(self, xx: torch.Tensor, nb: int, nloc: int) -> torch.Tensor: return torch.tile(xx.view([1, 1, self.numb_aparam]), [nb, nloc, 1]) + def _compute_aparam_output_gate( + self, + aparam_raw: torch.Tensor, + ) -> torch.Tensor: + """Hard-coded gate g = a^2 / (sigma^2 * norm) from raw aparam.""" + assert self.aparam_inv_std is not None + sigma = 1.0 / self.aparam_inv_std + gate = (aparam_raw * aparam_raw) / ( + sigma * sigma * self.aparam_gate_norm + 1e-12 + ) + if self.numb_aparam > 1: + gate = gate.prod(dim=-1, keepdim=True) + if self.aparam_gate_clamp: + gate = gate.clamp(0.0, 1.0) + return gate + + def _apply_aparam_output_gate( + self, + outs: torch.Tensor, + aparam_raw: torch.Tensor | None, + ) -> torch.Tensor: + if not self.use_aparam_output_gate: + return outs + if aparam_raw is None: + raise ValueError( + "aparam is required when use_aparam_output_gate is enabled" + ) + gate = self._compute_aparam_output_gate(aparam_raw) + return outs * gate + + @torch.jit.export + def apply_aparam_output_gate_to_atomic_output( + self, + outs: torch.Tensor, + aparam: torch.Tensor | None, + ) -> torch.Tensor: + """Apply the aparam gate to atomic outputs after out_bias is added.""" + if not self.use_aparam_output_gate: + return outs + if aparam is None: + raise ValueError( + "aparam is required when use_aparam_output_gate is enabled" + ) + nf, nloc = outs.shape[0], outs.shape[1] + aparam_raw = aparam.to(self.prec) + expected = nf * nloc * self.numb_aparam + if aparam_raw.numel() != expected: + raise ValueError( + f"input aparam: cannot reshape {list(aparam_raw.shape)} " + f"into ({nf}, {nloc}, {self.numb_aparam})." + ) + aparam_raw = aparam_raw.reshape(nf, nloc, self.numb_aparam) + return self._apply_aparam_output_gate(outs, aparam_raw) + def _forward_common( self, descriptor: torch.Tensor, @@ -756,7 +835,20 @@ def _forward_common( fparam = torch.tile(self.default_fparam_tensor.unsqueeze(0), [nf, 1]) fparam = fparam.to(self.prec) if fparam is not None else None - aparam = aparam.to(self.prec) if aparam is not None else None + aparam_raw: torch.Tensor | None = None + if aparam is not None: + aparam = aparam.to(self.prec) + if self.numb_aparam > 0: + if aparam.numel() % (nf * self.numb_aparam) != 0: + raise ValueError( + f"input aparam: cannot reshape {list(aparam.shape)} " + f"into ({nf}, nloc, {self.numb_aparam})." + ) + aparam_raw = aparam.view([nf, -1, self.numb_aparam]) + if self.use_aparam_output_gate and aparam_raw is None: + raise ValueError( + "aparam is required when use_aparam_output_gate is enabled" + ) if self.remove_vaccum_contribution is not None: # TODO: compute the input for vaccm when remove_vaccum_contribution is set @@ -801,26 +893,20 @@ def _forward_common( ) # check aparam dim, concate to input descriptor if self.numb_aparam > 0 and not self.use_aparam_as_mask: - assert aparam is not None, "aparam should not be None" + assert aparam_raw is not None, "aparam should not be None" assert self.aparam_avg is not None assert self.aparam_inv_std is not None - if aparam.numel() % (nf * self.numb_aparam) != 0: - raise ValueError( - f"input aparam: cannot reshape {list(aparam.shape)} " - f"into ({nf}, nloc, {self.numb_aparam})." - ) - aparam = aparam.view([nf, -1, self.numb_aparam]) - nb, nloc, _ = aparam.shape + nb, nloc, _ = aparam_raw.shape t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) - aparam = (aparam - t_aparam_avg) * t_aparam_inv_std + aparam_embed = (aparam_raw - t_aparam_avg) * t_aparam_inv_std xx = torch.cat( - [xx, aparam], + [xx, aparam_embed], dim=-1, ) if xx_zeros is not None: xx_zeros = torch.cat( - [xx_zeros, aparam], + [xx_zeros, aparam_embed], dim=-1, ) diff --git a/deepmd/pt/model/task/sezm_ener.py b/deepmd/pt/model/task/sezm_ener.py index c6af12fb5f..75928035bf 100644 --- a/deepmd/pt/model/task/sezm_ener.py +++ b/deepmd/pt/model/task/sezm_ener.py @@ -691,7 +691,20 @@ def _forward_case_film( assert self.default_fparam_tensor is not None fparam = torch.tile(self.default_fparam_tensor.unsqueeze(0), [nf, 1]) fparam = fparam.to(self.prec) if fparam is not None else None - aparam = aparam.to(self.prec) if aparam is not None else None + aparam_raw: torch.Tensor | None = None + if aparam is not None: + aparam = aparam.to(self.prec) + if self.numb_aparam > 0: + if aparam.numel() % (nf * self.numb_aparam) != 0: + raise ValueError( + f"input aparam: cannot reshape {list(aparam.shape)} " + f"into ({nf}, nloc, {self.numb_aparam})." + ) + aparam_raw = aparam.view([nf, -1, self.numb_aparam]) + if self.use_aparam_output_gate and aparam_raw is None: + raise ValueError( + "aparam is required when use_aparam_output_gate is enabled" + ) if self.remove_vaccum_contribution is not None: xx_zeros = torch.zeros_like(xx) @@ -725,22 +738,16 @@ def _forward_case_film( xx_zeros = torch.cat([xx_zeros, fparam], dim=-1) if self.numb_aparam > 0 and not self.use_aparam_as_mask: - assert aparam is not None, "aparam should not be None" + assert aparam_raw is not None, "aparam should not be None" assert self.aparam_avg is not None assert self.aparam_inv_std is not None - if aparam.numel() % (nf * self.numb_aparam) != 0: - raise ValueError( - f"input aparam: cannot reshape {list(aparam.shape)} " - f"into ({nf}, nloc, {self.numb_aparam})." - ) - aparam = aparam.view([nf, -1, self.numb_aparam]) - nb, nloc, _ = aparam.shape + nb, nloc, _ = aparam_raw.shape t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) - aparam = (aparam - t_aparam_avg) * t_aparam_inv_std - xx = torch.cat([xx, aparam], dim=-1) + aparam_embed = (aparam_raw - t_aparam_avg) * t_aparam_inv_std + xx = torch.cat([xx, aparam_embed], dim=-1) if xx_zeros is not None: - xx_zeros = torch.cat([xx_zeros, aparam], dim=-1) + xx_zeros = torch.cat([xx_zeros, aparam_embed], dim=-1) assert self.case_embd is not None outs = torch.zeros( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 762347fde7..2f1ac20592 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -46,6 +46,13 @@ # descriptors doc_loc_frame = "Defines a local frame at each atom, and computes the descriptor as local coordinates under this frame." doc_se_e2_a = "Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor." +doc_se_a_vg = ( + "Variational-Gaussian smooth edition of Deep Potential (VGM II). " + "Gaussian width sigma enters the descriptor via " + "sigma_ij = sqrt(sigma_i^2 + sigma_j^2) in the radial kernel and an extra " + "environment-matrix column. Requires aparam (one sigma per atom) in training data. " + "Compatible with use_aparam_output_gate in the fitting net." +) doc_se_e2_r = "Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor." doc_se_e3 = "Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Three-body embedding will be used by this descriptor." doc_se_a_tpe = "Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Type embedding will be used by this descriptor." @@ -344,6 +351,13 @@ def descrpt_se_a_args() -> list[Argument]: ] +@descrpt_args_plugin.register( + "se_a_vg", alias=["se_e2_a_vg"], doc=doc_only_pt_supported + doc_se_a_vg +) +def descrpt_se_a_vg_args() -> list[Argument]: + return descrpt_se_a_args() + + @descrpt_args_plugin.register( "dpa4", alias=["DPA4", "SeZM", "sezm"], @@ -2241,6 +2255,49 @@ def descrpt_variant_type_args(exclude_hybrid: bool = False) -> Variant: fitting_args_plugin = ArgsPlugin() +def fitting_aparam_output_gate_args() -> list[Argument]: + """Arguments for the hard-coded aparam output gate in fitting nets.""" + doc_use_aparam_output_gate = ( + doc_only_pt_supported + + "If True and numb_aparam > 0, multiply the fitting output by " + "g = a^2 / (sigma^2 * aparam_gate_norm), where a is the raw aparam " + "and sigma is the standard deviation from training statistics. " + "g is 0 when a = 0. aparam is still embedded in the fitting net unless " + "use_aparam_as_mask is True." + ) + doc_aparam_gate_norm = ( + doc_only_pt_supported + + "Normalization factor in the aparam output gate denominator " + "(sigma^2 * aparam_gate_norm)." + ) + doc_aparam_gate_clamp = ( + doc_only_pt_supported + "If True, clamp the aparam output gate to [0, 1]." + ) + return [ + Argument( + "use_aparam_output_gate", + bool, + optional=True, + default=False, + doc=doc_use_aparam_output_gate, + ), + Argument( + "aparam_gate_norm", + float, + optional=True, + default=1.0, + doc=doc_aparam_gate_norm, + ), + Argument( + "aparam_gate_clamp", + bool, + optional=True, + default=True, + doc=doc_aparam_gate_clamp, + ), + ] + + @fitting_args_plugin.register("ener", doc=doc_ener) def fitting_ener() -> list[Argument]: doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." @@ -2330,6 +2387,7 @@ def fitting_ener() -> list[Argument]: default=False, doc=doc_use_aparam_as_mask, ), + *fitting_aparam_output_gate_args(), ] @@ -2433,6 +2491,7 @@ def fitting_sezm_ener() -> list[Argument]: default=False, doc=doc_only_pt_supported + doc_case_film_embd, ), + *fitting_aparam_output_gate_args(), ] @@ -2801,7 +2860,10 @@ def model_compression_type_args() -> Variant: return Variant( "type", - [Argument("se_e2_a", dict, model_compression(), alias=["se_a"])], + [ + Argument("se_e2_a", dict, model_compression(), alias=["se_a"]), + Argument("se_a_vg", dict, model_compression(), alias=["se_e2_a_vg"]), + ], optional=True, default_tag="se_e2_a", doc=doc_compress_type, diff --git a/examples/fparam/train/input_aparam.json b/examples/fparam/train/input_aparam.json index 32420068e6..9ade016777 100644 --- a/examples/fparam/train/input_aparam.json +++ b/examples/fparam/train/input_aparam.json @@ -28,6 +28,10 @@ ], "resnet_dt": true, "numb_aparam": 1, + "use_aparam_as_mask": false, + "use_aparam_output_gate": true, + "aparam_gate_norm": 1.0, + "aparam_gate_clamp": true, "precision": "float64", "seed": 1 } diff --git a/source/tests/pt/model/test_aparam_output_gate.py b/source/tests/pt/model/test_aparam_output_gate.py new file mode 100644 index 0000000000..0874516788 --- /dev/null +++ b/source/tests/pt/model/test_aparam_output_gate.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +from deepmd.pt.model.task.invar_fitting import ( + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + + +class TestAparamOutputGate(unittest.TestCase): + def test_zero_aparam_zeros_output_with_out_bias(self) -> None: + nf, nloc, dim_descrpt = 1, 2, 8 + sigma = 2.0 + fitting = InvarFitting( + "energy", + ntypes=1, + dim_descrpt=dim_descrpt, + dim_out=1, + neuron=[4, 4], + numb_aparam=1, + mixed_types=True, + use_aparam_output_gate=True, + aparam_gate_norm=1.0, + aparam_gate_clamp=True, + ).to(device) + fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) + + descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device) + atype = torch.zeros(nf, nloc, dtype=torch.int64, device=device) + aparam_zero = torch.zeros(nf, nloc, 1, dtype=dtype, device=device) + aparam_sigma = torch.full((nf, nloc, 1), sigma, dtype=dtype, device=device) + + raw_zero = fitting(descriptor, atype, aparam=aparam_zero)["energy"] + raw_sigma = fitting(descriptor, atype, aparam=aparam_sigma)["energy"] + fake_out_bias = torch.full((nf, nloc, 1), 1.2, dtype=dtype, device=device) + + out_zero = fitting.apply_aparam_output_gate_to_atomic_output( + raw_zero + fake_out_bias, aparam_zero + ) + out_sigma = fitting.apply_aparam_output_gate_to_atomic_output( + raw_sigma + fake_out_bias, aparam_sigma + ) + + self.assertTrue(torch.allclose(out_zero, torch.zeros_like(out_zero))) + self.assertGreater(out_sigma.abs().max().item(), 0.0) + + def test_gate_matches_formula(self) -> None: + nf, nloc, dim_descrpt = 1, 1, 4 + sigma = 3.0 + norm = 2.0 + fitting = InvarFitting( + "energy", + ntypes=1, + dim_descrpt=dim_descrpt, + dim_out=1, + neuron=[4], + numb_aparam=1, + mixed_types=True, + use_aparam_output_gate=True, + aparam_gate_norm=norm, + aparam_gate_clamp=False, + ).to(device) + fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) + + descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device) + atype = torch.zeros(nf, nloc, dtype=torch.int64, device=device) + a_val = 1.5 + aparam = torch.full((nf, nloc, 1), a_val, dtype=dtype, device=device) + + fitting_gate = fitting._compute_aparam_output_gate(aparam) + expected = (a_val * a_val) / (sigma * sigma * norm) + self.assertTrue( + torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype)) + ) + + def test_gate_reshape_flat_aparam(self) -> None: + nf, nloc, dim_descrpt = 1, 3, 4 + sigma = 2.0 + fitting = InvarFitting( + "energy", + ntypes=1, + dim_descrpt=dim_descrpt, + dim_out=1, + neuron=[4], + numb_aparam=1, + mixed_types=True, + use_aparam_output_gate=True, + aparam_gate_norm=1.0, + aparam_gate_clamp=False, + ).to(device) + fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) + + outs = torch.ones(nf, nloc, 1, dtype=dtype, device=device) + aparam_flat = torch.full((nf * nloc,), sigma, dtype=dtype, device=device) + out = fitting.apply_aparam_output_gate_to_atomic_output(outs, aparam_flat) + expected_gate = 1.0 + self.assertTrue(torch.allclose(out, torch.full_like(out, expected_gate))) + + def test_serialize_roundtrip(self) -> None: + fitting = InvarFitting( + "energy", + ntypes=1, + dim_descrpt=4, + dim_out=1, + neuron=[4], + numb_aparam=1, + use_aparam_output_gate=True, + aparam_gate_norm=1.5, + aparam_gate_clamp=False, + ) + restored = InvarFitting.deserialize(fitting.serialize()) + self.assertTrue(restored.use_aparam_output_gate) + self.assertEqual(restored.aparam_gate_norm, 1.5) + self.assertFalse(restored.aparam_gate_clamp) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_se_a_vg.py b/source/tests/pt/model/test_se_a_vg.py new file mode 100644 index 0000000000..9ac2be0c8b --- /dev/null +++ b/source/tests/pt/model/test_se_a_vg.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.model.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat, +) +from deepmd.pt.model.descriptor.env_mat_vg import ( + VG_ENV_DIM, + prod_env_mat_vg, +) +from deepmd.pt.model.descriptor.se_a_vg import ( + DescrptSeAVg, +) +from deepmd.pt.model.task.invar_fitting import ( + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptSeAVg(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_sigma_zero_matches_se_a_env_mat(self) -> None: + """At sigma=0 the VG radial kernel reduces to the standard 1/r form.""" + prec = "float64" + pt_dtype = PRECISION_DICT[prec] + nf, nloc, nnei = self.nlist.shape + mean = torch.zeros( + (self.nt, nnei, VG_ENV_DIM), dtype=pt_dtype, device=env.DEVICE + ) + stddev = torch.ones( + (self.nt, nnei, VG_ENV_DIM), dtype=pt_dtype, device=env.DEVICE + ) + mean_se = torch.zeros((self.nt, nnei, 4), dtype=pt_dtype, device=env.DEVICE) + stddev_se = torch.ones((self.nt, nnei, 4), dtype=pt_dtype, device=env.DEVICE) + + coord = torch.tensor(self.coord_ext, dtype=pt_dtype, device=env.DEVICE) + nlist = torch.tensor(self.nlist, dtype=torch.int64, device=env.DEVICE) + atype = torch.tensor( + self.atype_ext[:, :nloc], dtype=torch.int64, device=env.DEVICE + ) + aparam_zero = torch.zeros((nf, nloc, 1), dtype=pt_dtype, device=env.DEVICE) + + vg_mat, _, _ = prod_env_mat_vg( + coord, + nlist, + atype, + aparam_zero, + mean, + stddev, + self.rcut, + self.rcut_smth, + ) + se_mat, _, _ = prod_env_mat( + coord, + nlist, + atype, + mean_se, + stddev_se, + self.rcut, + self.rcut_smth, + ) + np.testing.assert_allclose( + vg_mat[..., :4].detach().cpu().numpy(), + se_mat.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + vg_mat[..., 4].detach().cpu().numpy(), + 0.0, + atol=1e-10, + ) + + def test_sigma_changes_descriptor(self) -> None: + prec = "float64" + pt_dtype = PRECISION_DICT[prec] + nf, nloc, _ = self.nlist.shape + dd = DescrptSeAVg( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + coord = torch.tensor(self.coord_ext, dtype=pt_dtype, device=env.DEVICE) + atype_ext = torch.tensor(self.atype_ext, dtype=torch.int64, device=env.DEVICE) + nlist = torch.tensor(self.nlist, dtype=torch.int64, device=env.DEVICE) + aparam_zero = torch.zeros((nf, nloc, 1), dtype=pt_dtype, device=env.DEVICE) + aparam_one = torch.ones((nf, nloc, 1), dtype=pt_dtype, device=env.DEVICE) + + out0, _, _, _, _ = dd(coord, atype_ext, nlist, aparam=aparam_zero) + out1, _, _, _, _ = dd(coord, atype_ext, nlist, aparam=aparam_one) + diff = (out0 - out1).abs().max().item() + self.assertGreater(diff, 0.0) + + def test_forward_shape(self) -> None: + prec = "float64" + pt_dtype = PRECISION_DICT[prec] + nf, nloc, nnei = self.nlist.shape + axis_neuron = 4 + neuron = [8, 16] + dd = DescrptSeAVg( + self.rcut, + self.rcut_smth, + self.sel, + neuron=neuron, + axis_neuron=axis_neuron, + precision=prec, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + coord = torch.tensor(self.coord_ext, dtype=pt_dtype, device=env.DEVICE) + atype_ext = torch.tensor(self.atype_ext, dtype=torch.int64, device=env.DEVICE) + nlist = torch.tensor(self.nlist, dtype=torch.int64, device=env.DEVICE) + aparam = torch.full((nf, nloc, 1), 0.5, dtype=pt_dtype, device=env.DEVICE) + + out, rot, _, _, sw = dd(coord, atype_ext, nlist, aparam=aparam) + self.assertEqual(out.shape, (nf, nloc, neuron[-1] * axis_neuron)) + self.assertEqual(rot.shape, (nf, nloc, neuron[-1], 3)) + self.assertEqual(dd.sea.ndescrpt, nnei * VG_ENV_DIM) + self.assertIsNotNone(sw) + + def test_dp_atomic_model_accepts_aparam_for_freeze(self) -> None: + type_map = ["A", "B"] + dd = DescrptSeAVg( + self.rcut, + self.rcut_smth, + self.sel, + precision="float64", + seed=GLOBAL_SEED, + type_map=type_map, + ).to(env.DEVICE) + dim_descrpt = dd.get_dim_out() + fitting = InvarFitting( + "energy", + ntypes=self.nt, + dim_descrpt=dim_descrpt, + dim_out=1, + neuron=[4], + numb_aparam=1, + mixed_types=True, + ).to(env.DEVICE) + model = DPAtomicModel(dd, fitting, type_map).to(env.DEVICE) + self.assertTrue(model._descriptor_accepts_aparam) + scripted = torch.jit.script(model) + self.assertTrue(scripted._descriptor_accepts_aparam) + + def test_scripted_descriptor_uses_aparam(self) -> None: + prec = "float64" + pt_dtype = PRECISION_DICT[prec] + nf, nloc, _ = self.nlist.shape + dd = DescrptSeAVg( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + coord = torch.tensor(self.coord_ext, dtype=pt_dtype, device=env.DEVICE) + atype_ext = torch.tensor(self.atype_ext, dtype=torch.int64, device=env.DEVICE) + nlist = torch.tensor(self.nlist, dtype=torch.int64, device=env.DEVICE) + aparam_zero = torch.zeros((nf, nloc, 1), dtype=pt_dtype, device=env.DEVICE) + aparam_one = torch.ones((nf, nloc, 1), dtype=pt_dtype, device=env.DEVICE) + + dd.eval() + out0, _, _, _, _ = dd(coord, atype_ext, nlist, aparam=aparam_zero) + out1, _, _, _, _ = dd(coord, atype_ext, nlist, aparam=aparam_one) + + scripted = torch.jit.script(dd) + out0_jit, _, _, _, _ = scripted(coord, atype_ext, nlist, aparam=aparam_zero) + out1_jit, _, _, _, _ = scripted(coord, atype_ext, nlist, aparam=aparam_one) + self.assertGreater((out0_jit - out1_jit).abs().max().item(), 0.0) + torch.testing.assert_close(out0, out0_jit) + torch.testing.assert_close(out1, out1_jit) + + def test_compression_matches_uncompressed(self) -> None: + if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"): + self.skipTest("tabulate_fusion_se_a op is not built") + prec = "float64" + pt_dtype = PRECISION_DICT[prec] + nf, nloc, _ = self.nlist.shape + dd = DescrptSeAVg( + self.rcut, + self.rcut_smth, + self.sel, + neuron=[8, 16], + axis_neuron=4, + precision=prec, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + coord = torch.tensor(self.coord_ext, dtype=pt_dtype, device=env.DEVICE) + atype_ext = torch.tensor(self.atype_ext, dtype=torch.int64, device=env.DEVICE) + nlist = torch.tensor(self.nlist, dtype=torch.int64, device=env.DEVICE) + aparam = torch.full((nf, nloc, 1), 0.5, dtype=pt_dtype, device=env.DEVICE) + + out_ref, _, _, _, _ = dd(coord, atype_ext, nlist, aparam=aparam) + dd.enable_compression( + min_nbor_dist=0.5, + table_extrapolate=5.0, + table_stride_1=0.01, + table_stride_2=0.1, + ) + out_cmp, _, _, _, _ = dd(coord, atype_ext, nlist, aparam=aparam) + np.testing.assert_allclose( + out_ref.detach().cpu().numpy(), + out_cmp.detach().cpu().numpy(), + rtol=1e-5, + atol=1e-5, + )