Skip to content
Open
27 changes: 21 additions & 6 deletions paimon-python/pypaimon/read/scanner/file_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta
from pypaimon.manifest.simple_stats_evolutions import SimpleStatsEvolutions
from pypaimon.read.plan import Plan
from pypaimon.read.push_down_utils import (remove_row_id_filter,
from pypaimon.read.push_down_utils import (_get_all_fields,
remove_row_id_filter,
trim_and_transform_predicate)
from pypaimon.read.scanner.append_table_split_generator import \
AppendTableSplitGenerator
Expand Down Expand Up @@ -366,20 +367,34 @@ def with_global_index_result(self, result) -> 'FileScanner':
return self

def _apply_push_down_limit(self, splits: List[DataSplit]) -> List[DataSplit]:
"""Mirror Java ``DataTableBatchScan.applyPushDownLimit``: sum the
DV-aware ``merged_row_count`` (== Java ``Split.mergedRowCount()``)
until the limit is met. Splits with unknown merged count fall
through to the reader unchanged.
"""
if self.limit is None:
return splits
scanned_row_count = 0
limited_splits = []
if self._has_non_partition_filter():
return splits

scanned_row_count = 0
limited_splits: List[DataSplit] = []
for split in splits:
if split.raw_convertible:
merged = split.merged_row_count()
if merged is not None:
limited_splits.append(split)
scanned_row_count += split.row_count
scanned_row_count += merged
if scanned_row_count >= self.limit:
return limited_splits

return splits

def _has_non_partition_filter(self) -> bool:
"""Mirror Java ``SnapshotReaderImpl.hasNonPartitionFilter``."""
if self.predicate is None:
return False
partition_keys = set(self.table.partition_keys or [])
return not _get_all_fields(self.predicate).issubset(partition_keys)

def _filter_manifest_file(self, file: ManifestFileMeta) -> bool:
if not self.partition_key_predicate:
return True
Expand Down
71 changes: 71 additions & 0 deletions paimon-python/pypaimon/tests/reader_split_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,5 +330,76 @@ def test_sliced_split_merged_row_count(self):
# This test ensures that if SlicedSplit is created, merged_row_count() works correctly


class ApplyPushDownLimitUnitTest(unittest.TestCase):
Comment thread
TheR1sing3un marked this conversation as resolved.
"""Mock-driven coverage of ``FileScanner._apply_push_down_limit``."""

@staticmethod
def _apply(splits, limit, has_non_partition_filter=False):
from pypaimon.read.scanner.file_scanner import FileScanner

class _FakeScanner:
pass

scanner = _FakeScanner()
scanner.limit = limit
scanner._has_non_partition_filter = lambda: has_non_partition_filter
return FileScanner._apply_push_down_limit(scanner, splits)

@staticmethod
def _split(raw_convertible, row_count, merged_row_count):
class _FakeSplit:
pass

s = _FakeSplit()
s.raw_convertible = raw_convertible
s.row_count = row_count
s._merged = merged_row_count
s.merged_row_count = lambda: s._merged
return s

def test_dv_aware_accumulator_uses_merged_row_count(self):
"""DV-aware raw split + trailing non-raw splits, ``limit > merged``:
pre-fix (``+= row_count``) early-returns ``[raw]``; post-fix
(``+= merged_row_count``) leaves the budget at 4 < 5, the loop
completes, and the fall-through returns all three splits."""
s_raw = self._split(raw_convertible=True, row_count=10, merged_row_count=4)
s_nr1 = self._split(raw_convertible=False, row_count=10, merged_row_count=None)
s_nr2 = self._split(raw_convertible=False, row_count=10, merged_row_count=None)

result = self._apply([s_raw, s_nr1, s_nr2], limit=5)
self.assertEqual(len(result), 3)

def test_accumulator_skips_splits_with_unknown_merged_count(self):
"""A split whose ``merged_row_count()`` returns ``None`` does not
contribute to the budget; the loop completes and returns the
input via the fall-through."""
s = self._split(raw_convertible=True, row_count=10, merged_row_count=None)
result = self._apply([s], limit=5)
self.assertEqual(result, [s])

def test_no_raw_splits_falls_through_to_full_list(self):
"""No split contributes to the budget → fall-through returns all."""
s1 = self._split(raw_convertible=False, row_count=10, merged_row_count=None)
s2 = self._split(raw_convertible=False, row_count=10, merged_row_count=None)
result = self._apply([s1, s2], limit=5)
self.assertEqual(result, [s1, s2])

def test_empty_splits_returns_empty(self):
self.assertEqual(self._apply([], limit=5), [])

def test_no_limit_returns_input_unchanged(self):
s = self._split(raw_convertible=True, row_count=10, merged_row_count=10)
result = self._apply([s], limit=None)
self.assertEqual(result, [s])

def test_non_partition_filter_short_circuits_pushdown(self):
"""Predicate touching a non-partition column → no pushdown,
regardless of how many DV-aware splits the plan contains."""
s_raw = self._split(raw_convertible=True, row_count=10, merged_row_count=10)
result = self._apply(
[s_raw, s_raw, s_raw], limit=5, has_non_partition_filter=True)
self.assertEqual(len(result), 3)


if __name__ == '__main__':
unittest.main()
Loading