diff --git a/deepmd/dpmodel/utils/lmdb_data.py b/deepmd/dpmodel/utils/lmdb_data.py index 243d4f525d..414b5a8aab 100644 --- a/deepmd/dpmodel/utils/lmdb_data.py +++ b/deepmd/dpmodel/utils/lmdb_data.py @@ -208,6 +208,29 @@ def _compute_batch_size(nloc: int, rule: int) -> int: return max(bsi, 1) +def _parse_positive_rule(spec: str, prefix: str) -> int: + """Parse the ``N`` in ``N`` and require ``N > 0``. + + Rejects missing/non-integer/non-positive ``N`` up front so that + misconfigurations (``"filter:"``, ``"filter:0"``, ``"max:-5"``) fail at + construction time instead of silently producing an empty dataset or a + batch_size=1 fallback downstream. + """ + _, _, raw = spec.partition(":") + try: + n = int(raw) + except ValueError: + raise ValueError( + f"Unsupported batch_size {spec!r}. " + f"Expected '{prefix}N' with N a positive integer." + ) from None + if n <= 0: + raise ValueError( + f"Unsupported batch_size {spec!r}: N must be a positive integer, got {n}." + ) + return n + + class LmdbDataReader: """Framework-agnostic LMDB dataset reader. @@ -232,7 +255,22 @@ class LmdbDataReader: type_map : list[str] Global type map from model config. batch_size : int or str - Batch size. Supports int, "auto", "auto:N". + Batch size rule used to derive per-nloc batch sizes. Supports: + + - ``int``: fixed, identical batch size for every nloc group. + - ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group + (``N=32`` for bare ``"auto"``). Acts as a *lower* bound — + each batch has at least ``N`` atoms, but may exceed ``N`` + by up to ``nloc - 1``. + - ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group. + Acts as an *upper* bound for groups with ``nloc <= N`` + (batch has at most ``N`` atoms). For groups with + ``nloc > N`` the ``max(1, ...)`` floor kicks in: ``bsi=1`` + and a single-frame batch still carries ``nloc`` atoms, + which exceeds ``N``. + - ``"filter:N"``: same per-nloc formula as ``"max:N"`` **and** + drops every frame whose ``nloc > N`` from the dataset. By + construction every retained batch has at most ``N`` atoms. mixed_batch : bool If True, allow different nloc in the same batch (future). If False (default), enforce same-nloc-per-batch. @@ -283,51 +321,139 @@ def __init__( # Scan per-frame nloc only when needed for same-nloc batching. # For mixed_batch=True, skip the scan entirely (future: padding handles it). + # ``orig_frame_nlocs`` / ``orig_frame_system_ids`` are indexed by the + # *original* LMDB frame index. After a potential ``filter:N`` drop we + # rebuild ``self._frame_nlocs`` / ``self._frame_system_ids`` so they + # are parallel arrays over the *dataset* index space (0..len(self)); + # the dataset-to-original mapping lives in ``self._retained_keys``. if not mixed_batch: # Fast path: use pre-computed frame_nlocs from metadata if available. # Falls back to scanning each frame's atom_types shape (~10 us/frame). meta_nlocs = meta.get("frame_nlocs") if meta_nlocs is not None: - self._frame_nlocs = [int(n) for n in meta_nlocs] + orig_frame_nlocs = [int(n) for n in meta_nlocs] else: - self._frame_nlocs = _scan_frame_nlocs( + orig_frame_nlocs = _scan_frame_nlocs( self._env, self.nframes, self._frame_fmt, self._natoms ) - self._nloc_groups: dict[int, list[int]] = {} - for idx, nloc in enumerate(self._frame_nlocs): - self._nloc_groups.setdefault(nloc, []).append(idx) else: - self._frame_nlocs = [] - self._nloc_groups = {} + orig_frame_nlocs = [] - # Parse frame_system_ids for auto_prob support + # Parse frame_system_ids for auto_prob support. ``_nsystems`` must stay + # at ``max(original_sid) + 1`` even after filter:N so that user-facing + # auto_prob block slicing (e.g. ``prob_sys_size;0:284:0.5;284:842:0.5``) + # keeps its meaning across filter thresholds. meta_sys_ids = meta.get("frame_system_ids") if meta_sys_ids is not None: - self._frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids] - self._nsystems = max(self._frame_system_ids) + 1 - self._system_groups: dict[int, list[int]] = {} - for idx, sid in enumerate(self._frame_system_ids): - self._system_groups.setdefault(sid, []).append(idx) - self._system_nframes: list[int] = [ - len(self._system_groups.get(i, [])) for i in range(self._nsystems) - ] + orig_frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids] + self._nsystems = max(orig_frame_system_ids) + 1 else: - self._frame_system_ids = None + orig_frame_system_ids = None self._nsystems = 1 - self._system_groups = {0: list(range(self.nframes))} - self._system_nframes = [self.nframes] - # Parse batch_size spec + # Parse batch_size spec. ``auto_rule`` and ``max_rule`` are mutually + # exclusive; ``filter_rule`` implies ``max_rule`` plus dropping frames + # whose nloc exceeds the threshold. self._auto_rule: int | None = None + self._max_rule: int | None = None + self._filter_rule: int | None = None if isinstance(batch_size, str): if batch_size == "auto": self._auto_rule = 32 elif batch_size.startswith("auto:"): - self._auto_rule = int(batch_size.split(":")[1]) + self._auto_rule = _parse_positive_rule(batch_size, "auto:") + elif batch_size.startswith("max:"): + self._max_rule = _parse_positive_rule(batch_size, "max:") + elif batch_size.startswith("filter:"): + self._filter_rule = _parse_positive_rule(batch_size, "filter:") + self._max_rule = self._filter_rule else: - self._auto_rule = 32 - # Default batch_size uses first frame's nloc (for total_batch estimate) + raise ValueError( + f"Unsupported batch_size {batch_size!r}. " + "Expected int, 'auto', 'auto:N', 'max:N', or 'filter:N'." + ) + + # ``filter:N`` needs per-frame nloc to drop oversized frames; the + # ``mixed_batch=True`` fast path skips the nloc scan entirely, so the + # two options are incompatible. Fail fast rather than silently + # retaining every frame and breaking the documented contract. + if self._filter_rule is not None and mixed_batch: + raise ValueError( + "batch_size='filter:N' is incompatible with mixed_batch=True: " + "per-frame nloc is unavailable in the mixed-batch fast path. " + "Use mixed_batch=False, or switch to 'max:N' / a fixed int." + ) + + # Determine which original-index frames survive the filter. Without + # ``filter:N`` every frame is retained. + if self._filter_rule is not None: + retained_keys = [ + i for i, n in enumerate(orig_frame_nlocs) if n <= self._filter_rule + ] + n_dropped = self.nframes - len(retained_keys) + if n_dropped > 0: + log.info( + f"LMDB filter:{self._filter_rule} drops {n_dropped}/" + f"{self.nframes} frames with nloc > {self._filter_rule} " + f"({self.lmdb_path})." + ) + else: + retained_keys = list(range(self.nframes)) + + # Dataset-index → original LMDB frame key. ``__getitem__`` looks up + # this table so that ``reader[i]`` is a valid LMDB read for every + # ``0 <= i < len(reader)``, no matter how many frames were filtered. + self._retained_keys: list[int] = retained_keys + + # Re-key _frame_nlocs / _frame_system_ids into the dataset-index + # space so that every downstream consumer (nloc_groups, system_groups, + # SameNlocBatchSampler, _expand_indices_by_blocks) operates in a + # single, self-consistent indexing scheme. + if not mixed_batch: + self._frame_nlocs = [orig_frame_nlocs[k] for k in retained_keys] + else: + self._frame_nlocs = [] + + if orig_frame_system_ids is not None: + self._frame_system_ids: list[int] | None = [ + orig_frame_system_ids[k] for k in retained_keys + ] + else: + self._frame_system_ids = None + + # Group retained frames by nloc using dataset indices (0..len-1). + if not mixed_batch: + self._nloc_groups: dict[int, list[int]] = {} + for ds_idx, nloc in enumerate(self._frame_nlocs): + self._nloc_groups.setdefault(nloc, []).append(ds_idx) + else: + self._nloc_groups = {} + + # Group retained frames by original system id; the sid numbering is + # preserved (no compression) so user-facing auto_prob slices stay + # meaningful across filter thresholds. Fully-dropped systems appear + # as zero-frame entries in ``_system_nframes``. + if self._frame_system_ids is not None: + self._system_groups: dict[int, list[int]] = {} + for ds_idx, sid in enumerate(self._frame_system_ids): + self._system_groups.setdefault(sid, []).append(ds_idx) + self._system_nframes: list[int] = [ + len(self._system_groups.get(i, [])) for i in range(self._nsystems) + ] + else: + self._system_groups = {0: list(range(len(retained_keys)))} + self._system_nframes = [len(retained_keys)] + + # nframes now reflects retained frames; __len__ returns this and the + # valid index domain for __getitem__ is [0, self.nframes). + self.nframes = len(retained_keys) + + # Default batch_size used only by the index/total_batch estimate. The + # sampler always goes through get_batch_size_for_nloc for real batches. + if self._auto_rule is not None: self.batch_size = _compute_batch_size(self._natoms, self._auto_rule) + elif self._max_rule is not None: + self.batch_size = max(1, self._max_rule // max(self._natoms, 1)) else: self.batch_size = int(batch_size) @@ -382,20 +508,44 @@ def __del__(self) -> None: _close_lmdb(path) def get_batch_size_for_nloc(self, nloc: int) -> int: - """Get batch_size for a given nloc. Uses auto rule if configured.""" + """Return the per-nloc batch size for the configured rule. + + - ``auto`` / ``auto:N``: ``ceil(N / nloc)`` — may overshoot the + atom budget by up to ``nloc - 1`` atoms. + - ``max:N``: ``max(1, floor(N / nloc))``. Acts as an upper bound + for groups with ``nloc <= N`` (batch has at most ``N`` atoms). + For groups with ``nloc > N`` the floor clamps to 1 and the + single-frame batch still carries ``nloc`` atoms, exceeding ``N``. + - ``filter:N``: same per-nloc formula as ``max:N``; by + construction every retained group satisfies ``nloc <= N`` so + no overshoot occurs. + - fixed int: the same value for every nloc group. + """ if self._auto_rule is not None: return _compute_batch_size(nloc, self._auto_rule) + if self._max_rule is not None: + return max(1, self._max_rule // max(nloc, 1)) return self.batch_size def __len__(self) -> int: return self.nframes def __getitem__(self, index: int) -> dict[str, Any]: - """Read frame from LMDB, decode, remap keys, return dict of numpy arrays.""" - key = format(index, self._frame_fmt).encode() + """Read frame from LMDB, decode, remap keys, return dict of numpy arrays. + + ``index`` is a dataset-level index in ``[0, len(self))``. Under + ``filter:N`` the LMDB key space may have gaps (dropped frames), so + we translate through ``self._retained_keys`` before hitting LMDB. + """ + if index < 0 or index >= self.nframes: + raise IndexError(f"dataset index {index} out of range [0, {self.nframes})") + original_key = self._retained_keys[index] + key = format(original_key, self._frame_fmt).encode() raw = self._txn.get(key) if raw is None: - raise IndexError(f"Frame {index} not found in LMDB") + raise IndexError( + f"Frame {original_key} not found in LMDB (dataset index {index})" + ) frame = _decode_frame(raw) frame = _remap_keys(frame) @@ -524,7 +674,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: np.float32(1.0) if extra_key in frame else np.float32(0.0) ) - frame["fid"] = index + frame["fid"] = original_key return frame @@ -538,11 +688,19 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N def print_summary(self, name: str, prob: Any) -> None: """Print basic dataset info.""" n_groups = len(self._nloc_groups) + if self._auto_rule is not None: + bs_str = f"auto:{self._auto_rule}" + elif self._filter_rule is not None: + bs_str = f"filter:{self._filter_rule}" + elif self._max_rule is not None: + bs_str = f"max:{self._max_rule}" + else: + bs_str = str(self.batch_size) log.info( f"LMDB {name}: {self.lmdb_path}, " f"{self.nframes} frames, {n_groups} nloc groups, " - f"batch_size={'auto' if self._auto_rule else self.batch_size}, " + f"batch_size={bs_str}, " f"mixed_batch={self.mixed_batch}" ) # Print nloc groups in rows of ~10 for readability @@ -646,6 +804,43 @@ def compute_block_targets( stt, end, weight = part.split(":") blocks.append((int(stt), int(end), float(weight))) + # Drop blocks that retain zero frames (can happen when ``filter:N`` + # eliminates every system in a block). prob_sys_size_ext's per-block + # ``nbatch_block / sum(nbatch_block)`` would otherwise propagate NaN + # when the whole block sums to zero. An all-zero dataset yields no + # targets at all. + nonempty = [ + (stt, end, weight) + for stt, end, weight in blocks + if sum(system_nframes[stt:end]) > 0 + ] + if not nonempty: + log.info( + "compute_block_targets: all blocks are empty in " + f"{auto_prob_style!r}; dataset has no retained frames." + ) + return [] + if len(nonempty) < len(blocks): + # Rewriting auto_prob_style silently re-normalises the remaining + # weights so they sum to 1.0 — e.g. ``0:3:0.8;3:10:0.2`` with block + # ``0:3`` empty becomes effectively weight 1.0 on block ``3:10``. + # Surface this reweighting so operators can correlate it with the + # preceding ``filter:N`` log line. + dropped = [ + f"{stt}:{end}:{weight}" + for (stt, end, weight) in blocks + if (stt, end, weight) not in nonempty + ] + log.info( + "compute_block_targets: dropping empty blocks (all systems have " + f"0 frames, likely after filter:N): {dropped}. Remaining block " + "weights will be renormalised to sum to 1.0." + ) + auto_prob_style = "prob_sys_size;" + ";".join( + f"{stt}:{end}:{weight}" for stt, end, weight in nonempty + ) + blocks = nonempty + # Compute per-system probabilities using the standard function sys_probs = prob_sys_size_ext(auto_prob_style, nsystems, system_nframes) diff --git a/deepmd/pt/utils/lmdb_dataset.py b/deepmd/pt/utils/lmdb_dataset.py index 44d67be242..6bf300a2fc 100644 --- a/deepmd/pt/utils/lmdb_dataset.py +++ b/deepmd/pt/utils/lmdb_dataset.py @@ -112,7 +112,14 @@ class LmdbDataset(Dataset): type_map : list[str] Global type map from model config. batch_size : int or str - Batch size. Supports int, "auto", "auto:N". + Batch size rule forwarded to :class:`LmdbDataReader`. Supports: + + - ``int``: fixed batch size for every nloc group. + - ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group + (``N=32`` for bare ``"auto"``). + - ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group. + - ``"filter:N"``: same per-nloc formula as ``"max:N"`` and drops + every frame whose ``nloc > N`` from the dataset. mixed_batch : bool If True, allow different nloc in the same batch (future). If False (default), use SameNlocBatchSampler. diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 4ac49cf4d8..07e42d725d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3677,8 +3677,8 @@ def training_data_args() -> list[ - string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\ - string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\ - string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\ -- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\ -- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\ +- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\ +- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.\n\n\ If MPI is used, the value should be considered as the batch size per task.' doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\ - "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\ @@ -3757,7 +3757,9 @@ def validation_data_args() -> list[ - list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\ - int: all {link_sys} use the same batch size.\n\n\ - string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\ -- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.' +- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\ +- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\ +- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.' doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\ - "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\ - "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\ diff --git a/source/tests/common/dpmodel/test_lmdb_data.py b/source/tests/common/dpmodel/test_lmdb_data.py index ac096633c2..67c3a4b1bb 100644 --- a/source/tests/common/dpmodel/test_lmdb_data.py +++ b/source/tests/common/dpmodel/test_lmdb_data.py @@ -599,6 +599,22 @@ def test_compute_block_targets_asymmetric(self): self.assertEqual(result[0][0], [0, 1]) self.assertEqual(result[0][1], 400) + def test_compute_block_targets_logs_dropped_block(self): + """Emptied blocks trigger an INFO log. + + The silent re-normalisation of ``auto_prob_style`` (remaining + weights rescaled to sum to 1.0) must be visible to operators + alongside the ``filter:N`` drop line. + """ + with self.assertLogs("deepmd.dpmodel.utils.lmdb_data", level="INFO") as cm: + result = compute_block_targets( + "prob_sys_size;0:1:0.8;1:2:0.2", + nsystems=2, + system_nframes=[0, 500], + ) + self.assertTrue(any("empty blocks" in msg for msg in cm.output)) + self.assertEqual(result, []) + def test_expand_indices_basic(self): frame_system_ids = [0] * 5 + [1] * 5 block_targets = [([0], 25), ([1], 25)] @@ -670,6 +686,297 @@ def test_sampler_without_block_targets(self): self.assertEqual(sorted(all_indices), list(range(600))) +# ============================================================ +# batch_size = "max:N" / "filter:N" tests +# ============================================================ + + +def _create_mixed_sid_nloc_lmdb( + path: str, + system_specs: list[tuple[int, int]], + type_map: list[str] | None = None, +) -> str: + """Build an LMDB whose systems have *different* nloc per sid. + + Existing helpers either fix nloc globally or fix sid boundaries; the + filter:N behaviour we want to exercise depends on nloc varying across + systems so this helper glues both axes together in one tiny LMDB. + + Parameters + ---------- + path + Output LMDB directory. + system_specs + ``[(nframes, natoms), ...]`` for each system (sid = list index). + type_map + Optional element list stored in metadata. + """ + total = sum(nf for nf, _ in system_specs) + frame_system_ids: list[int] = [] + frame_nlocs: list[int] = [] + for sid, (nf, natoms) in enumerate(system_specs): + frame_system_ids.extend([sid] * nf) + frame_nlocs.extend([natoms] * nf) + + env = lmdb.open(path, map_size=100 * 1024 * 1024) + with env.begin(write=True) as txn: + first_natoms = system_specs[0][1] + n0 = max(1, first_natoms // 3) + n1 = first_natoms - n0 + meta = { + "nframes": total, + "frame_idx_fmt": "012d", + "system_info": {"natoms": [n0, n1]}, + "frame_system_ids": frame_system_ids, + "frame_nlocs": frame_nlocs, + } + if type_map is not None: + meta["type_map"] = type_map + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + idx = 0 + for _sid, (nf, natoms) in enumerate(system_specs): + for _ in range(nf): + frame = _make_frame(natoms=natoms, seed=idx % 100) + txn.put( + format(idx, "012d").encode(), + msgpack.packb(frame, use_bin_type=True), + ) + idx += 1 + env.close() + return path + + +class TestMaxFilterBatchSize(unittest.TestCase): + """Tests for ``batch_size='max:N'`` and ``batch_size='filter:N'``.""" + + @classmethod + def setUpClass(cls): + cls._tmpdir = tempfile.TemporaryDirectory() + cls._uniform_path = _create_lmdb( + f"{cls._tmpdir.name}/uniform.lmdb", nframes=10, natoms=6 + ) + cls._mixed_path = _create_mixed_nloc_lmdb(f"{cls._tmpdir.name}/mixed.lmdb") + cls._type_map = ["O", "H"] + + @classmethod + def tearDownClass(cls): + cls._tmpdir.cleanup() + + def test_max_batch_size_single_nloc(self): + """``max:N`` uses floor division and clamps to 1.""" + reader = LmdbDataReader( + self._uniform_path, self._type_map, batch_size="max:500" + ) + # floor(500 / 6) = 83. + self.assertEqual(reader.get_batch_size_for_nloc(6), 83) + self.assertEqual(reader._max_rule, 500) + self.assertIsNone(reader._auto_rule) + self.assertIsNone(reader._filter_rule) + + reader_small = LmdbDataReader( + self._uniform_path, self._type_map, batch_size="max:5" + ) + # floor(5 / 6) == 0 → clamped to 1 so every nloc still yields a batch. + self.assertEqual(reader_small.get_batch_size_for_nloc(6), 1) + + def test_auto_vs_max_ceiling_vs_floor(self): + """``auto:N`` rounds up; ``max:N`` rounds down for the same budget.""" + auto_reader = LmdbDataReader( + self._uniform_path, self._type_map, batch_size="auto:1024" + ) + max_reader = LmdbDataReader( + self._uniform_path, self._type_map, batch_size="max:1000" + ) + # nloc=148: ceil(1024/148)=7, floor(1000/148)=6 + self.assertEqual(auto_reader.get_batch_size_for_nloc(148), 7) + self.assertEqual(max_reader.get_batch_size_for_nloc(148), 6) + # nloc=2: ceil(1024/2)=512, floor(1000/2)=500 + self.assertEqual(auto_reader.get_batch_size_for_nloc(2), 512) + self.assertEqual(max_reader.get_batch_size_for_nloc(2), 500) + + def test_filter_drops_large_nloc_groups(self): + """``filter:N`` removes whole nloc groups above the threshold.""" + # _create_mixed_nloc_lmdb produces nloc groups {6:4, 9:4, 12:2}. + r10 = LmdbDataReader(self._mixed_path, self._type_map, batch_size="filter:10") + self.assertEqual(set(r10.nloc_groups.keys()), {6, 9}) + self.assertEqual(len(r10), 8) + self.assertEqual(r10._max_rule, 10) + self.assertEqual(r10._filter_rule, 10) + + r6 = LmdbDataReader(self._mixed_path, self._type_map, batch_size="filter:6") + self.assertEqual(set(r6.nloc_groups.keys()), {6}) + self.assertEqual(len(r6), 4) + + r100 = LmdbDataReader(self._mixed_path, self._type_map, batch_size="filter:100") + self.assertEqual(set(r100.nloc_groups.keys()), {6, 9, 12}) + self.assertEqual(len(r100), 10) + + def test_filter_preserves_system_id_numbering(self): + """filter:N keeps original sid numbering and zeroes dropped systems.""" + path = f"{self._tmpdir.name}/mixed_sids.lmdb" + # sid 0..2 at natoms=6; sid=3 at natoms=20 (fully dropped by filter:10). + _create_mixed_sid_nloc_lmdb( + path, + system_specs=[(100, 6), (200, 6), (300, 6), (20, 20)], + type_map=self._type_map, + ) + reader = LmdbDataReader(path, self._type_map, batch_size="filter:10") + # sid=3 is fully filtered but the numbering must survive so that + # auto_prob block slicing keeps its user-facing semantics. + self.assertEqual(reader.nsystems, 4) + self.assertEqual(reader.system_nframes, [100, 200, 300, 0]) + self.assertEqual(reader.system_groups.get(3, []), []) + self.assertEqual(len(reader), 600) + + block_targets = compute_block_targets( + "prob_sys_size;0:3:0.5;3:4:0.5", + nsystems=reader.nsystems, + system_nframes=reader.system_nframes, + ) + # Empty block (3:4) drops out, remaining block is already balanced + # after re-normalisation → no expansion needed. + self.assertEqual(block_targets, []) + + def test_filter_dataset_index_is_contiguous_and_live(self): + """After filter:N, every i in range(len(reader)) is a live retrievable frame. + + Regression for the earlier indexing bug where ``len(reader)`` shrank + to the retained count but ``__getitem__`` still indexed the original + LMDB key space. Under filter:10 the mixed-nloc LMDB drops the two + 12-atom frames at original keys 8 & 9; we check here that: + + * every dataset index ``0..len(reader)-1`` decodes without raising + and never returns a filtered-out frame, and + * ``fid`` reports the stable original LMDB key, not the dataset + index (so downstream logs survive the remap), and + * out-of-range indices still raise IndexError. + """ + reader = LmdbDataReader( + self._mixed_path, self._type_map, batch_size="filter:10" + ) + self.assertEqual(len(reader), 8) + self.assertEqual(len(reader._retained_keys), 8) + self.assertEqual(reader._retained_keys, [0, 1, 2, 3, 4, 5, 6, 7]) + + seen_fids = [] + for i in range(len(reader)): + frame = reader[i] + self.assertLessEqual(frame["atype"].shape[0], 10) + self.assertEqual( + frame["fid"], + reader._retained_keys[i], + msg=f"fid should be the original LMDB key, not dataset index {i}", + ) + seen_fids.append(frame["fid"]) + # Dropped original keys (8, 9) must never appear as fids. + self.assertNotIn(8, seen_fids) + self.assertNotIn(9, seen_fids) + + with self.assertRaises(IndexError): + reader[len(reader)] + with self.assertRaises(IndexError): + reader[-1] + + def test_sampler_with_filter(self): + """SameNlocBatchSampler only emits retained, same-nloc frames.""" + reader = LmdbDataReader( + self._mixed_path, self._type_map, batch_size="filter:10" + ) + sampler = SameNlocBatchSampler(reader, shuffle=False, seed=0) + all_batches = list(sampler) + all_indices = [idx for batch in all_batches for idx in batch] + + # (a) every frame in every batch has nloc <= 10 + for batch in all_batches: + for idx in batch: + self.assertLessEqual(reader.frame_nlocs[idx], 10) + # (b) unique frame index count equals retained frames + self.assertEqual(len(set(all_indices)), len(reader)) + self.assertEqual(len(reader), 8) + # (c) each batch is same-nloc + for batch in all_batches: + nlocs = {reader.frame_nlocs[idx] for idx in batch} + self.assertEqual(len(nlocs), 1) + # The 12-atom frames were at original LMDB keys 8, 9; they must + # never be reachable via any emitted dataset index. + reached_original_keys = {reader._retained_keys[idx] for idx in all_indices} + for original_key in (8, 9): + self.assertNotIn(original_key, reached_original_keys) + + def test_invalid_batch_size_strings_rejected(self): + """``:N`` specs with missing / non-positive N fail at init. + + Before this hardening, ``filter:0`` silently dropped every frame + and ``max:`` raised a cryptic ``invalid literal for int()``. + One case per failure mode is enough to pin the behaviour. + """ + for spec in ("filter:", "filter:0", "max:-1"): + with self.assertRaises(ValueError) as ctx: + LmdbDataReader(self._uniform_path, self._type_map, batch_size=spec) + self.assertIn("positive", str(ctx.exception)) + + def test_filter_with_mixed_batch_rejected(self): + """``filter:N`` + ``mixed_batch=True`` must fail loudly. + + The mixed-batch fast path skips the per-frame nloc scan, so + filter:N cannot honour its documented ``nloc > N`` drop. + """ + with self.assertRaises(ValueError) as ctx: + LmdbDataReader( + self._mixed_path, + self._type_map, + batch_size="filter:10", + mixed_batch=True, + ) + self.assertIn("filter", str(ctx.exception)) + self.assertIn("mixed_batch", str(ctx.exception)) + + def test_auto_prob_with_filter_still_works(self): + """compute_block_targets + sampler survive a fully-dropped block.""" + path = f"{self._tmpdir.name}/auto_prob_filter.lmdb" + # filter:10 drops sid=2 (natoms=20), and sid=0 is under-represented + # relative to sid=1 so at least one block still needs expansion. + _create_mixed_sid_nloc_lmdb( + path, + system_specs=[(50, 6), (500, 6), (30, 20)], + type_map=self._type_map, + ) + reader = LmdbDataReader(path, self._type_map, batch_size="filter:10") + self.assertEqual(reader.nsystems, 3) + self.assertEqual(reader.system_nframes, [50, 500, 0]) + self.assertEqual(len(reader), 550) + + block_targets = compute_block_targets( + "prob_sys_size;0:1:0.5;1:3:0.5", + nsystems=reader.nsystems, + system_nframes=reader.system_nframes, + ) + # sid=2 in block 1:3 is empty but block 1:3 overall still has 500 + # frames, so compute_block_targets should produce finite targets. + self.assertTrue( + all(np.isfinite(target) for _sys_ids, target in block_targets), + block_targets, + ) + # Block 0 under-represented relative to weight → expansion needed. + self.assertGreater(len(block_targets), 0) + + sampler = SameNlocBatchSampler( + reader, shuffle=False, seed=0, block_targets=block_targets + ) + all_batches = list(sampler) + all_indices = [idx for batch in all_batches for idx in batch] + # Every index must be a retained frame — no dropped sid=2 / nloc=20. + for idx in all_indices: + self.assertLessEqual(reader.frame_nlocs[idx], 10) + self.assertNotEqual(reader.frame_system_ids[idx], 2) + # Every batch is same-nloc + for batch in all_batches: + nlocs = {reader.frame_nlocs[idx] for idx in batch} + self.assertEqual(len(nlocs), 1) + # Expansion produces more indices than the retained dataset size. + self.assertGreater(len(all_indices), len(reader)) + + # ============================================================ # Neighbor stat from LMDB tests # ============================================================