diff --git a/paimon-python/pypaimon/ray/__init__.py b/paimon-python/pypaimon/ray/__init__.py new file mode 100644 index 000000000000..cb5307efd9f5 --- /dev/null +++ b/paimon-python/pypaimon/ray/__init__.py @@ -0,0 +1,21 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from pypaimon.ray.ray_paimon import read_paimon, write_paimon + +__all__ = ["read_paimon", "write_paimon"] diff --git a/paimon-python/pypaimon/ray/ray_paimon.py b/paimon-python/pypaimon/ray/ray_paimon.py new file mode 100644 index 000000000000..5ea2d21096f4 --- /dev/null +++ b/paimon-python/pypaimon/ray/ray_paimon.py @@ -0,0 +1,124 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +""" +Top-level API for reading and writing Paimon tables with Ray Datasets. + +Usage:: + + from pypaimon.ray import read_paimon, write_paimon + + ds = read_paimon("db.table", catalog_options={"warehouse": "/path"}) + write_paimon(ds, "db.table", catalog_options={"warehouse": "/path"}) +""" + +from typing import Any, Dict, List, Optional + +import ray.data + +from pypaimon.common.predicate import Predicate + + +def read_paimon( + table_identifier: str, + catalog_options: Dict[str, str], + *, + filter: Optional[Predicate] = None, + projection: Optional[List[str]] = None, + limit: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + override_num_blocks: Optional[int] = None, + **read_args, +) -> ray.data.Dataset: + """Read a Paimon table into a Ray Dataset. + + Args: + table_identifier: Full table name, e.g. ``"db_name.table_name"``. + catalog_options: Options passed to ``CatalogFactory.create()``, + e.g. ``{"warehouse": "/path/to/warehouse"}``. + filter: Optional predicate to push down into the scan. + projection: Optional list of column names to read. + limit: Optional row limit for the scan. + ray_remote_args: Optional kwargs passed to ``ray.remote`` in read tasks. + concurrency: Optional max number of Ray read tasks to run concurrently. + override_num_blocks: Optional override for the number of output blocks. + **read_args: Additional kwargs forwarded to ``ray.data.read_datasource``. + + Returns: + A ``ray.data.Dataset`` containing the table data. + """ + from pypaimon.read.datasource.ray_datasource import RayDatasource + from pypaimon.read.datasource.split_provider import CatalogSplitProvider + + if override_num_blocks is not None and override_num_blocks < 1: + raise ValueError( + "override_num_blocks must be at least 1, got {}".format(override_num_blocks) + ) + + datasource = RayDatasource( + CatalogSplitProvider( + table_identifier=table_identifier, + catalog_options=catalog_options, + predicate=filter, + projection=projection, + limit=limit, + ) + ) + return ray.data.read_datasource( + datasource, + ray_remote_args=ray_remote_args, + concurrency=concurrency, + override_num_blocks=override_num_blocks, + **read_args, + ) + + +def write_paimon( + dataset: ray.data.Dataset, + table_identifier: str, + catalog_options: Dict[str, str], + *, + overwrite: bool = False, + concurrency: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, +) -> None: + """Write a Ray Dataset to a Paimon table. + + Args: + dataset: The Ray Dataset to write. + table_identifier: Full table name, e.g. ``"db_name.table_name"``. + catalog_options: Options passed to ``CatalogFactory.create()``. + overwrite: If ``True``, overwrite existing data in the table. + concurrency: Optional max number of Ray write tasks to run concurrently. + ray_remote_args: Optional kwargs passed to ``ray.remote`` in write tasks. + """ + from pypaimon.catalog.catalog_factory import CatalogFactory + from pypaimon.write.ray_datasink import PaimonDatasink + + catalog = CatalogFactory.create(catalog_options) + table = catalog.get_table(table_identifier) + + datasink = PaimonDatasink(table, overwrite=overwrite) + + write_kwargs = {} + if ray_remote_args is not None: + write_kwargs["ray_remote_args"] = ray_remote_args + if concurrency is not None: + write_kwargs["concurrency"] = concurrency + + dataset.write_datasink(datasink, **write_kwargs) diff --git a/paimon-python/pypaimon/read/datasource/ray_datasource.py b/paimon-python/pypaimon/read/datasource/ray_datasource.py index 33ba6904e511..b08d4c1e50df 100644 --- a/paimon-python/pypaimon/read/datasource/ray_datasource.py +++ b/paimon-python/pypaimon/read/datasource/ray_datasource.py @@ -22,15 +22,15 @@ import itertools import logging from functools import partial -from typing import List, Optional, Iterable +from typing import Iterable, List, Optional import pyarrow from packaging.version import parse import ray from ray.data.datasource import Datasource +from pypaimon.read.datasource.split_provider import SplitProvider from pypaimon.read.split import Split -from pypaimon.read.table_read import TableRead from pypaimon.schema.data_types import PyarrowFieldParser logger = logging.getLogger(__name__) @@ -41,35 +41,45 @@ class RayDatasource(Datasource): - """ - Ray Data Datasource implementation for reading Paimon tables. + """Ray Data ``Datasource`` implementation for reading Paimon tables. + + Holds a :class:`SplitProvider` that supplies the four planning artefacts + needed to build read tasks (table, splits, read_type, predicate). Two + provider implementations exist today: - This datasource enables distributed parallel reading of Paimon table splits, - allowing Ray to read multiple splits concurrently across the cluster. + * :class:`CatalogSplitProvider` — resolves a fully-qualified table + identifier through the catalog and runs the ``ReadBuilder`` plan. + Used by the public :func:`pypaimon.ray.read_paimon` facade. + * :class:`PreResolvedSplitProvider` — wraps an already-resolved + ``(table, splits, read_type, predicate)`` tuple. Used by the legacy + ``TableRead.to_ray()`` bridge to skip a second catalog round-trip. + + Both providers are cheap to instantiate; they defer the catalog + round-trip and split planning until the first read. """ - def __init__(self, table_read: TableRead, splits: List[Split]): - """ - Initialize PaimonDatasource. + def __init__(self, split_provider: SplitProvider): + """Initialize a RayDatasource. Args: - table_read: TableRead instance for reading data - splits: List of splits to read + split_provider: The :class:`SplitProvider` that supplies the + table, splits, read_type, and predicate. Construct one with + :class:`CatalogSplitProvider` (from a table identifier + + catalog options) or :class:`PreResolvedSplitProvider` (from + an already-resolved ``TableRead``). """ - self.table_read = table_read - self.splits = splits + self._split_provider = split_provider self._schema = None def get_name(self) -> str: - identifier = self.table_read.table.identifier - table_name = identifier.get_full_name() if hasattr(identifier, 'get_full_name') else str(identifier) - return f"PaimonTable({table_name})" + return f"PaimonTable({self._split_provider.display_name()})" def estimate_inmemory_data_size(self) -> Optional[int]: - if not self.splits: + splits = self._split_provider.splits() + if not splits: return 0 - total_size = sum(split.file_size for split in self.splits) + total_size = sum(split.file_size for split in splits) return total_size if total_size > 0 else None @staticmethod @@ -108,22 +118,26 @@ def get_read_tasks(self, parallelism: int, **kwargs) -> List: if parallelism < 1: raise ValueError(f"parallelism must be at least 1, got {parallelism}") + # Pull provider state into locals once: avoids capturing self in the + # ReadTask closure (see ray-project/ray#49107) and amortises the + # provider-method dispatch over all chunks. + table = self._split_provider.table() + predicate = self._split_provider.predicate() + read_type = self._split_provider.read_type() + splits = self._split_provider.splits() + if not splits: + return [] + if self._schema is None: - self._schema = PyarrowFieldParser.from_paimon_schema(self.table_read.read_type) + self._schema = PyarrowFieldParser.from_paimon_schema(read_type) + schema = self._schema - if parallelism > len(self.splits): - parallelism = len(self.splits) + if parallelism > len(splits): + parallelism = len(splits) logger.warning( f"Reducing the parallelism to {parallelism}, as that is the number of splits" ) - # Store necessary information for creating readers in Ray workers - # Extract these to avoid serializing the entire self object in closures - table = self.table_read.table - predicate = self.table_read.predicate - read_type = self.table_read.read_type - schema = self._schema - # Create a partial function to avoid capturing self in closure # This reduces serialization overhead (see https://github.com/ray-project/ray/issues/49107) def _get_read_task( @@ -163,7 +177,7 @@ def _get_read_task( read_tasks = [] # Distribute splits across tasks using load balancing algorithm - for chunk_splits in self._distribute_splits_into_equal_chunks(self.splits, parallelism): + for chunk_splits in self._distribute_splits_into_equal_chunks(splits, parallelism): if not chunk_splits: continue @@ -174,14 +188,9 @@ def _get_read_task( for split in chunk_splits: if predicate is None: # Only estimate rows if no predicate (predicate filtering changes row count) - row_count = None - if hasattr(split, 'merged_row_count'): - merged_count = split.merged_row_count() - if merged_count is not None: - row_count = merged_count - if row_count is None and hasattr(split, 'row_count') and split.row_count > 0: - row_count = split.row_count - if row_count is not None and row_count > 0: + merged = split.merged_row_count() + row_count = merged if merged is not None else split.row_count + if row_count > 0: total_rows += row_count if hasattr(split, 'file_size') and split.file_size > 0: total_size += split.file_size diff --git a/paimon-python/pypaimon/read/datasource/split_provider.py b/paimon-python/pypaimon/read/datasource/split_provider.py new file mode 100644 index 000000000000..491e8127d2f0 --- /dev/null +++ b/paimon-python/pypaimon/read/datasource/split_provider.py @@ -0,0 +1,164 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""SplitProvider abstraction used by ``RayDatasource``. + +The datasource only needs four things to build read tasks: the underlying +table, the planned splits, the scan read type, and the optional predicate. +``SplitProvider`` decouples how those four items are obtained so the same +datasource can serve both the public ``read_paimon`` facade (which only has +a table identifier + catalog options) and the legacy ``TableRead.to_ray()`` +bridge (which already has a fully resolved ``TableRead``). +""" + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +from pypaimon.read.split import Split + + +class SplitProvider(ABC): + """Source of the planning artefacts required by ``RayDatasource``.""" + + @abstractmethod + def table(self): + """Return the ``FileStoreTable`` to read.""" + + @abstractmethod + def splits(self) -> List[Split]: + """Return the planned splits.""" + + @abstractmethod + def read_type(self): + """Return the scan read type (row / record type).""" + + @abstractmethod + def predicate(self): + """Return the scan-time predicate, or ``None``.""" + + @abstractmethod + def display_name(self) -> str: + """Return a short, human-readable name for the source. + + Used by ``RayDatasource.get_name()`` so the datasource doesn't have + to peek at concrete provider types to format its name. + """ + + +class CatalogSplitProvider(SplitProvider): + """Plan splits from a fully-qualified table identifier and catalog options. + + Resolves the catalog and the table lazily on first access, then runs a + single ``ReadBuilder`` plan to populate splits + read type together. The + same provider should be reused across calls — the planning is cached. + """ + + def __init__( + self, + table_identifier: str, + catalog_options: Dict[str, str], + predicate=None, + projection: Optional[List[str]] = None, + limit: Optional[int] = None, + ): + if not table_identifier: + raise ValueError("table_identifier is required") + if catalog_options is None: + raise ValueError("catalog_options is required") + self._table_identifier = table_identifier + self._catalog_options = catalog_options + self._predicate = predicate + self._projection = projection + self._limit = limit + self._table_cached = None + self._splits_cached = None + self._read_type_cached = None + + def _ensure_table(self): + if self._table_cached is None: + from pypaimon.catalog.catalog_factory import CatalogFactory + catalog = CatalogFactory.create(self._catalog_options) + self._table_cached = catalog.get_table(self._table_identifier) + return self._table_cached + + def _ensure_planned(self): + if self._splits_cached is not None and self._read_type_cached is not None: + return + from pypaimon.read.read_builder import ReadBuilder + rb = ReadBuilder(self._ensure_table()) + if self._predicate is not None: + rb = rb.with_filter(self._predicate) + if self._projection is not None: + rb = rb.with_projection(self._projection) + if self._limit is not None: + rb = rb.with_limit(self._limit) + self._read_type_cached = rb.read_type() + self._splits_cached = rb.new_scan().plan().splits() + + @property + def table_identifier(self) -> str: + return self._table_identifier + + def table(self): + return self._ensure_table() + + def splits(self) -> List[Split]: + self._ensure_planned() + return self._splits_cached + + def read_type(self): + self._ensure_planned() + return self._read_type_cached + + def predicate(self): + return self._predicate + + def display_name(self) -> str: + return self._table_identifier + + +class PreResolvedSplitProvider(SplitProvider): + """Wrap an already-planned ``(table, splits, read_type, predicate)`` tuple. + + Used by ``TableRead.to_ray()`` where the caller has already built a + ``TableRead`` and planned splits, so the catalog round-trip should be + skipped. + """ + + def __init__(self, table, splits: List[Split], read_type, predicate=None): + self._table = table + self._splits = splits + self._read_type = read_type + self._predicate = predicate + + def table(self): + return self._table + + def splits(self) -> List[Split]: + return self._splits + + def read_type(self): + return self._read_type + + def predicate(self): + return self._predicate + + def display_name(self) -> str: + identifier = self._table.identifier + if hasattr(identifier, 'get_full_name'): + return identifier.get_full_name() + return str(identifier) diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index 0aef1d3ca9e5..40cc337aaa4f 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -232,7 +232,16 @@ def to_ray( raise ValueError(f"override_num_blocks must be at least 1, got {override_num_blocks}") from pypaimon.read.datasource.ray_datasource import RayDatasource - datasource = RayDatasource(self, splits) + from pypaimon.read.datasource.split_provider import PreResolvedSplitProvider + + datasource = RayDatasource( + PreResolvedSplitProvider( + table=self.table, + splits=splits, + read_type=self.read_type, + predicate=self.predicate, + ) + ) return ray.data.read_datasource( datasource, ray_remote_args=ray_remote_args, diff --git a/paimon-python/pypaimon/tests/ray_integration_test.py b/paimon-python/pypaimon/tests/ray_integration_test.py new file mode 100644 index 000000000000..1b8e2df5057d --- /dev/null +++ b/paimon-python/pypaimon/tests/ray_integration_test.py @@ -0,0 +1,291 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa +import ray + +from pypaimon import CatalogFactory, Schema + + +class RayIntegrationTest(unittest.TestCase): + """Tests for the top-level read_paimon() / write_paimon() API.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog_options = {'warehouse': cls.warehouse} + + catalog = CatalogFactory.create(cls.catalog_options) + catalog.create_database('default', True) + + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, num_cpus=2) + + @classmethod + def tearDownClass(cls): + try: + if ray.is_initialized(): + ray.shutdown() + except Exception: + pass + try: + shutil.rmtree(cls.tempdir) + except OSError: + pass + + def _create_and_populate_table(self, table_name, pa_schema, data_dict, + primary_keys=None, partition_keys=None, options=None): + """Helper to create a table and write a single batch of data.""" + identifier = 'default.{}'.format(table_name) + schema = Schema.from_pyarrow_schema( + pa_schema, + primary_keys=primary_keys, + partition_keys=partition_keys, + options=options, + ) + catalog = CatalogFactory.create(self.catalog_options) + catalog.create_table(identifier, schema, False) + table = catalog.get_table(identifier) + + test_data = pa.Table.from_pydict(data_dict, schema=pa_schema) + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(test_data) + commit_messages = writer.prepare_commit() + commit = write_builder.new_commit() + commit.commit(commit_messages) + writer.close() + + return identifier + + def test_read_paimon_basic(self): + """read_paimon() reads back the data we wrote.""" + from pypaimon.ray import read_paimon + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('value', pa.int64()), + ]) + identifier = self._create_and_populate_table( + 'test_read_basic', pa_schema, + {'id': [1, 2, 3], 'name': ['a', 'b', 'c'], 'value': [10, 20, 30]}, + ) + + ds = read_paimon(identifier, self.catalog_options, override_num_blocks=1) + self.assertEqual(ds.count(), 3) + + df = ds.to_pandas().sort_values('id').reset_index(drop=True) + self.assertEqual(list(df['id']), [1, 2, 3]) + self.assertEqual(list(df['name']), ['a', 'b', 'c']) + + def test_read_paimon_with_projection(self): + """read_paimon() respects column projection.""" + from pypaimon.ray import read_paimon + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('value', pa.int64()), + ]) + identifier = self._create_and_populate_table( + 'test_read_proj', pa_schema, + {'id': [1, 2], 'name': ['a', 'b'], 'value': [10, 20]}, + ) + + ds = read_paimon(identifier, self.catalog_options, projection=['id', 'name']) + df = ds.to_pandas() + self.assertEqual(set(df.columns), {'id', 'name'}) + self.assertEqual(len(df), 2) + + def test_read_paimon_with_filter(self): + """read_paimon() pushes down a predicate filter.""" + from pypaimon.ray import read_paimon + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('category', pa.string()), + ]) + identifier = self._create_and_populate_table( + 'test_read_filter', pa_schema, + {'id': [1, 2, 3], 'category': ['A', 'B', 'A']}, + ) + + catalog = CatalogFactory.create(self.catalog_options) + table = catalog.get_table(identifier) + pb = table.new_read_builder().new_predicate_builder() + predicate = pb.equal('category', 'A') + + ds = read_paimon(identifier, self.catalog_options, filter=predicate) + self.assertEqual(ds.count(), 2) + df = ds.to_pandas() + self.assertEqual(set(df['category'].tolist()), {'A'}) + + def test_read_paimon_with_limit(self): + """``read_paimon(limit=N)`` propagates the limit into the scan plan. + + Writes 10 rows across two partitions (5 + 5) so the scan produces two + raw-convertible splits. ``limit=3`` causes ``FileScanner`` to drop the + second split once the first already covers the limit, so the Ray + Dataset contains strictly fewer than the full 10 rows. + + We assert ``< 10`` (not ``== N``) because Paimon's scan-time limit is + a per-split cap — whole-split granularity at this layer — not a + row-exact hard limit. Row-exact short-circuiting in the reader is a + separate follow-up. + """ + from pypaimon.ray import read_paimon + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('part', pa.string()), + ('value', pa.string()), + ]) + identifier = self._create_and_populate_table( + 'test_read_limit', pa_schema, + { + 'id': list(range(10)), + 'part': ['a'] * 5 + ['b'] * 5, + 'value': [str(i) for i in range(10)], + }, + partition_keys=['part'], + ) + + # Sanity baseline: the full unbounded scan returns all 10 rows. + ds_full = read_paimon(identifier, self.catalog_options) + self.assertEqual(ds_full.count(), 10) + + # With limit=3, the scan plan drops the second partition's split + # once the first split's row count already covers the limit. + ds = read_paimon(identifier, self.catalog_options, limit=3) + limited_count = ds.count() + self.assertGreater(limited_count, 0) + self.assertLess(limited_count, 10) + + def test_read_paimon_empty_table(self): + """read_paimon() on a table with no data returns an empty dataset.""" + from pypaimon.ray import read_paimon + + pa_schema = pa.schema([('id', pa.int32())]) + identifier = 'default.test_read_empty' + catalog = CatalogFactory.create(self.catalog_options) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table(identifier, schema, False) + + ds = read_paimon(identifier, self.catalog_options) + self.assertEqual(ds.count(), 0) + + def test_write_paimon_basic(self): + """write_paimon() writes data that read_paimon() can round-trip.""" + from pypaimon.ray import read_paimon, write_paimon + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ]) + identifier = 'default.test_write_basic' + catalog = CatalogFactory.create(self.catalog_options) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table(identifier, schema, False) + + source = pa.Table.from_pydict( + {'id': [1, 2, 3], 'name': ['x', 'y', 'z']}, schema=pa_schema, + ) + ds = ray.data.from_arrow(source) + write_paimon(ds, identifier, self.catalog_options) + + result = read_paimon(identifier, self.catalog_options) + self.assertEqual(result.count(), 3) + df = result.to_pandas().sort_values('id').reset_index(drop=True) + self.assertEqual(list(df['name']), ['x', 'y', 'z']) + + def test_write_paimon_overwrite(self): + """write_paimon(overwrite=True) replaces existing data.""" + from pypaimon.ray import read_paimon, write_paimon + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('val', pa.int64()), + ]) + identifier = 'default.test_write_overwrite' + catalog = CatalogFactory.create(self.catalog_options) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table(identifier, schema, False) + + ds1 = ray.data.from_arrow( + pa.Table.from_pydict({'id': [1, 2], 'val': [10, 20]}, schema=pa_schema) + ) + write_paimon(ds1, identifier, self.catalog_options) + + ds2 = ray.data.from_arrow( + pa.Table.from_pydict({'id': [3], 'val': [30]}, schema=pa_schema) + ) + write_paimon(ds2, identifier, self.catalog_options, overwrite=True) + + result = read_paimon(identifier, self.catalog_options) + self.assertEqual(result.count(), 1) + df = result.to_pandas() + self.assertEqual(list(df['id']), [3]) + + def test_read_paimon_primary_key(self): + """read_paimon() merges PK rows correctly after an upsert.""" + from pypaimon.ray import read_paimon + + pa_schema = pa.schema([ + pa.field('id', pa.int32(), nullable=False), + ('name', pa.string()), + ]) + identifier = self._create_and_populate_table( + 'test_read_pk', pa_schema, + {'id': [1, 2, 3], 'name': ['a', 'b', 'c']}, + primary_keys=['id'], + options={'bucket': '2'}, + ) + + catalog = CatalogFactory.create(self.catalog_options) + table = catalog.get_table(identifier) + update = pa.Table.from_pydict({'id': [1, 4], 'name': ['a2', 'd']}, schema=pa_schema) + wb = table.new_batch_write_builder() + w = wb.new_write() + w.write_arrow(update) + msgs = w.prepare_commit() + wb.new_commit().commit(msgs) + w.close() + + ds = read_paimon(identifier, self.catalog_options) + self.assertEqual(ds.count(), 4) + df = ds.to_pandas().sort_values('id').reset_index(drop=True) + self.assertEqual(list(df['name']), ['a2', 'b', 'c', 'd']) + + def test_read_paimon_invalid_override_num_blocks(self): + """override_num_blocks below 1 is rejected with a clear error.""" + from pypaimon.ray import read_paimon + + with self.assertRaises(ValueError): + read_paimon('default.does_not_matter', self.catalog_options, + override_num_blocks=0) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/split_provider_test.py b/paimon-python/pypaimon/tests/split_provider_test.py new file mode 100644 index 000000000000..31152f28a6d1 --- /dev/null +++ b/paimon-python/pypaimon/tests/split_provider_test.py @@ -0,0 +1,178 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.read.datasource.split_provider import ( + CatalogSplitProvider, + PreResolvedSplitProvider, +) + + +class SplitProviderTest(unittest.TestCase): + """Unit tests for the two SplitProvider implementations.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog_options = {'warehouse': cls.warehouse} + + catalog = CatalogFactory.create(cls.catalog_options) + catalog.create_database('default', True) + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ]) + cls.identifier = 'default.split_provider_test' + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table(cls.identifier, schema, False) + table = catalog.get_table(cls.identifier) + + data = pa.Table.from_pydict( + {'id': [1, 2, 3], 'name': ['a', 'b', 'c']}, schema=pa_schema + ) + wb = table.new_batch_write_builder() + writer = wb.new_write() + writer.write_arrow(data) + wb.new_commit().commit(writer.prepare_commit()) + writer.close() + + @classmethod + def tearDownClass(cls): + try: + shutil.rmtree(cls.tempdir) + except OSError: + pass + + def test_catalog_provider_resolves_table_and_splits(self): + """CatalogSplitProvider does the catalog→table→ReadBuilder→Scan dance lazily.""" + provider = CatalogSplitProvider( + table_identifier=self.identifier, + catalog_options=self.catalog_options, + ) + + self.assertIsNone(provider._table_cached) + self.assertIsNone(provider._splits_cached) + self.assertIsNone(provider._read_type_cached) + + table = provider.table() + self.assertIsNotNone(table) + self.assertIs(provider.table(), table) # cached + + splits = provider.splits() + self.assertGreater(len(splits), 0) + self.assertIs(provider.splits(), splits) # cached + self.assertIsNotNone(provider.read_type()) + self.assertIsNone(provider.predicate()) + + def test_catalog_provider_propagates_projection(self): + """``projection`` reaches ``ReadBuilder.with_projection`` (visible via read_type).""" + provider = CatalogSplitProvider( + table_identifier=self.identifier, + catalog_options=self.catalog_options, + projection=['id'], + ) + + read_type = provider.read_type() + field_names = [f.name for f in read_type] + self.assertEqual(field_names, ['id']) + + def test_catalog_provider_propagates_predicate(self): + """``predicate`` is held on the provider and surfaced via predicate().""" + catalog = CatalogFactory.create(self.catalog_options) + table = catalog.get_table(self.identifier) + pb = table.new_read_builder().new_predicate_builder() + pred = pb.equal('id', 2) + + provider = CatalogSplitProvider( + table_identifier=self.identifier, + catalog_options=self.catalog_options, + predicate=pred, + ) + + self.assertIs(provider.predicate(), pred) + + def test_catalog_provider_propagates_limit(self): + """``limit`` reaches ``ReadBuilder.with_limit``: splits are pruned once + the per-split row budget is met. Uses a fresh partitioned table so + each commit produces its own split.""" + pa_schema = pa.schema([('id', pa.int32()), ('name', pa.string())]) + identifier = 'default.split_provider_limit' + schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['id']) + catalog = CatalogFactory.create(self.catalog_options) + catalog.create_table(identifier, schema, False) + table = catalog.get_table(identifier) + for i in range(3): + data = pa.Table.from_pydict({'id': [i], 'name': [f'r{i}']}, schema=pa_schema) + wb = table.new_batch_write_builder() + writer = wb.new_write() + writer.write_arrow(data) + wb.new_commit().commit(writer.prepare_commit()) + writer.close() + + unlimited = CatalogSplitProvider( + table_identifier=identifier, catalog_options=self.catalog_options, + ) + limited = CatalogSplitProvider( + table_identifier=identifier, catalog_options=self.catalog_options, + limit=1, + ) + + # Three single-row commits → three splits; limit=1 prunes after the + # first split meets the budget. + self.assertEqual(len(unlimited.splits()), 3) + self.assertLess(len(limited.splits()), len(unlimited.splits())) + + def test_catalog_provider_requires_identifier_and_options(self): + with self.assertRaises(ValueError): + CatalogSplitProvider( + table_identifier='', catalog_options=self.catalog_options + ) + with self.assertRaises(ValueError): + CatalogSplitProvider( + table_identifier=self.identifier, catalog_options=None + ) + + def test_pre_resolved_provider_returns_inputs(self): + """PreResolvedSplitProvider just hands back what it was given.""" + catalog = CatalogFactory.create(self.catalog_options) + table = catalog.get_table(self.identifier) + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + read_type = rb.read_type() + + provider = PreResolvedSplitProvider( + table=table, splits=splits, read_type=read_type, predicate=None + ) + + self.assertIs(provider.table(), table) + self.assertIs(provider.splits(), splits) + self.assertIs(provider.read_type(), read_type) + self.assertIsNone(provider.predicate()) + + +if __name__ == '__main__': + unittest.main()