Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 225 additions & 30 deletions deepmd/dpmodel/utils/lmdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,29 @@ def _compute_batch_size(nloc: int, rule: int) -> int:
return max(bsi, 1)


def _parse_positive_rule(spec: str, prefix: str) -> int:
"""Parse the ``N`` in ``<prefix>N`` and require ``N > 0``.

Rejects missing/non-integer/non-positive ``N`` up front so that
misconfigurations (``"filter:"``, ``"filter:0"``, ``"max:-5"``) fail at
construction time instead of silently producing an empty dataset or a
batch_size=1 fallback downstream.
"""
_, _, raw = spec.partition(":")
try:
n = int(raw)
except ValueError:
raise ValueError(
f"Unsupported batch_size {spec!r}. "
f"Expected '{prefix}N' with N a positive integer."
) from None
if n <= 0:
raise ValueError(
f"Unsupported batch_size {spec!r}: N must be a positive integer, got {n}."
)
return n


class LmdbDataReader:
"""Framework-agnostic LMDB dataset reader.

Expand All @@ -232,7 +255,22 @@ class LmdbDataReader:
type_map : list[str]
Global type map from model config.
batch_size : int or str
Batch size. Supports int, "auto", "auto:N".
Batch size rule used to derive per-nloc batch sizes. Supports:

- ``int``: fixed, identical batch size for every nloc group.
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
(``N=32`` for bare ``"auto"``). Acts as a *lower* bound —
each batch has at least ``N`` atoms, but may exceed ``N``
by up to ``nloc - 1``.
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
Acts as an *upper* bound for groups with ``nloc <= N``
(batch has at most ``N`` atoms). For groups with
``nloc > N`` the ``max(1, ...)`` floor kicks in: ``bsi=1``
and a single-frame batch still carries ``nloc`` atoms,
which exceeds ``N``.
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` **and**
drops every frame whose ``nloc > N`` from the dataset. By
construction every retained batch has at most ``N`` atoms.
mixed_batch : bool
If True, allow different nloc in the same batch (future).
If False (default), enforce same-nloc-per-batch.
Expand Down Expand Up @@ -283,51 +321,139 @@ def __init__(

# Scan per-frame nloc only when needed for same-nloc batching.
# For mixed_batch=True, skip the scan entirely (future: padding handles it).
# ``orig_frame_nlocs`` / ``orig_frame_system_ids`` are indexed by the
# *original* LMDB frame index. After a potential ``filter:N`` drop we
# rebuild ``self._frame_nlocs`` / ``self._frame_system_ids`` so they
# are parallel arrays over the *dataset* index space (0..len(self));
# the dataset-to-original mapping lives in ``self._retained_keys``.
if not mixed_batch:
# Fast path: use pre-computed frame_nlocs from metadata if available.
# Falls back to scanning each frame's atom_types shape (~10 us/frame).
meta_nlocs = meta.get("frame_nlocs")
if meta_nlocs is not None:
self._frame_nlocs = [int(n) for n in meta_nlocs]
orig_frame_nlocs = [int(n) for n in meta_nlocs]
else:
self._frame_nlocs = _scan_frame_nlocs(
orig_frame_nlocs = _scan_frame_nlocs(
self._env, self.nframes, self._frame_fmt, self._natoms
)
self._nloc_groups: dict[int, list[int]] = {}
for idx, nloc in enumerate(self._frame_nlocs):
self._nloc_groups.setdefault(nloc, []).append(idx)
else:
self._frame_nlocs = []
self._nloc_groups = {}
orig_frame_nlocs = []

# Parse frame_system_ids for auto_prob support
# Parse frame_system_ids for auto_prob support. ``_nsystems`` must stay
# at ``max(original_sid) + 1`` even after filter:N so that user-facing
# auto_prob block slicing (e.g. ``prob_sys_size;0:284:0.5;284:842:0.5``)
# keeps its meaning across filter thresholds.
meta_sys_ids = meta.get("frame_system_ids")
if meta_sys_ids is not None:
self._frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
self._nsystems = max(self._frame_system_ids) + 1
self._system_groups: dict[int, list[int]] = {}
for idx, sid in enumerate(self._frame_system_ids):
self._system_groups.setdefault(sid, []).append(idx)
self._system_nframes: list[int] = [
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
]
orig_frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
self._nsystems = max(orig_frame_system_ids) + 1
else:
self._frame_system_ids = None
orig_frame_system_ids = None
self._nsystems = 1
self._system_groups = {0: list(range(self.nframes))}
self._system_nframes = [self.nframes]

# Parse batch_size spec
# Parse batch_size spec. ``auto_rule`` and ``max_rule`` are mutually
# exclusive; ``filter_rule`` implies ``max_rule`` plus dropping frames
# whose nloc exceeds the threshold.
self._auto_rule: int | None = None
self._max_rule: int | None = None
self._filter_rule: int | None = None
if isinstance(batch_size, str):
if batch_size == "auto":
self._auto_rule = 32
elif batch_size.startswith("auto:"):
self._auto_rule = int(batch_size.split(":")[1])
self._auto_rule = _parse_positive_rule(batch_size, "auto:")
elif batch_size.startswith("max:"):
self._max_rule = _parse_positive_rule(batch_size, "max:")
elif batch_size.startswith("filter:"):
self._filter_rule = _parse_positive_rule(batch_size, "filter:")
self._max_rule = self._filter_rule
else:
self._auto_rule = 32
# Default batch_size uses first frame's nloc (for total_batch estimate)
raise ValueError(
f"Unsupported batch_size {batch_size!r}. "
"Expected int, 'auto', 'auto:N', 'max:N', or 'filter:N'."
)
Comment thread
OutisLi marked this conversation as resolved.

# ``filter:N`` needs per-frame nloc to drop oversized frames; the
# ``mixed_batch=True`` fast path skips the nloc scan entirely, so the
# two options are incompatible. Fail fast rather than silently
# retaining every frame and breaking the documented contract.
if self._filter_rule is not None and mixed_batch:
raise ValueError(
"batch_size='filter:N' is incompatible with mixed_batch=True: "
"per-frame nloc is unavailable in the mixed-batch fast path. "
"Use mixed_batch=False, or switch to 'max:N' / a fixed int."
)

# Determine which original-index frames survive the filter. Without
# ``filter:N`` every frame is retained.
if self._filter_rule is not None:
retained_keys = [
i for i, n in enumerate(orig_frame_nlocs) if n <= self._filter_rule
]
n_dropped = self.nframes - len(retained_keys)
if n_dropped > 0:
log.info(
f"LMDB filter:{self._filter_rule} drops {n_dropped}/"
f"{self.nframes} frames with nloc > {self._filter_rule} "
f"({self.lmdb_path})."
)
else:
retained_keys = list(range(self.nframes))
Comment thread
OutisLi marked this conversation as resolved.

# Dataset-index → original LMDB frame key. ``__getitem__`` looks up
# this table so that ``reader[i]`` is a valid LMDB read for every
# ``0 <= i < len(reader)``, no matter how many frames were filtered.
self._retained_keys: list[int] = retained_keys

# Re-key _frame_nlocs / _frame_system_ids into the dataset-index
# space so that every downstream consumer (nloc_groups, system_groups,
# SameNlocBatchSampler, _expand_indices_by_blocks) operates in a
# single, self-consistent indexing scheme.
if not mixed_batch:
self._frame_nlocs = [orig_frame_nlocs[k] for k in retained_keys]
else:
self._frame_nlocs = []

if orig_frame_system_ids is not None:
self._frame_system_ids: list[int] | None = [
orig_frame_system_ids[k] for k in retained_keys
]
else:
self._frame_system_ids = None

# Group retained frames by nloc using dataset indices (0..len-1).
if not mixed_batch:
self._nloc_groups: dict[int, list[int]] = {}
for ds_idx, nloc in enumerate(self._frame_nlocs):
self._nloc_groups.setdefault(nloc, []).append(ds_idx)
else:
self._nloc_groups = {}

# Group retained frames by original system id; the sid numbering is
# preserved (no compression) so user-facing auto_prob slices stay
# meaningful across filter thresholds. Fully-dropped systems appear
# as zero-frame entries in ``_system_nframes``.
if self._frame_system_ids is not None:
self._system_groups: dict[int, list[int]] = {}
for ds_idx, sid in enumerate(self._frame_system_ids):
self._system_groups.setdefault(sid, []).append(ds_idx)
self._system_nframes: list[int] = [
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
]
else:
self._system_groups = {0: list(range(len(retained_keys)))}
self._system_nframes = [len(retained_keys)]

# nframes now reflects retained frames; __len__ returns this and the
# valid index domain for __getitem__ is [0, self.nframes).
self.nframes = len(retained_keys)

# Default batch_size used only by the index/total_batch estimate. The
# sampler always goes through get_batch_size_for_nloc for real batches.
if self._auto_rule is not None:
self.batch_size = _compute_batch_size(self._natoms, self._auto_rule)
elif self._max_rule is not None:
self.batch_size = max(1, self._max_rule // max(self._natoms, 1))
else:
self.batch_size = int(batch_size)

Expand Down Expand Up @@ -382,20 +508,44 @@ def __del__(self) -> None:
_close_lmdb(path)

def get_batch_size_for_nloc(self, nloc: int) -> int:
"""Get batch_size for a given nloc. Uses auto rule if configured."""
"""Return the per-nloc batch size for the configured rule.

- ``auto`` / ``auto:N``: ``ceil(N / nloc)`` — may overshoot the
atom budget by up to ``nloc - 1`` atoms.
- ``max:N``: ``max(1, floor(N / nloc))``. Acts as an upper bound
for groups with ``nloc <= N`` (batch has at most ``N`` atoms).
For groups with ``nloc > N`` the floor clamps to 1 and the
single-frame batch still carries ``nloc`` atoms, exceeding ``N``.
- ``filter:N``: same per-nloc formula as ``max:N``; by
construction every retained group satisfies ``nloc <= N`` so
no overshoot occurs.
- fixed int: the same value for every nloc group.
"""
if self._auto_rule is not None:
return _compute_batch_size(nloc, self._auto_rule)
if self._max_rule is not None:
return max(1, self._max_rule // max(nloc, 1))
Comment thread
OutisLi marked this conversation as resolved.
return self.batch_size

def __len__(self) -> int:
return self.nframes

def __getitem__(self, index: int) -> dict[str, Any]:
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays."""
key = format(index, self._frame_fmt).encode()
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays.

``index`` is a dataset-level index in ``[0, len(self))``. Under
``filter:N`` the LMDB key space may have gaps (dropped frames), so
we translate through ``self._retained_keys`` before hitting LMDB.
"""
if index < 0 or index >= self.nframes:
raise IndexError(f"dataset index {index} out of range [0, {self.nframes})")
original_key = self._retained_keys[index]
key = format(original_key, self._frame_fmt).encode()
raw = self._txn.get(key)
if raw is None:
raise IndexError(f"Frame {index} not found in LMDB")
raise IndexError(
f"Frame {original_key} not found in LMDB (dataset index {index})"
)
frame = _decode_frame(raw)
frame = _remap_keys(frame)

Expand Down Expand Up @@ -524,7 +674,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
np.float32(1.0) if extra_key in frame else np.float32(0.0)
)

frame["fid"] = index
frame["fid"] = original_key

return frame

Expand All @@ -538,11 +688,19 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
def print_summary(self, name: str, prob: Any) -> None:
"""Print basic dataset info."""
n_groups = len(self._nloc_groups)
if self._auto_rule is not None:
bs_str = f"auto:{self._auto_rule}"
elif self._filter_rule is not None:
bs_str = f"filter:{self._filter_rule}"
elif self._max_rule is not None:
bs_str = f"max:{self._max_rule}"
else:
bs_str = str(self.batch_size)

log.info(
f"LMDB {name}: {self.lmdb_path}, "
f"{self.nframes} frames, {n_groups} nloc groups, "
f"batch_size={'auto' if self._auto_rule else self.batch_size}, "
f"batch_size={bs_str}, "
f"mixed_batch={self.mixed_batch}"
)
# Print nloc groups in rows of ~10 for readability
Expand Down Expand Up @@ -646,6 +804,43 @@ def compute_block_targets(
stt, end, weight = part.split(":")
blocks.append((int(stt), int(end), float(weight)))

# Drop blocks that retain zero frames (can happen when ``filter:N``
# eliminates every system in a block). prob_sys_size_ext's per-block
# ``nbatch_block / sum(nbatch_block)`` would otherwise propagate NaN
# when the whole block sums to zero. An all-zero dataset yields no
# targets at all.
nonempty = [
(stt, end, weight)
for stt, end, weight in blocks
if sum(system_nframes[stt:end]) > 0
]
if not nonempty:
log.info(
"compute_block_targets: all blocks are empty in "
f"{auto_prob_style!r}; dataset has no retained frames."
)
return []
if len(nonempty) < len(blocks):
# Rewriting auto_prob_style silently re-normalises the remaining
# weights so they sum to 1.0 — e.g. ``0:3:0.8;3:10:0.2`` with block
# ``0:3`` empty becomes effectively weight 1.0 on block ``3:10``.
# Surface this reweighting so operators can correlate it with the
# preceding ``filter:N`` log line.
dropped = [
f"{stt}:{end}:{weight}"
for (stt, end, weight) in blocks
if (stt, end, weight) not in nonempty
]
log.info(
"compute_block_targets: dropping empty blocks (all systems have "
f"0 frames, likely after filter:N): {dropped}. Remaining block "
"weights will be renormalised to sum to 1.0."
)
auto_prob_style = "prob_sys_size;" + ";".join(
f"{stt}:{end}:{weight}" for stt, end, weight in nonempty
)
blocks = nonempty
Comment thread
OutisLi marked this conversation as resolved.

# Compute per-system probabilities using the standard function
sys_probs = prob_sys_size_ext(auto_prob_style, nsystems, system_nframes)

Expand Down
9 changes: 8 additions & 1 deletion deepmd/pt/utils/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ class LmdbDataset(Dataset):
type_map : list[str]
Global type map from model config.
batch_size : int or str
Batch size. Supports int, "auto", "auto:N".
Batch size rule forwarded to :class:`LmdbDataReader`. Supports:

- ``int``: fixed batch size for every nloc group.
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
(``N=32`` for bare ``"auto"``).
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` and drops
every frame whose ``nloc > N`` from the dataset.
mixed_batch : bool
If True, allow different nloc in the same batch (future).
If False (default), use SameNlocBatchSampler.
Expand Down
8 changes: 5 additions & 3 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3677,8 +3677,8 @@ def training_data_args() -> list[
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
- string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.\n\n\
If MPI is used, the value should be considered as the batch size per task.'
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
Expand Down Expand Up @@ -3757,7 +3757,9 @@ def validation_data_args() -> list[
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
- int: all {link_sys} use the same batch size.\n\n\
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.'
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.'
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\
Expand Down
Loading
Loading