Skip to content
21 changes: 21 additions & 0 deletions paimon-python/pypaimon/ray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

from pypaimon.ray.ray_paimon import read_paimon, write_paimon

__all__ = ["read_paimon", "write_paimon"]
124 changes: 124 additions & 0 deletions paimon-python/pypaimon/ray/ray_paimon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
Top-level API for reading and writing Paimon tables with Ray Datasets.

Usage::

from pypaimon.ray import read_paimon, write_paimon

ds = read_paimon("db.table", catalog_options={"warehouse": "/path"})
write_paimon(ds, "db.table", catalog_options={"warehouse": "/path"})
"""

from typing import Any, Dict, List, Optional

import ray.data

from pypaimon.common.predicate import Predicate


def read_paimon(
table_identifier: str,
catalog_options: Dict[str, str],
*,
filter: Optional[Predicate] = None,
projection: Optional[List[str]] = None,
limit: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
concurrency: Optional[int] = None,
override_num_blocks: Optional[int] = None,
**read_args,
) -> ray.data.Dataset:
"""Read a Paimon table into a Ray Dataset.

Args:
table_identifier: Full table name, e.g. ``"db_name.table_name"``.
catalog_options: Options passed to ``CatalogFactory.create()``,
e.g. ``{"warehouse": "/path/to/warehouse"}``.
filter: Optional predicate to push down into the scan.
projection: Optional list of column names to read.
limit: Optional row limit for the scan.
ray_remote_args: Optional kwargs passed to ``ray.remote`` in read tasks.
concurrency: Optional max number of Ray read tasks to run concurrently.
override_num_blocks: Optional override for the number of output blocks.
**read_args: Additional kwargs forwarded to ``ray.data.read_datasource``.

Returns:
A ``ray.data.Dataset`` containing the table data.
"""
from pypaimon.read.datasource.ray_datasource import RayDatasource
from pypaimon.read.datasource.split_provider import CatalogSplitProvider

if override_num_blocks is not None and override_num_blocks < 1:
raise ValueError(
"override_num_blocks must be at least 1, got {}".format(override_num_blocks)
)

datasource = RayDatasource(
CatalogSplitProvider(
table_identifier=table_identifier,
catalog_options=catalog_options,
predicate=filter,
projection=projection,
limit=limit,
)
)
return ray.data.read_datasource(
datasource,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
override_num_blocks=override_num_blocks,
**read_args,
)


def write_paimon(
dataset: ray.data.Dataset,
table_identifier: str,
catalog_options: Dict[str, str],
*,
overwrite: bool = False,
concurrency: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Write a Ray Dataset to a Paimon table.

Args:
dataset: The Ray Dataset to write.
table_identifier: Full table name, e.g. ``"db_name.table_name"``.
catalog_options: Options passed to ``CatalogFactory.create()``.
overwrite: If ``True``, overwrite existing data in the table.
concurrency: Optional max number of Ray write tasks to run concurrently.
ray_remote_args: Optional kwargs passed to ``ray.remote`` in write tasks.
"""
from pypaimon.catalog.catalog_factory import CatalogFactory
from pypaimon.write.ray_datasink import PaimonDatasink

catalog = CatalogFactory.create(catalog_options)
table = catalog.get_table(table_identifier)

datasink = PaimonDatasink(table, overwrite=overwrite)

write_kwargs = {}
if ray_remote_args is not None:
write_kwargs["ray_remote_args"] = ray_remote_args
if concurrency is not None:
write_kwargs["concurrency"] = concurrency

dataset.write_datasink(datasink, **write_kwargs)
83 changes: 46 additions & 37 deletions paimon-python/pypaimon/read/datasource/ray_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
import itertools
import logging
from functools import partial
from typing import List, Optional, Iterable
from typing import Iterable, List, Optional

import pyarrow
from packaging.version import parse
import ray
from ray.data.datasource import Datasource

from pypaimon.read.datasource.split_provider import SplitProvider
from pypaimon.read.split import Split
from pypaimon.read.table_read import TableRead
from pypaimon.schema.data_types import PyarrowFieldParser

logger = logging.getLogger(__name__)
Expand All @@ -41,35 +41,45 @@


class RayDatasource(Datasource):
"""
Ray Data Datasource implementation for reading Paimon tables.
"""Ray Data ``Datasource`` implementation for reading Paimon tables.

Holds a :class:`SplitProvider` that supplies the four planning artefacts
needed to build read tasks (table, splits, read_type, predicate). Two
provider implementations exist today:

This datasource enables distributed parallel reading of Paimon table splits,
allowing Ray to read multiple splits concurrently across the cluster.
* :class:`CatalogSplitProvider` — resolves a fully-qualified table
identifier through the catalog and runs the ``ReadBuilder`` plan.
Used by the public :func:`pypaimon.ray.read_paimon` facade.
* :class:`PreResolvedSplitProvider` — wraps an already-resolved
``(table, splits, read_type, predicate)`` tuple. Used by the legacy
``TableRead.to_ray()`` bridge to skip a second catalog round-trip.

Both providers are cheap to instantiate; they defer the catalog
round-trip and split planning until the first read.
"""

def __init__(self, table_read: TableRead, splits: List[Split]):
"""
Initialize PaimonDatasource.
def __init__(self, split_provider: SplitProvider):
"""Initialize a RayDatasource.

Args:
table_read: TableRead instance for reading data
splits: List of splits to read
split_provider: The :class:`SplitProvider` that supplies the
table, splits, read_type, and predicate. Construct one with
:class:`CatalogSplitProvider` (from a table identifier +
catalog options) or :class:`PreResolvedSplitProvider` (from
an already-resolved ``TableRead``).
"""
self.table_read = table_read
self.splits = splits
self._split_provider = split_provider
self._schema = None

def get_name(self) -> str:
identifier = self.table_read.table.identifier
table_name = identifier.get_full_name() if hasattr(identifier, 'get_full_name') else str(identifier)
return f"PaimonTable({table_name})"
return f"PaimonTable({self._split_provider.display_name()})"

def estimate_inmemory_data_size(self) -> Optional[int]:
if not self.splits:
splits = self._split_provider.splits()
if not splits:
return 0

total_size = sum(split.file_size for split in self.splits)
total_size = sum(split.file_size for split in splits)
return total_size if total_size > 0 else None

@staticmethod
Expand Down Expand Up @@ -108,22 +118,26 @@ def get_read_tasks(self, parallelism: int, **kwargs) -> List:
if parallelism < 1:
raise ValueError(f"parallelism must be at least 1, got {parallelism}")

# Pull provider state into locals once: avoids capturing self in the
# ReadTask closure (see ray-project/ray#49107) and amortises the
# provider-method dispatch over all chunks.
table = self._split_provider.table()
predicate = self._split_provider.predicate()
read_type = self._split_provider.read_type()
splits = self._split_provider.splits()
if not splits:
return []

if self._schema is None:
self._schema = PyarrowFieldParser.from_paimon_schema(self.table_read.read_type)
self._schema = PyarrowFieldParser.from_paimon_schema(read_type)
schema = self._schema

if parallelism > len(self.splits):
parallelism = len(self.splits)
if parallelism > len(splits):
parallelism = len(splits)
logger.warning(
f"Reducing the parallelism to {parallelism}, as that is the number of splits"
)

# Store necessary information for creating readers in Ray workers
# Extract these to avoid serializing the entire self object in closures
table = self.table_read.table
predicate = self.table_read.predicate
read_type = self.table_read.read_type
schema = self._schema

# Create a partial function to avoid capturing self in closure
# This reduces serialization overhead (see https://github.com/ray-project/ray/issues/49107)
def _get_read_task(
Expand Down Expand Up @@ -163,7 +177,7 @@ def _get_read_task(
read_tasks = []

# Distribute splits across tasks using load balancing algorithm
for chunk_splits in self._distribute_splits_into_equal_chunks(self.splits, parallelism):
for chunk_splits in self._distribute_splits_into_equal_chunks(splits, parallelism):
if not chunk_splits:
continue

Expand All @@ -174,14 +188,9 @@ def _get_read_task(
for split in chunk_splits:
if predicate is None:
# Only estimate rows if no predicate (predicate filtering changes row count)
row_count = None
if hasattr(split, 'merged_row_count'):
merged_count = split.merged_row_count()
if merged_count is not None:
row_count = merged_count
if row_count is None and hasattr(split, 'row_count') and split.row_count > 0:
row_count = split.row_count
if row_count is not None and row_count > 0:
merged = split.merged_row_count()
row_count = merged if merged is not None else split.row_count
if row_count > 0:
total_rows += row_count
if hasattr(split, 'file_size') and split.file_size > 0:
total_size += split.file_size
Expand Down
Loading
Loading