diff --git a/paimon-python/pypaimon/read/scanner/bucket_select_converter.py b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py new file mode 100644 index 000000000000..132d60156298 --- /dev/null +++ b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py @@ -0,0 +1,227 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +""" +Predicate-driven bucket pruning for HASH_FIXED tables. + +Mirrors Java's ``org.apache.paimon.operation.BucketSelectConverter``: +walk the predicate, isolate AND clauses that constrain bucket-key fields +with Equal/In, take the cartesian product of literal values, hash each +combination using the writer's hash routine, and produce the set of +buckets the query can possibly hit. All other entries are safely dropped. + +Hard correctness contract: the bucket set this returns is a *superset* of +the buckets that contain any matching rows. False-positive (over-keep) +allowed; false-negative (drop a bucket that has matching rows) MUST never +happen — that would be silent data loss. + +The hashing routine reuses ``RowKeyExtractor._hash_bytes_by_words`` / +``_bucket_from_hash`` from ``pypaimon.write.row_key_extractor`` — the same +code path the writer uses to assign rows to buckets. Reusing it (rather +than copying) is what guarantees read/write hash agreement in the face of +future routine changes. + +Conservative scope (deliberately narrower than Java's general flexibility): + + * Only HASH_FIXED tables (caller's responsibility to gate; this module + does not look at the bucket mode itself). + * All bucket-key fields must be constrained, with Equal or In, in a + single AND-of-OR-of-literals shape. If any bucket-key column is + unconstrained, return None — the caller must scan all buckets. + * Repeated constraints on the same bucket-key column under top-level + AND (e.g. ``id = 1 AND id = 2``) → return None. Java does the same + rather than reasoning about unsatisfiability. + * Total cartesian product capped at MAX_VALUES (1000), again matching + Java; above that, fall back to a full scan. + +Returns a callable ``selector(bucket: int, total_buckets: int) -> bool``. +The callable is cached per ``total_buckets`` to handle the rare case +where bucket count varies across snapshots (rescale). +""" + +from itertools import product +from typing import Any, Callable, Dict, FrozenSet, List, Optional + +from pypaimon.common.predicate import Predicate +from pypaimon.schema.data_types import DataField +from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer +from pypaimon.table.row.internal_row import RowKind +from pypaimon.write.row_key_extractor import (_bucket_from_hash, + _hash_bytes_by_words) + +MAX_VALUES = 1000 + + +def _split_and(p: Predicate) -> List[Predicate]: + if p.method == 'and': + out: List[Predicate] = [] + for child in (p.literals or []): + out.extend(_split_and(child)) + return out + return [p] + + +def _split_or(p: Predicate) -> List[Predicate]: + if p.method == 'or': + out: List[Predicate] = [] + for child in (p.literals or []): + out.extend(_split_or(child)) + return out + return [p] + + +def _extract_or_clause(or_pred: Predicate, + bk_name_to_slot: Dict[str, int]) -> Optional[List[Any]]: + """For one AND-child predicate, return either: + * ``[slot_index, [literal, ...]]`` — the OR/leaf is a pure + Equal-or-In list on a single bucket-key field; or + * ``None`` — the clause is not a bucket-key constraint we can + safely use; the caller skips it. + + All disjuncts must hit the same bucket-key column. Mixed columns or + non-Equal/In operators disqualify the entire AND clause. + """ + slot: Optional[int] = None + values: List[Any] = [] + for clause in _split_or(or_pred): + if clause.method not in ('equal', 'in'): + return None + if clause.field is None or clause.field not in bk_name_to_slot: + return None + this_slot = bk_name_to_slot[clause.field] + if slot is not None and slot != this_slot: + return None + slot = this_slot + for lit in (clause.literals or []): + # Java filters nulls; null literals are degenerate (NULL = NULL + # is UNKNOWN in SQL). Producing zero values for a slot will + # cascade through the cartesian product to "match nothing", + # which is the same observable behaviour as Java. + if lit is None: + continue + values.append(lit) + return None if slot is None else [slot, values] + + +class _Selector: + """Callable bucket filter, lazy + cached per ``total_buckets``.""" + + __slots__ = ('_combinations', '_bucket_key_fields', '_cache') + + def __init__(self, combinations: List[List[Any]], + bucket_key_fields: List[DataField]): + self._combinations = combinations + self._bucket_key_fields = bucket_key_fields + self._cache: Dict[int, FrozenSet[int]] = {} + + def __call__(self, bucket: int, total_buckets: int) -> bool: + # ``total_buckets <= 0`` shows up for postpone / legacy / special + # entries and must NOT be pruned: returning False here would drop + # rows the writer hashed under a different convention. Fail open. + if total_buckets <= 0: + return True + try: + return bucket in self._compute(total_buckets) + except Exception: + # Fail open on any hashing/serialization error (e.g. a literal + # type that doesn't match the bucket-key column's atomic type: + # ``pb.equal('id_bigint', 'foo')`` — GenericRowSerializer raises + # struct.error trying to pack the string as int64). Crashing + # the entire scan here would be worse than skipping pruning; + # the soundness contract still forbids false-negatives. + return True + + def _compute(self, total_buckets: int) -> FrozenSet[int]: + cached = self._cache.get(total_buckets) + if cached is not None: + return cached + result = set() + for combo in self._combinations: + row = GenericRow(list(combo), self._bucket_key_fields, + RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + # Skip the 4-byte length prefix — matches the writer's hash + # input exactly (see RowKeyExtractor._binary_row_hash_code). + h = _hash_bytes_by_words(serialized[4:]) + result.add(_bucket_from_hash(h, total_buckets)) + frozen = frozenset(result) + self._cache[total_buckets] = frozen + return frozen + + @property + def bucket_combinations(self) -> int: + """Number of (bucket-key) combinations used to compute the filter. + Exposed for tests / observability.""" + return len(self._combinations) + + +def create_bucket_selector( + predicate: Optional[Predicate], + bucket_key_fields: List[DataField]) -> Optional[Callable[[int, int], bool]]: + """Try to derive a bucket selector from ``predicate`` constrained to + ``bucket_key_fields``. + + Returns: + A callable ``(bucket, total_buckets) -> bool`` if the predicate + pins down all bucket keys to a finite Equal/In set; otherwise None + (caller must NOT prune by bucket). + """ + if predicate is None or not bucket_key_fields: + return None + + bk_name_to_slot: Dict[str, int] = { + f.name: i for i, f in enumerate(bucket_key_fields) + } + n_slots = len(bucket_key_fields) + slot_values: List[Optional[List[Any]]] = [None] * n_slots + + for and_child in _split_and(predicate): + extracted = _extract_or_clause(and_child, bk_name_to_slot) + if extracted is None: + # Not a bucket-key constraint — that's fine, just skip it. The + # remaining predicate still describes a SUPERSET of matching + # rows; bucket pruning stays sound as long as we don't *add* + # constraints that aren't actually true. + continue + slot, values = extracted + if slot_values[slot] is not None: + # Two AND clauses on the same bucket-key column. Java bails; + # so do we. (e.g. ``id = 1 AND id = 2`` is unsatisfiable but + # we don't reason about that — any superset, including "all + # buckets", is correct.) + return None + slot_values[slot] = values + + # Every bucket-key column must be constrained. + for v in slot_values: + if v is None: + return None + + # Cartesian-product cap. Above the cap the bucket set is essentially + # all buckets anyway; punting saves the hash computation. + total = 1 + for v in slot_values: + # An empty slot (e.g. all literals were null) collapses the + # product to 0 — observable behaviour: empty bucket set, drop + # everything. Mirrors Java. + total *= len(v) + if total > MAX_VALUES: + return None + + combinations = [list(combo) for combo in product(*slot_values)] + return _Selector(combinations, bucket_key_fields) diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py b/paimon-python/pypaimon/read/scanner/file_scanner.py index ffbd83daf0da..f3a62e0f4857 100755 --- a/paimon-python/pypaimon/read/scanner/file_scanner.py +++ b/paimon-python/pypaimon/read/scanner/file_scanner.py @@ -35,6 +35,8 @@ trim_and_transform_predicate) from pypaimon.read.scanner.append_table_split_generator import \ AppendTableSplitGenerator +from pypaimon.read.scanner.bucket_select_converter import \ + create_bucket_selector from pypaimon.read.scanner.data_evolution_split_generator import \ DataEvolutionSplitGenerator from pypaimon.read.scanner.primary_key_table_split_generator import \ @@ -208,6 +210,12 @@ def __init__( self._scanned_snapshot = None self._scanned_snapshot_id = None + # Predicate-driven bucket pruning (HASH_FIXED only). Mirrors Java + # BucketSelectConverter. Set on demand and reused across all + # _filter_manifest_entry calls; the inner _Selector caches the + # bucket set per ``total_buckets`` value. + self._bucket_selector = self._init_bucket_selector() + def schema_fields_func(schema_id: int): return self.table.schema_manager.get_schema(schema_id).fields @@ -387,9 +395,58 @@ def _filter_manifest_file(self, file: ManifestFileMeta) -> bool: file.partition_stats, file.num_added_files + file.num_deleted_files) + def _init_bucket_selector(self): + """Build the predicate-driven bucket selector if (and only if) the + table is in HASH_FIXED mode and the predicate pins all bucket-key + fields to Equal/In literals. Anything else returns None — the + caller treats None as "no bucket-level pruning". + + Bucket-key fields are derived by instantiating the *writer's* + ``FixedBucketRowKeyExtractor`` and reading back its resolved + fields. Reusing the writer class (rather than re-deriving the + list here) is what guarantees the reader always hashes against + the same fields the writer used at insert time — any future + change to the writer's bucket-key resolution propagates + automatically. + + Sound across rescale: ``_Selector`` caches per ``total_buckets``, + which can vary between manifest entries after a bucket rescale. + """ + if self.predicate is None: + return None + # ``bucket_mode()`` returns HASH_FIXED only when ``options.bucket() + # > 0``; other modes (DYNAMIC / POSTPONE / UNAWARE / CROSS_PARTITION) + # have no fixed hash → bucket mapping at write time and must NOT + # be pruned here. + try: + if self.table.bucket_mode() != BucketMode.HASH_FIXED: + return None + except Exception: + # Defensive: any catalog/proxy table that fails the mode check + # falls back to no pruning rather than crashing the scan. + return None + from pypaimon.write.row_key_extractor import \ + FixedBucketRowKeyExtractor + try: + extractor = FixedBucketRowKeyExtractor(self.table.table_schema) + except Exception: + return None + bucket_key_fields = list(extractor._bucket_key_fields) + if not bucket_key_fields: + return None + return create_bucket_selector(self.predicate, bucket_key_fields) + def _filter_manifest_entry(self, entry: ManifestEntry) -> bool: if self.only_read_real_buckets and entry.bucket < 0: return False + # Predicate-driven bucket pruning. Cheapest possible discriminator + # for PK point queries (``pk = 'X'``) — short-circuits before any + # stats decoding. Stays sound across rescale because the selector + # keys its internal cache on ``total_buckets``. + if (self._bucket_selector is not None + and entry.bucket >= 0 + and not self._bucket_selector(entry.bucket, entry.total_buckets)): + return False if self.partition_key_predicate and not self.partition_key_predicate.test(entry.partition): return False # Get SimpleStatsEvolution for this schema diff --git a/paimon-python/pypaimon/tests/pushdown_bucket_test.py b/paimon-python/pypaimon/tests/pushdown_bucket_test.py new file mode 100644 index 000000000000..2ca00a7af900 --- /dev/null +++ b/paimon-python/pypaimon/tests/pushdown_bucket_test.py @@ -0,0 +1,546 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +""" +Three-layer correctness tests for predicate-driven bucket pruning. + +Mirrors Java's ``BucketSelectConverter`` contract: PK Equal/In queries on +HASH_FIXED tables must touch only the bucket(s) the writer would have +placed those keys in. Two correctness obligations: + + 1. Sound: every bucket retained by the selector contains AT MOST a + superset of matching rows. Buckets that DO contain matching rows + are NEVER dropped — false-negative-free. + 2. Hash-consistent with writers: ``RowKeyExtractor`` (writer) and + ``BucketSelectConverter`` (reader) must agree on every literal. + This is what makes ``pk = 'X'`` read the bucket holding 'X'. + +Layered: + * Unit — direct calls to ``create_bucket_selector`` with crafted + predicates, asserting selector behaviour. + * Integration — real PK tables with multiple buckets; queries; assert + (a) result correctness, (b) bucket pruning happened. + * Property — randomly-seeded PK tables, random Equal/In predicates, + result == oracle. No hypothesis dependency (keeps + Python 3.6 compat). +""" + +import os +import random +import shutil +import tempfile +import unittest +from typing import Any, Dict, List + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.common.predicate_builder import PredicateBuilder +from pypaimon.read.scanner.bucket_select_converter import ( + MAX_VALUES, create_bucket_selector) +from pypaimon.schema.data_types import AtomicType, DataField +from pypaimon.write.row_key_extractor import (FixedBucketRowKeyExtractor, + _bucket_from_hash, + _hash_bytes_by_words) +from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer +from pypaimon.table.row.internal_row import RowKind + + +def _bigint_field(idx: int, name: str) -> DataField: + return DataField(idx, name, AtomicType('BIGINT', nullable=False)) + + +def _hash_bucket(values: List[Any], fields: List[DataField], total: int) -> int: + """Re-implement the writer's hash so unit tests can compute the + expected bucket without spinning up a real table.""" + row = GenericRow(values, fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + h = _hash_bytes_by_words(serialized[4:]) + return _bucket_from_hash(h, total) + + +# --------------------------------------------------------------------------- +# Layer 1 — Unit: drive ``create_bucket_selector`` with crafted predicates. +# --------------------------------------------------------------------------- +class BucketSelectConverterUnitTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.id_field = _bigint_field(0, 'id') + cls.val_field = _bigint_field(1, 'val') + cls.k1 = _bigint_field(0, 'k1') + cls.k2 = _bigint_field(1, 'k2') + cls.pb_id_val = PredicateBuilder([cls.id_field, cls.val_field]) + cls.pb_k1_k2 = PredicateBuilder([cls.k1, cls.k2]) + + # -- Equal / In on single bucket key --------------------------------- + def test_equal_on_single_bucket_key_yields_single_bucket(self): + sel = create_bucket_selector( + self.pb_id_val.equal('id', 42), [self.id_field]) + self.assertIsNotNone(sel, "PK Equal must produce a selector") + expected = _hash_bucket([42], [self.id_field], total=8) + for b in range(8): + self.assertEqual( + sel(b, 8), b == expected, + "only bucket {} must be kept (got {})".format(expected, b)) + + def test_in_on_single_bucket_key_unions_buckets(self): + sel = create_bucket_selector( + self.pb_id_val.is_in('id', [1, 2, 3, 100]), [self.id_field]) + expected = {_hash_bucket([v], [self.id_field], 8) + for v in (1, 2, 3, 100)} + for b in range(8): + self.assertEqual(sel(b, 8), b in expected) + + def test_or_of_equals_on_same_field_unions_buckets(self): + # ``id = 1 OR id = 2`` must equal ``id IN (1, 2)``. + pred = PredicateBuilder.or_predicates([ + self.pb_id_val.equal('id', 1), + self.pb_id_val.equal('id', 2), + ]) + sel = create_bucket_selector(pred, [self.id_field]) + expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)} + for b in range(8): + self.assertEqual(sel(b, 8), b in expected) + + # -- Composite bucket keys ------------------------------------------ + def test_composite_bucket_key_intersects_via_cartesian(self): + pred = PredicateBuilder.and_predicates([ + self.pb_k1_k2.is_in('k1', [1, 2]), + self.pb_k1_k2.equal('k2', 99), + ]) + sel = create_bucket_selector(pred, [self.k1, self.k2]) + expected = { + _hash_bucket([k1, 99], [self.k1, self.k2], 4) + for k1 in (1, 2) + } + for b in range(4): + self.assertEqual(sel(b, 4), b in expected) + + def test_composite_bucket_key_missing_one_field_returns_none(self): + pred = self.pb_k1_k2.equal('k1', 1) # k2 unconstrained + sel = create_bucket_selector(pred, [self.k1, self.k2]) + self.assertIsNone(sel, + "all bucket keys must be constrained or fall back") + + # -- Predicates that can't be reduced ------------------------------- + def test_non_bucket_key_predicate_returns_none(self): + sel = create_bucket_selector( + self.pb_id_val.equal('val', 5), [self.id_field]) + self.assertIsNone(sel, "predicate not on bucket key -> no selector") + + def test_range_predicate_on_bucket_key_returns_none(self): + sel = create_bucket_selector( + self.pb_id_val.greater_than('id', 100), [self.id_field]) + self.assertIsNone(sel, "ranges can't be turned into a finite bucket set") + + def test_or_with_non_bucket_key_returns_none(self): + # ``id = 1 OR val = 5`` — ``val`` isn't a bucket key, so the OR + # is not a pure bucket-key constraint. + pred = PredicateBuilder.or_predicates([ + self.pb_id_val.equal('id', 1), + self.pb_id_val.equal('val', 5), + ]) + sel = create_bucket_selector(pred, [self.id_field]) + self.assertIsNone(sel) + + def test_repeated_equal_on_same_key_under_and_returns_none(self): + # ``id = 1 AND id = 2``: unsatisfiable, but Java bails to "no + # filter" rather than reasoning. We do the same — any superset + # of the true set is acceptable. + pred = PredicateBuilder.and_predicates([ + self.pb_id_val.equal('id', 1), + self.pb_id_val.equal('id', 2), + ]) + sel = create_bucket_selector(pred, [self.id_field]) + self.assertIsNone(sel) + + def test_and_with_unrelated_clause_is_unaffected(self): + # ``id = 7 AND val > 100`` — the ``val > 100`` part doesn't + # constrain buckets, but mustn't disqualify the ``id = 7`` part. + pred = PredicateBuilder.and_predicates([ + self.pb_id_val.equal('id', 7), + self.pb_id_val.greater_than('val', 100), + ]) + sel = create_bucket_selector(pred, [self.id_field]) + self.assertIsNotNone(sel) + expected = _hash_bucket([7], [self.id_field], 4) + for b in range(4): + self.assertEqual(sel(b, 4), b == expected) + + # -- Cap & degenerate edge cases ------------------------------------ + def test_cartesian_above_max_values_returns_none(self): + # Two columns of size > sqrt(MAX_VALUES) → product > MAX_VALUES. + size = 33 # 33 * 33 = 1089 > 1000 + pred = PredicateBuilder.and_predicates([ + self.pb_k1_k2.is_in('k1', list(range(size))), + self.pb_k1_k2.is_in('k2', list(range(size))), + ]) + self.assertGreater(size * size, MAX_VALUES) + sel = create_bucket_selector(pred, [self.k1, self.k2]) + self.assertIsNone(sel) + + def test_null_only_literal_drops_everything(self): + # ``id IN (NULL)`` after null-stripping has zero literals; the + # cartesian product is empty → selector matches no buckets. Same + # behaviour as Java. + pred = self.pb_id_val.is_in('id', [None]) + sel = create_bucket_selector(pred, [self.id_field]) + self.assertIsNotNone(sel) + for b in range(4): + self.assertFalse(sel(b, 4), + "all-null literal collapses bucket set to empty") + + def test_no_predicate_returns_none(self): + self.assertIsNone(create_bucket_selector(None, [self.id_field])) + + def test_no_bucket_keys_returns_none(self): + self.assertIsNone( + create_bucket_selector(self.pb_id_val.equal('id', 1), [])) + + # -- Selector cache + rescale ------------------------------------- + def test_selector_caches_per_total_buckets(self): + """Selector must answer correctly when the same query applies to + different ``total_buckets`` values (the rescale scenario).""" + sel = create_bucket_selector( + self.pb_id_val.equal('id', 42), [self.id_field]) + for total in (4, 8, 16, 32): + expected = _hash_bucket([42], [self.id_field], total) + self.assertTrue(sel(expected, total)) + other = (expected + 1) % total + self.assertFalse(sel(other, total)) + + def test_non_positive_total_buckets_fails_open(self): + """Manifest entries can carry ``total_buckets <= 0`` for legacy / + special bucket modes. Pruning MUST fail open — returning False + would silently drop rows the writer placed in those entries. + This is correctness, not performance: the soundness contract + forbids false-negatives.""" + sel = create_bucket_selector( + self.pb_id_val.equal('id', 1), [self.id_field]) + for total in (0, -1, -2): + self.assertTrue(sel(0, total), + "total_buckets={} must be kept (fail open)".format(total)) + self.assertTrue(sel(-1, total)) + self.assertTrue(sel(99, total)) + + def test_type_mismatched_literal_fails_open_not_crash(self): + """If the user constructs a predicate whose literal type doesn't + match the bucket-key column's atomic type — e.g. a STRING literal + on a BIGINT column — ``GenericRowSerializer`` raises during the + deferred hash inside ``_Selector``. The selector MUST swallow the + exception and fail open (return True for every bucket) rather + than propagate it. Crashing the entire scan with an opaque + ``struct.error`` is a worse user experience than silently + skipping bucket pruning, and the soundness contract still + forbids false-negatives.""" + sel = create_bucket_selector( + self.pb_id_val.equal('id', 'not-an-int'), [self.id_field]) + # Construction itself succeeds (no eager hashing). + self.assertIsNotNone(sel) + # Calling the selector must NOT raise; instead it returns True + # for every (bucket, total_buckets), preserving soundness. + for total in (4, 8): + for b in range(total): + self.assertTrue(sel(b, total), + "type-mismatched literal must fail open, " + "not crash (bucket={}, total={})".format(b, total)) + + +# --------------------------------------------------------------------------- +# Layer 2 — Integration: real tables, public API, assert correctness AND +# that pruning actually fired (otherwise we're not testing the optimisation, +# only that we didn't break full-scan). +# --------------------------------------------------------------------------- +class BucketPruningIntegrationTest(unittest.TestCase): + + NUM_BUCKETS = 8 + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', False) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_pk_table(self, name: str, num_buckets: int = NUM_BUCKETS, + bucket_key: str = None) -> Any: + opts = {'bucket': str(num_buckets), 'file.format': 'parquet'} + if bucket_key is not None: + opts['bucket-key'] = bucket_key + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('val', pa.int64()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, primary_keys=['id'], options=opts) + full = 'default.{}'.format(name) + self.catalog.create_table(full, schema, False) + return self.catalog.get_table(full) + + def _write(self, table, rows: List[Dict]): + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('val', pa.int64()), + ]) + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + def _read_with(self, table, predicate=None): + rb = table.new_read_builder() + if predicate is not None: + rb = rb.with_filter(predicate) + splits = rb.new_scan().plan().splits() + if not splits: + return [], splits + return rb.new_read().to_arrow(splits).to_pylist(), splits + + @staticmethod + def _split_buckets(splits) -> set: + """Collect the distinct bucket numbers actually returned in a plan.""" + return {s.bucket for s in splits} + + # -- Equal on PK ----------------------------------------------------- + def test_pk_equal_only_reads_target_bucket(self): + table = self._create_pk_table('int_eq') + rows = [{'id': i, 'val': i * 11} for i in range(100)] + self._write(table, rows) + + target_id = 42 + pred = table.new_read_builder().new_predicate_builder().equal( + 'id', target_id) + got, splits = self._read_with(table, pred) + + # Correctness: row for id=42 returned (and only that). + self.assertEqual(got, [{'id': 42, 'val': 42 * 11}]) + + # Pruning effectiveness: at most 1 bucket touched. + self.assertEqual(len(self._split_buckets(splits)), 1, + "PK equal must touch exactly one bucket") + + def test_pk_in_reads_only_target_buckets(self): + table = self._create_pk_table('int_in') + rows = [{'id': i, 'val': i * 7} for i in range(200)] + self._write(table, rows) + + ids = [3, 17, 99, 150] + pred = table.new_read_builder().new_predicate_builder().is_in( + 'id', ids) + got, splits = self._read_with(table, pred) + got_sorted = sorted(got, key=lambda r: r['id']) + self.assertEqual(got_sorted, + [{'id': i, 'val': i * 7} for i in sorted(ids)]) + + actual = self._split_buckets(splits) + # Selector-derived expectation + ext = FixedBucketRowKeyExtractor(table.table_schema) + expected_buckets = set() + for i in ids: + arr = pa.RecordBatch.from_pylist( + [{'id': i, 'val': 0}], + schema=pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('val', pa.int64()), + ]), + ) + expected_buckets.update(ext._extract_buckets_batch(arr)) + self.assertTrue(actual.issubset(expected_buckets), + "must not read buckets outside the target set; " + "got {}, expected ⊆ {}".format(actual, expected_buckets)) + + # -- Predicates that should NOT prune ------------------------------- + def test_value_only_predicate_falls_back_to_full_scan(self): + """``val < X`` doesn't constrain the PK → selector must be None + and no bucket pruning may fire. Both checked: result correctness + AND the explicit "selector is None" property.""" + table = self._create_pk_table('val_only') + rows = [{'id': i, 'val': i} for i in range(100)] + self._write(table, rows) + + pred = table.new_read_builder().new_predicate_builder().less_than( + 'val', 30) + got, splits = self._read_with(table, pred) + self.assertEqual(sorted([r['id'] for r in got]), list(range(30))) + + # Inspect the scanner's bucket selector to prove pruning DIDN'T + # fire — without this assertion the test would also pass under a + # buggy selector that prunes wrongly but happens to keep the + # rows we picked. + rb = table.new_read_builder().with_filter(pred) + scan = rb.new_scan() + self.assertIsNone(scan.file_scanner._bucket_selector, + "value-only predicate must NOT produce a selector") + + def test_range_on_pk_falls_back_to_full_scan(self): + """``id > X`` is a range, not Equal/In, so cannot derive a bucket + set. Selector returns None — result must still be exact.""" + table = self._create_pk_table('pk_range') + rows = [{'id': i, 'val': i} for i in range(50)] + self._write(table, rows) + + pred = table.new_read_builder().new_predicate_builder().greater_or_equal( + 'id', 40) + got, _ = self._read_with(table, pred) + self.assertEqual(sorted([r['id'] for r in got]), list(range(40, 50))) + + # -- Mixed predicate: Equal on PK AND range on val ------------------ + def test_pk_equal_with_unrelated_value_predicate_still_prunes(self): + table = self._create_pk_table('int_eq_with_val') + rows = [{'id': i, 'val': i} for i in range(50)] + self._write(table, rows) + + pb = table.new_read_builder().new_predicate_builder() + pred = pb.and_predicates([ + pb.equal('id', 25), + pb.greater_than('val', 20), + ]) + got, splits = self._read_with(table, pred) + self.assertEqual(got, [{'id': 25, 'val': 25}]) + self.assertEqual(len(self._split_buckets(splits)), 1, + "Equal on PK still narrows buckets even when " + "AND'd with a non-bucket-key predicate") + + # -- Explicit bucket-key option ------------------------------------ + def test_bucket_key_option_overrides_pk_for_pruning(self): + """When the ``bucket-key`` option is set explicitly, the bucket + derivation must use it — not the trimmed primary keys. This is + the path that catches read/write hash divergence if a refactor + forgets the option.""" + # PK = id, bucket-key = id explicitly (single key but exercises + # the explicit-config branch in ``_init_bucket_selector``). + table = self._create_pk_table('explicit_bk', bucket_key='id') + rows = [{'id': i, 'val': i * 3} for i in range(40)] + self._write(table, rows) + + pred = table.new_read_builder().new_predicate_builder().equal('id', 17) + got, splits = self._read_with(table, pred) + self.assertEqual(got, [{'id': 17, 'val': 51}]) + self.assertEqual(len(self._split_buckets(splits)), 1) + + +# --------------------------------------------------------------------------- +# Layer 3 — Property: random PK tables, random Equal/In predicates, +# correctness vs oracle. +# --------------------------------------------------------------------------- +class BucketPruningPropertyTest(unittest.TestCase): + + SEED = 0xB0CC + TRIALS = 30 + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', False) + cls.rnd = random.Random(cls.SEED) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _make_table(self, idx: int, num_buckets: int): + pa_schema = pa.schema([ + pa.field('k', pa.int64(), nullable=False), + ('v', pa.int64()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + primary_keys=['k'], + options={'bucket': str(num_buckets), 'file.format': 'parquet'}, + ) + name = 'default.bp_{}'.format(idx) + self.catalog.create_table(name, schema, False) + return self.catalog.get_table(name) + + def _write(self, table, rows): + pa_schema = pa.schema([ + pa.field('k', pa.int64(), nullable=False), + ('v', pa.int64()), + ]) + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + def test_property_pk_equal_correctness(self): + for trial in range(self.TRIALS): + num_buckets = self.rnd.choice([2, 4, 8, 16]) + table = self._make_table(trial, num_buckets) + keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100)) + rows = [{'k': k, 'v': k * 13} for k in keys] + self._write(table, rows) + + target = self.rnd.choice(keys) + pb = table.new_read_builder().new_predicate_builder() + pred = pb.equal('k', target) + rb = table.new_read_builder().with_filter(pred) + splits = rb.new_scan().plan().splits() + if splits: + got = rb.new_read().to_arrow(splits).to_pylist() + else: + got = [] + self.assertEqual(got, [{'k': target, 'v': target * 13}], + "trial {} buckets={} target={}: result mismatch" + .format(trial, num_buckets, target)) + + def test_property_pk_in_correctness(self): + for trial in range(self.TRIALS): + num_buckets = self.rnd.choice([2, 4, 8, 16]) + offset = self.TRIALS + trial # avoid name clash with prev test + table = self._make_table(offset, num_buckets) + keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100)) + rows = [{'k': k, 'v': k * 13} for k in keys] + self._write(table, rows) + + target_n = self.rnd.randint(1, min(10, len(keys))) + targets = self.rnd.sample(keys, target_n) + pb = table.new_read_builder().new_predicate_builder() + pred = pb.is_in('k', targets) + rb = table.new_read_builder().with_filter(pred) + splits = rb.new_scan().plan().splits() + if splits: + got = rb.new_read().to_arrow(splits).to_pylist() + else: + got = [] + got_sorted = sorted(got, key=lambda r: r['k']) + want = sorted( + [{'k': k, 'v': k * 13} for k in targets], + key=lambda r: r['k']) + self.assertEqual(got_sorted, want, + "trial {}: IN result mismatch".format(trial)) + + +if __name__ == '__main__': + unittest.main()