diff --git a/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java b/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java index 86cea365c7a1..fc8345db2392 100644 --- a/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java +++ b/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java @@ -466,17 +466,18 @@ public void testReadPkTable() throws Exception { assertThat(table.rowType().getFieldTypes().get(5)) .isEqualTo(DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE()); assertThat(table.rowType().getFieldTypes().get(6)).isEqualTo(DataTypes.TIME()); - assertThat(table.rowType().getFieldTypes().get(7)).isInstanceOf(RowType.class); - RowType metadataType = (RowType) table.rowType().getFieldTypes().get(7); + assertThat(table.rowType().getFieldTypes().get(7)).isEqualTo(DataTypes.BYTES()); + assertThat(table.rowType().getFieldTypes().get(8)).isInstanceOf(RowType.class); + RowType metadataType = (RowType) table.rowType().getFieldTypes().get(8); assertThat(metadataType.getFieldTypes().get(2)).isInstanceOf(RowType.class); assertThat(res) .containsExactlyInAnyOrder( - "+I[1, Apple, Fruit, 1.5, 1970-01-01T00:16:40, 1970-01-01T00:33:20, 1000, (store1, 1001, (Beijing, China))]", - "+I[2, Banana, Fruit, 0.8, 1970-01-01T00:16:40.001, 1970-01-01T00:33:20.001, 2000, (store1, 1002, (Shanghai, China))]", - "+I[3, Carrot, Vegetable, 0.6, 1970-01-01T00:16:40.002, 1970-01-01T00:33:20.002, 3000, (store2, 1003, (Tokyo, Japan))]", - "+I[4, Broccoli, Vegetable, 1.2, 1970-01-01T00:16:40.003, 1970-01-01T00:33:20.003, 4000, (store2, 1004, (Seoul, Korea))]", - "+I[5, Chicken, Meat, 5.0, 1970-01-01T00:16:40.004, 1970-01-01T00:33:20.004, 5000, (store3, 1005, (NewYork, USA))]", - "+I[6, Beef, Meat, 8.0, 1970-01-01T00:16:40.005, 1970-01-01T00:33:20.005, 6000, (store3, 1006, (London, UK))]"); + "+I[1, Apple, Fruit, 1.5, 1970-01-01T00:16:40, 1970-01-01T00:33:20, 1000, [97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97], (store1, 1001, (Beijing, China))]", + "+I[2, Banana, Fruit, 0.8, 1970-01-01T00:16:40.001, 1970-01-01T00:33:20.001, 2000, [98, 98, 98, 98, 98], (store1, 1002, (Shanghai, China))]", + "+I[3, Carrot, Vegetable, 0.6, 1970-01-01T00:16:40.002, 1970-01-01T00:33:20.002, 3000, [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], (store2, 1003, (Tokyo, Japan))]", + "+I[4, Broccoli, Vegetable, 1.2, 1970-01-01T00:16:40.003, 1970-01-01T00:33:20.003, 4000, [98, 105, 110, 97, 114, 121, 95, 118, 97, 108, 117, 101, 95, 52], (store2, 1004, (Seoul, Korea))]", + "+I[5, Chicken, Meat, 5.0, 1970-01-01T00:16:40.004, 1970-01-01T00:33:20.004, 5000, [98, 105, 110, 97, 114, 121, 95, 118, 97, 108, 117, 101, 95, 53], (store3, 1005, (NewYork, USA))]", + "+I[6, Beef, Meat, 8.0, 1970-01-01T00:16:40.005, 1970-01-01T00:33:20.005, 6000, [98, 105, 110, 97, 114, 121, 95, 118, 97, 108, 117, 101, 95, 54], (store3, 1006, (London, UK))]"); PredicateBuilder predicateBuilder = new PredicateBuilder(table.rowType()); int[] ids = {1, 2, 3, 4, 5, 6}; diff --git a/paimon-python/pypaimon/common/options/core_options.py b/paimon-python/pypaimon/common/options/core_options.py index 7d9a227e4a9d..da5f7744ecc2 100644 --- a/paimon-python/pypaimon/common/options/core_options.py +++ b/paimon-python/pypaimon/common/options/core_options.py @@ -199,8 +199,10 @@ class CoreOptions: METADATA_STATS_MODE: ConfigOption[str] = ( ConfigOptions.key("metadata.stats-mode") .string_type() - .default_value("none") - .with_description("Stats Mode, Python by default is none. Java is truncate(16).") + .default_value("truncate(16)") + .with_description("The mode of metadata stats. Available modes: " + "'none' (no stats), 'counts' (null counts only), " + "'full' (exact min/max), 'truncate(length)' (truncated min/max).") ) BLOB_AS_DESCRIPTOR: ConfigOption[bool] = ( @@ -566,7 +568,32 @@ def file_block_size(self, default=None): return self.options.get(CoreOptions.FILE_BLOCK_SIZE, default) def metadata_stats_enabled(self, default=None): - return self.options.get(CoreOptions.METADATA_STATS_MODE, default) == "full" + mode, _ = CoreOptions.parse_metadata_stats_mode( + self.options.get(CoreOptions.METADATA_STATS_MODE, default)) + return mode != "NONE" + + def metadata_stats_mode(self, default=None): + mode = self.options.get(CoreOptions.METADATA_STATS_MODE, default) + CoreOptions.parse_metadata_stats_mode(mode) + return mode.strip() + + @staticmethod + def parse_metadata_stats_mode(mode: str): + if mode is None: + mode = CoreOptions.METADATA_STATS_MODE.default_value() + normalized = mode.strip() + upper = normalized.upper() + if upper in ("NONE", "COUNTS", "FULL"): + return upper, None + if upper.startswith("TRUNCATE(") and upper.endswith(")"): + length_text = upper[9:-1] + if not length_text or not all('0' <= c <= '9' for c in length_text): + raise ValueError(f"Unsupported metadata.stats-mode: {mode}") + length = int(length_text) + if length <= 0: + raise ValueError(f"Truncate length must be > 0, got: {mode}") + return "TRUNCATE", length + raise ValueError(f"Unsupported metadata.stats-mode: {mode}") def blob_as_descriptor(self, default=None): return self.options.get(CoreOptions.BLOB_AS_DESCRIPTOR, default) diff --git a/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py b/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py index f48f4c99f368..a73abfd9b0e8 100644 --- a/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py +++ b/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py @@ -135,6 +135,7 @@ def test_py_write_read_pk_table(self, file_format): ('ts', pa.timestamp('us')), ('ts_ltz', pa.timestamp('us', tz='UTC')), ('t', pa.time32('ms')), + ('blob', pa.binary()), ('metadata', pa.struct([ pa.field('source', pa.string()), pa.field('created_at', pa.int64()), @@ -187,6 +188,7 @@ def test_py_write_read_pk_table(self, file_format): 'ts_ltz': pd.to_datetime([2000000, 2000001, 2000002, 2000003, 2000004, 2000005], unit='ms', utc=True), 't': [datetime.time(0, 0, 1), datetime.time(0, 0, 2), datetime.time(0, 0, 3), datetime.time(0, 0, 4), datetime.time(0, 0, 5), datetime.time(0, 0, 6)], + 'blob': [b'a' * 30, b'b' * 5, b'\xff' * 16, b'binary_value_4', b'binary_value_5', b'binary_value_6'], 'metadata': [ {'source': 'store1', 'created_at': 1001, 'location': {'city': 'Beijing', 'country': 'China'}}, {'source': 'store1', 'created_at': 1002, 'location': {'city': 'Shanghai', 'country': 'China'}}, @@ -212,6 +214,19 @@ def test_py_write_read_pk_table(self, file_format): print(f"Format: {file_format}, Result:\n{result}") self.assertEqual(initial_data.to_dict(), result.to_dict()) + # Verify binary column stats are None (aligned with Java behavior) + if file_format != 'lance' and 'blob' in [f.name for f in table.fields]: + from pypaimon.table.row.generic_row import GenericRowDeserializer + latest_snapshot = table.snapshot_manager().get_latest_snapshot() + manifest_files = table_scan.file_scanner.manifest_list_manager.read_all(latest_snapshot) + manifest_entries = table_scan.file_scanner.manifest_file_manager.read( + manifest_files[0].file_name, lambda row: True, drop_stats=False) + stats = manifest_entries[0].file.value_stats + min_row = GenericRowDeserializer.from_bytes(stats.min_values.data, table.fields) + blob_idx = next(i for i, f in enumerate(table.fields) if f.name == 'blob') + self.assertIsNone(min_row.values[blob_idx], + "binary column should have no min/max stats") + from pypaimon.write.row_key_extractor import FixedBucketRowKeyExtractor expected_bucket_first_row = 2 first_row = initial_data.head(1) diff --git a/paimon-python/pypaimon/tests/predicates_test.py b/paimon-python/pypaimon/tests/predicates_test.py index 629f0235a00b..19341dc78cdc 100644 --- a/paimon-python/pypaimon/tests/predicates_test.py +++ b/paimon-python/pypaimon/tests/predicates_test.py @@ -431,6 +431,15 @@ def test_is_null(self): ) self.assertTrue(pred.test_by_simple_stats(stat_positive, 10)) + def test_missing_minmax_keeps_file_for_value_predicate(self): + stat_missing_minmax = SimpleStats( + min_values=GenericRow([None], []), + max_values=GenericRow([None], []), + null_counts=[0], + ) + pred = Predicate(method="equal", index=0, field="c", literals=["target"]) + self.assertTrue(pred.test_by_simple_stats(stat_missing_minmax, 10)) + def test_filter_with_null_and_or(self): p_gt = Predicate(method='greaterThan', index=1, field='score', literals=[10]) p_null = Predicate(method='isNull', index=1, field='score', literals=[]) diff --git a/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py b/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py index a0a3ad37e85d..9e5fd3ecf9e1 100644 --- a/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py +++ b/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py @@ -17,7 +17,6 @@ import logging import time -import random from datetime import date from decimal import Decimal from unittest.mock import Mock @@ -208,9 +207,7 @@ def test_full_data_types(self): ('f10', pa.decimal128(10, 2)), ('f11', pa.date32()), ]) - stats_enabled = random.random() < 0.5 - options = {'metadata.stats-mode': 'full'} if stats_enabled else {} - schema = Schema.from_pyarrow_schema(simple_pa_schema, options=options) + schema = Schema.from_pyarrow_schema(simple_pa_schema) self.rest_catalog.create_table('default.test_full_data_types', schema, False) table = self.rest_catalog.get_table('default.test_full_data_types') @@ -250,25 +247,21 @@ def test_full_data_types(self): manifest_files[0].file_name, lambda row: table_scan.file_scanner._filter_manifest_entry(row), drop_stats=False) - # Python write does not produce value stats - if stats_enabled: - self.assertEqual(manifest_entries[0].file.value_stats_cols, None) - min_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.min_values.data, - table.fields).values - max_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.max_values.data, - table.fields).values - expected_min_values = [col[0].as_py() for col in expect_data] - expected_max_values = [col[1].as_py() for col in expect_data] - self.assertEqual(min_value_stats, expected_min_values) - self.assertEqual(max_value_stats, expected_max_values) - else: - self.assertEqual(manifest_entries[0].file.value_stats_cols, []) - min_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.min_values.data, - []).values - max_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.max_values.data, - []).values - self.assertEqual(min_value_stats, []) - self.assertEqual(max_value_stats, []) + # Both 'full' and default 'truncate(16)' modes produce value stats + self.assertEqual(manifest_entries[0].file.value_stats_cols, None) + min_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.min_values.data, + table.fields).values + max_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.max_values.data, + table.fields).values + expected_min_values = [col[0].as_py() for col in expect_data] + expected_max_values = [col[1].as_py() for col in expect_data] + # binary columns (f8, f9) don't have min/max stats, aligned with Java behavior + expected_min_values[8] = None + expected_min_values[9] = None + expected_max_values[8] = None + expected_max_values[9] = None + self.assertEqual(min_value_stats, expected_min_values) + self.assertEqual(max_value_stats, expected_max_values) def test_mixed_add_and_delete_entries_same_partition(self): """Test record_count calculation with mixed ADD/DELETE entries in same partition.""" diff --git a/paimon-python/pypaimon/tests/reader_base_test.py b/paimon-python/pypaimon/tests/reader_base_test.py index 31d8205f8802..ee5e1f72c7b5 100644 --- a/paimon-python/pypaimon/tests/reader_base_test.py +++ b/paimon-python/pypaimon/tests/reader_base_test.py @@ -20,7 +20,6 @@ import shutil import tempfile import unittest -import random from datetime import date, datetime, time from decimal import Decimal from unittest.mock import Mock @@ -243,8 +242,7 @@ def test_full_data_types(self): ('f12', pa.date32()), ('f13', pa.time32('ms')), ]) - stats_enabled = random.random() < 0.5 - options = {'metadata.stats-mode': 'full'} if stats_enabled else {} + options = {'metadata.stats-mode': 'full'} schema = Schema.from_pyarrow_schema(simple_pa_schema, options=options) self.catalog.create_table('default.test_full_data_types', schema, False) table = self.catalog.get_table('default.test_full_data_types') @@ -279,10 +277,9 @@ def test_full_data_types(self): table_read = read_builder.new_read() splits = table_scan.plan().splits() - # assert data file without stats + # splits have stats dropped (drop_stats=True by default) first_file = splits[0].files[0] self.assertEqual(first_file.value_stats_cols, []) - self.assertEqual(first_file.value_stats, SimpleStats.empty_stats()) # assert equal actual_data = table_read.to_arrow(splits) @@ -294,25 +291,20 @@ def test_full_data_types(self): manifest_entries = table_scan.file_scanner.manifest_file_manager.read( manifest_files[0].file_name, lambda row: table_scan.file_scanner._filter_manifest_entry(row), False) - # Python write does not produce value stats - if stats_enabled: - self.assertEqual(manifest_entries[0].file.value_stats_cols, None) - min_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.min_values.data, - table.fields).values - max_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.max_values.data, - table.fields).values - expected_min_values = [col[0].as_py() for col in expect_data] - expected_max_values = [col[1].as_py() for col in expect_data] - self.assertEqual(min_value_stats, expected_min_values) - self.assertEqual(max_value_stats, expected_max_values) - else: - self.assertEqual(manifest_entries[0].file.value_stats_cols, []) - min_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.min_values.data, - []).values - max_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.max_values.data, - []).values - self.assertEqual(min_value_stats, []) - self.assertEqual(max_value_stats, []) + self.assertEqual(manifest_entries[0].file.value_stats_cols, None) + min_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.min_values.data, + table.fields).values + max_value_stats = GenericRowDeserializer.from_bytes(manifest_entries[0].file.value_stats.max_values.data, + table.fields).values + expected_min_values = [col[0].as_py() for col in expect_data] + expected_max_values = [col[1].as_py() for col in expect_data] + # binary columns (f8, f9) don't have min/max stats, aligned with Java behavior + expected_min_values[8] = None + expected_min_values[9] = None + expected_max_values[8] = None + expected_max_values[9] = None + self.assertEqual(min_value_stats, expected_min_values) + self.assertEqual(max_value_stats, expected_max_values) def test_write_wrong_schema(self): self.catalog.create_table('default.test_wrong_schema', @@ -615,6 +607,263 @@ def test_primary_key_value_stats_excludes_system_fields(self): self.assertFalse(is_system_field, f"value_stats_cols should not contain system field: {field_name}") + def test_truncate_stats(self): + from pypaimon.write.writer.data_writer import DataWriter, _truncate_min, _truncate_max + + self.assertEqual(_truncate_min('abcdefghij', 5), 'abcde') + self.assertEqual(_truncate_min('abc', 5), 'abc') + self.assertEqual(_truncate_min(None, 5), None) + self.assertEqual(_truncate_min(42, 5), 42) + + self.assertEqual(_truncate_max('abc', 5), 'abc') + self.assertEqual(_truncate_max('abcdefghij', 5), 'abcdf') + self.assertIsNone(_truncate_max('\ud7ffx', 1)) + self.assertEqual(_truncate_max('a\ud7ffx', 2), 'b') + self.assertEqual(_truncate_max(None, 5), None) + self.assertEqual(_truncate_max(42, 5), 42) + + self.assertEqual(_truncate_min(b'\x01\x02\x03\x04\x05\x06', 3), b'\x01\x02\x03') + self.assertEqual(_truncate_max(b'\x01\x02\x03\x04\x05\x06', 3), b'\x01\x02\x04') + self.assertIsNone(_truncate_max(b'\xff\xff\xff\x00', 3)) + + self.assertEqual(DataWriter._parse_truncate_length('truncate(10)'), ('TRUNCATE', 10)) + for invalid_mode in ['truncate(+1)', 'truncate( 1)', 'truncate(10.1)', 'truncate()', 'truncate(0)']: + with self.assertRaises(ValueError): + DataWriter._parse_truncate_length(invalid_mode) + + def test_metadata_stats_mode_defaults(self): + from pypaimon.common.options.core_options import CoreOptions + from pypaimon.common.options import Options + + core_options = CoreOptions(Options({})) + + self.assertEqual(core_options.metadata_stats_mode(), 'truncate(16)') + self.assertTrue(core_options.metadata_stats_enabled()) + + disabled_options = CoreOptions(Options({ + CoreOptions.METADATA_STATS_MODE.key(): ' none ' + })) + self.assertEqual(disabled_options.metadata_stats_mode(), 'none') + self.assertFalse(disabled_options.metadata_stats_enabled()) + + invalid_options = CoreOptions(Options({ + CoreOptions.METADATA_STATS_MODE.key(): 'tuncate(16)' + })) + with self.assertRaises(ValueError): + invalid_options.metadata_stats_mode() + with self.assertRaises(ValueError): + invalid_options.metadata_stats_enabled() + + def test_invalid_stats_mode_rejected_before_writing_file(self): + catalog = CatalogFactory.create({"warehouse": self.warehouse}) + catalog.create_database("test_db_invalid_stats_mode", True) + + pa_schema = pa.schema([('id', pa.int64()), ('name', pa.string())]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'metadata.stats-mode': 'tuncate(16)', + 'target-file-size': '1b', + } + ) + catalog.create_table("test_db_invalid_stats_mode.t", schema, False) + table = catalog.get_table("test_db_invalid_stats_mode.t") + + data = pa.Table.from_pydict({'id': [1], 'name': ['Alice']}, schema=pa_schema) + tw = table.new_batch_write_builder().new_write() + with self.assertRaises(ValueError): + tw.write_arrow(data) + + data_files = [] + table_path = os.path.join(self.warehouse, "test_db_invalid_stats_mode.db", "t") + for root, _, files in os.walk(table_path): + for file in files: + if file.endswith(('.avro', '.orc', '.parquet')): + data_files.append(os.path.join(root, file)) + self.assertEqual(data_files, []) + + def test_high_precision_timestamp_stats_skip_minmax(self): + from pypaimon.write.writer.data_writer import DataWriter + + data = pa.Table.from_pydict( + { + 'ts_us': [datetime(2024, 1, 1, 0, 0, 0, 999)], + 'ts_ns': [datetime(2024, 1, 1, 0, 0, 0, 999)], + }, + schema=pa.schema([ + ('ts_us', pa.timestamp('us')), + ('ts_ns', pa.timestamp('ns')), + ]) + ) + + for column in ['ts_us', 'ts_ns']: + stats = DataWriter._get_column_stats(data, column, 'full') + self.assertIsNone(stats['min_values']) + self.assertIsNone(stats['max_values']) + self.assertEqual(stats['null_counts'], 0) + + def test_default_truncate_stats_e2e(self): + catalog = CatalogFactory.create({"warehouse": self.warehouse}) + catalog.create_database("test_db_truncate_e2e", True) + + pa_schema = pa.schema([('id', pa.int64()), ('name', pa.string())]) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table("test_db_truncate_e2e.t", schema, False) + table = catalog.get_table("test_db_truncate_e2e.t") + + long_str = 'a' * 30 + data = pa.Table.from_pydict({'id': [1], 'name': [long_str]}) + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(data) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + snap = table.snapshot_manager().get_latest_snapshot() + rb = table.new_read_builder() + scan = rb.new_scan() + mf = scan.file_scanner.manifest_list_manager.read_all(snap) + entries = scan.file_scanner.manifest_file_manager.read( + mf[0].file_name, lambda r: True, drop_stats=False) + stats = entries[0].file.value_stats + min_row = GenericRowDeserializer.from_bytes(stats.min_values.data, table.fields) + max_row = GenericRowDeserializer.from_bytes(stats.max_values.data, table.fields) + self.assertEqual(min_row.values[1], 'a' * 16) + self.assertEqual(max_row.values[1], 'a' * 15 + 'b') + + def test_default_truncate_skips_invalid_surrogate_max_e2e(self): + catalog = CatalogFactory.create({"warehouse": self.warehouse}) + catalog.create_database("test_db_truncate_surrogate_e2e", True) + + pa_schema = pa.schema([('id', pa.int64()), ('name', pa.string())]) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table("test_db_truncate_surrogate_e2e.t", schema, False) + table = catalog.get_table("test_db_truncate_surrogate_e2e.t") + + high_boundary_str = 'a' * 15 + '\ud7ff' + 'tail' + data = pa.Table.from_pydict({'id': [1], 'name': [high_boundary_str]}, schema=pa_schema) + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(data) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + snap = table.snapshot_manager().get_latest_snapshot() + rb = table.new_read_builder() + scan = rb.new_scan() + mf = scan.file_scanner.manifest_list_manager.read_all(snap) + entries = scan.file_scanner.manifest_file_manager.read( + mf[0].file_name, lambda r: True, drop_stats=False) + stats = entries[0].file.value_stats + min_row = GenericRowDeserializer.from_bytes(stats.min_values.data, table.fields) + max_row = GenericRowDeserializer.from_bytes(stats.max_values.data, table.fields) + self.assertEqual(min_row.values[1], high_boundary_str[:16]) + self.assertEqual(max_row.values[1], 'a' * 14 + 'b') + + def test_default_truncate_binary_stats_e2e(self): + catalog = CatalogFactory.create({"warehouse": self.warehouse}) + catalog.create_database("test_db_truncate_binary_e2e", True) + + pa_schema = pa.schema([('id', pa.int64()), ('payload', pa.binary())]) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table("test_db_truncate_binary_e2e.t", schema, False) + table = catalog.get_table("test_db_truncate_binary_e2e.t") + + long_bytes = b'a' * 30 + data = pa.Table.from_pydict({'id': [1], 'payload': [long_bytes]}, schema=pa_schema) + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(data) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + snap = table.snapshot_manager().get_latest_snapshot() + rb = table.new_read_builder() + scan = rb.new_scan() + mf = scan.file_scanner.manifest_list_manager.read_all(snap) + entries = scan.file_scanner.manifest_file_manager.read( + mf[0].file_name, lambda r: True, drop_stats=False) + stats = entries[0].file.value_stats + min_row = GenericRowDeserializer.from_bytes(stats.min_values.data, table.fields) + max_row = GenericRowDeserializer.from_bytes(stats.max_values.data, table.fields) + self.assertIsNone(min_row.values[1]) + self.assertIsNone(max_row.values[1]) + + def test_default_stats_skips_high_precision_decimal_minmax(self): + catalog = CatalogFactory.create({"warehouse": self.warehouse}) + catalog.create_database("test_db_decimal_stats", True) + + pa_schema = pa.schema([('id', pa.int64()), ('amount', pa.decimal128(38, 0))]) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table("test_db_decimal_stats.t", schema, False) + table = catalog.get_table("test_db_decimal_stats.t") + + min_amount = Decimal('-123456789012345678901234567890') + max_amount = Decimal('123456789012345678901234567890') + data = pa.Table.from_pydict( + {'id': [1, 2], 'amount': [max_amount, min_amount]}, + schema=pa_schema) + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(data) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + snap = table.snapshot_manager().get_latest_snapshot() + rb = table.new_read_builder() + scan = rb.new_scan() + mf = scan.file_scanner.manifest_list_manager.read_all(snap) + entries = scan.file_scanner.manifest_file_manager.read( + mf[0].file_name, lambda r: True, drop_stats=False) + stats = entries[0].file.value_stats + min_row = GenericRowDeserializer.from_bytes(stats.min_values.data, table.fields) + max_row = GenericRowDeserializer.from_bytes(stats.max_values.data, table.fields) + self.assertIsNone(min_row.values[1]) + self.assertIsNone(max_row.values[1]) + self.assertEqual(stats.null_counts, [0, 0]) + self.assertEqual(rb.new_read().to_arrow(scan.plan().splits()), data) + + def test_default_stats_with_high_precision_timestamp_e2e(self): + catalog = CatalogFactory.create({"warehouse": self.warehouse}) + catalog.create_database("test_db_timestamp_stats", True) + + pa_schema = pa.schema([('id', pa.int64()), ('ts', pa.timestamp('us'))]) + schema = Schema.from_pyarrow_schema(pa_schema) + catalog.create_table("test_db_timestamp_stats.t", schema, False) + table = catalog.get_table("test_db_timestamp_stats.t") + + value = datetime(2024, 1, 1, 0, 0, 0, 999) + data = pa.Table.from_pydict({'id': [1], 'ts': [value]}, schema=pa_schema) + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(data) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + snap = table.snapshot_manager().get_latest_snapshot() + rb = table.new_read_builder() + scan = rb.new_scan() + mf = scan.file_scanner.manifest_list_manager.read_all(snap) + entries = scan.file_scanner.manifest_file_manager.read( + mf[0].file_name, lambda r: True, drop_stats=False) + stats = entries[0].file.value_stats + min_row = GenericRowDeserializer.from_bytes(stats.min_values.data, table.fields) + max_row = GenericRowDeserializer.from_bytes(stats.max_values.data, table.fields) + self.assertIsNone(min_row.values[1]) + self.assertIsNone(max_row.values[1]) + self.assertEqual(stats.null_counts, [0, 0]) + self.assertEqual(rb.new_read().to_arrow(scan.plan().splits()), data) + def test_value_stats_empty_when_stats_disabled(self): catalog = CatalogFactory.create({ "warehouse": self.warehouse @@ -698,6 +947,49 @@ def test_value_stats_empty_when_stats_disabled(self): "value_stats.null_counts should be empty (same as SimpleStats.empty_stats()) when stats are disabled" ) + def test_value_stats_counts_mode_e2e(self): + catalog = CatalogFactory.create({"warehouse": self.warehouse}) + catalog.create_database("test_db_stats_counts", True) + + pa_schema = pa.schema([ + ('id', pa.int64()), + ('name', pa.string()), + ('price', pa.float64()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={'metadata.stats-mode': 'counts'} + ) + catalog.create_table("test_db_stats_counts.t", schema, False) + table = catalog.get_table("test_db_stats_counts.t") + + data = pa.Table.from_pydict({ + 'id': [1, 2, 3], + 'name': ['Alice', None, 'Charlie'], + 'price': [None, 20.3, None], + }, schema=pa_schema) + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(data) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + snap = table.snapshot_manager().get_latest_snapshot() + rb = table.new_read_builder() + scan = rb.new_scan() + mf = scan.file_scanner.manifest_list_manager.read_all(snap) + entries = scan.file_scanner.manifest_file_manager.read( + mf[0].file_name, lambda r: True, drop_stats=False) + stats = entries[0].file.value_stats + min_row = GenericRowDeserializer.from_bytes(stats.min_values.data, table.fields) + max_row = GenericRowDeserializer.from_bytes(stats.max_values.data, table.fields) + + self.assertEqual(min_row.values, [None, None, None]) + self.assertEqual(max_row.values, [None, None, None]) + self.assertEqual(stats.null_counts, [0, 1, 2]) + def test_types(self): data_fields = [ DataField(0, "f0", AtomicType('TINYINT'), 'desc'), @@ -1246,6 +1538,69 @@ def test_primary_key_value_stats(self): actual_ids = sorted(actual_data.column('id').to_pylist()) self.assertEqual(actual_ids, [1, 2, 3, 4, 5], "All IDs should be present") + def test_primary_key_partial_write_value_stats(self): + pa_schema = pa.schema([ + ('id', pa.int64()), + ('name', pa.string()), + ('price', pa.float64()), + ('category', pa.string()) + ]) + partial_schema = pa.schema([ + ('name', pa.string()), + pa.field('id', pa.int64(), nullable=False), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + primary_keys=['id'], + options={'metadata.stats-mode': 'full', 'bucket': '2'} + ) + self.catalog.create_table('default.test_pk_partial_value_stats', schema, False) + table = self.catalog.get_table('default.test_pk_partial_value_stats') + + partial_data = pa.Table.from_pydict({ + 'name': ['Alice', 'Bob'], + 'id': [1, 2], + }, schema=partial_schema) + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write().with_write_type(['name', 'id']) + writer.write_arrow(partial_data) + commit_messages = writer.prepare_commit() + commit = write_builder.new_commit() + commit.commit(commit_messages) + writer.close() + commit.close() + + files = [file for msg in commit_messages for file in msg.new_files] + self.assertGreater(len(files), 0) + for file in files: + self.assertIsNone(file.write_cols) + self.assertEqual(file.value_stats_cols, ['id', 'name']) + self.assertEqual(len(file.value_stats.min_values), 2) + self.assertEqual(len(file.value_stats.max_values), 2) + self.assertEqual(len(file.value_stats.null_counts), 2) + + read_builder = table.new_read_builder() + scan = read_builder.new_scan() + snap = table.snapshot_manager().get_latest_snapshot() + mf = scan.file_scanner.manifest_list_manager.read_all(snap) + entries = scan.file_scanner.manifest_file_manager.read( + mf[0].file_name, + lambda row: scan.file_scanner._filter_manifest_entry(row), + False + ) + stats_file = entries[0].file + self.assertIsNone(stats_file.write_cols) + self.assertEqual(stats_file.value_stats_cols, ['id', 'name']) + stats_fields = [table.field_dict[col] for col in stats_file.value_stats_cols] + min_row = GenericRowDeserializer.from_bytes(stats_file.value_stats.min_values.data, stats_fields) + max_row = GenericRowDeserializer.from_bytes(stats_file.value_stats.max_values.data, stats_fields) + self.assertEqual(min_row.values, [1, 'Alice']) + self.assertEqual(max_row.values, [2, 'Bob']) + + actual = read_builder.new_read().to_arrow(scan.plan().splits()) + self.assertEqual(actual.column('id').to_pylist(), [1, 2]) + self.assertEqual(actual.column('name').to_pylist(), ['Alice', 'Bob']) + def test_split_target_size(self): """Test source.split.target-size configuration effect on split generation.""" from pypaimon.common.options.core_options import CoreOptions diff --git a/paimon-python/pypaimon/tests/write/table_write_test.py b/paimon-python/pypaimon/tests/write/table_write_test.py index 96abfb73aeec..58f3386524b7 100644 --- a/paimon-python/pypaimon/tests/write/table_write_test.py +++ b/paimon-python/pypaimon/tests/write/table_write_test.py @@ -30,6 +30,7 @@ from pypaimon.common.json_util import JSON from pypaimon.common.options.core_options import CoreOptions from pypaimon.write.writer.append_only_data_writer import AppendOnlyDataWriter +from pypaimon.write.writer.data_writer import DataWriter class TableWriteTest(unittest.TestCase): @@ -97,6 +98,81 @@ def test_write_snapshot(self): self.assertEquals(True, snapshot_json.__contains__("baseManifestList")) self.assertEquals(False, snapshot_json.__contains__("nextRowId")) + def test_partial_write_requires_partition_key(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.catalog.create_table('default.test_partial_missing_partition_key', schema, False) + table = self.catalog.get_table('default.test_partial_missing_partition_key') + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write().with_write_type(['user_id', 'item_id', 'behavior']) + partial_schema = pa.schema([ + ('user_id', pa.int32()), + ('item_id', pa.int64()), + ('behavior', pa.string()) + ]) + partial_data = pa.Table.from_pydict({ + 'user_id': [1], + 'item_id': [1001], + 'behavior': ['a'] + }, schema=partial_schema) + + try: + with self.assertRaisesRegex(ValueError, "Missing routing fields.*dt"): + table_write.write_arrow(partial_data) + finally: + table_write.close() + + def test_partial_write_requires_bucket_key(self): + schema = Schema.from_pyarrow_schema( + self.pk_pa_schema, + partition_keys=['dt'], + primary_keys=['user_id', 'dt'], + options={'bucket': '2'}, + ) + self.catalog.create_table('default.test_partial_missing_bucket_key', schema, False) + table = self.catalog.get_table('default.test_partial_missing_bucket_key') + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write().with_write_type(['item_id', 'behavior', 'dt']) + partial_schema = pa.schema([ + ('item_id', pa.int64()), + ('behavior', pa.string()), + pa.field('dt', pa.string(), nullable=False) + ]) + partial_data = pa.Table.from_pydict({ + 'item_id': [1001], + 'behavior': ['a'], + 'dt': ['p1'] + }, schema=partial_schema) + + try: + with self.assertRaisesRegex(ValueError, "Missing routing fields.*user_id"): + table_write.write_arrow(partial_data) + finally: + table_write.close() + + def test_full_stats_skip_binary_minmax(self): + data = pa.Table.from_pydict({ + 'payload': [b'zulu', b'alpha', None] + }, schema=pa.schema([('payload', pa.binary())])) + + stats = DataWriter._get_column_stats(data, 'payload', 'full') + + self.assertIsNone(stats['min_values']) + self.assertIsNone(stats['max_values']) + self.assertEqual(stats['null_counts'], 1) + + def test_truncate_stats_skip_binary_minmax(self): + data = pa.Table.from_pydict({ + 'payload': [b'zulu', b'alpha'] + }, schema=pa.schema([('payload', pa.binary())])) + + stats = DataWriter._get_column_stats(data, 'payload', 'truncate(16)') + + self.assertIsNone(stats['min_values']) + self.assertIsNone(stats['max_values']) + self.assertEqual(stats['null_counts'], 0) + def test_multi_prepare_commit_ao(self): schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) self.catalog.create_table('default.test_append_only_parquet', schema, False) diff --git a/paimon-python/pypaimon/write/row_key_extractor.py b/paimon-python/pypaimon/write/row_key_extractor.py index dad93bf1eda7..2add0fe7e744 100644 --- a/paimon-python/pypaimon/write/row_key_extractor.py +++ b/paimon-python/pypaimon/write/row_key_extractor.py @@ -83,24 +83,28 @@ class RowKeyExtractor(ABC): def __init__(self, table_schema: TableSchema): self.table_schema = table_schema - self.partition_indices = self._get_field_indices(table_schema.partition_keys) def extract_partition_bucket_batch(self, data: pa.RecordBatch) -> Tuple[List[Tuple], List[int]]: partitions = self._extract_partitions_batch(data) buckets = self._extract_buckets_batch(data) return partitions, buckets - def _get_field_indices(self, field_names: List[str]) -> List[int]: + @staticmethod + def _get_data_field_indices(data: pa.RecordBatch, field_names: List[str]) -> List[int]: if not field_names: return [] - field_map = {field.name: i for i, field in enumerate(self.table_schema.fields)} - return [field_map[name] for name in field_names if name in field_map] + field_map = {field.name: i for i, field in enumerate(data.schema)} + missing = [name for name in field_names if name not in field_map] + if missing: + raise ValueError(f"Missing routing fields in input data: {missing}") + return [field_map[name] for name in field_names] def _extract_partitions_batch(self, data: pa.RecordBatch) -> List[Tuple]: - if not self.partition_indices: + partition_indices = self._get_data_field_indices(data, self.table_schema.partition_keys) + if not partition_indices: return [() for _ in range(data.num_rows)] - partition_columns = [data.column(i) for i in self.partition_indices] + partition_columns = [data.column(i) for i in partition_indices] partitions = [] for row_idx in range(data.num_rows): @@ -124,15 +128,12 @@ def __init__(self, table_schema: TableSchema): if self.num_buckets <= 0: raise ValueError(f"Fixed bucket mode requires bucket > 0, got {self.num_buckets}") - # Bucket-key resolution lives on TableSchema (mirrors Java - # ``TableSchema.bucketKeys()`` / ``logicalBucketKeyType()``); reuse - # it so any reader path that walks the same logic stays in sync. self.bucket_keys = table_schema.bucket_keys - self.bucket_key_indices = self._get_field_indices(self.bucket_keys) self._bucket_key_fields = table_schema.logical_bucket_key_fields def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]: - columns = [data.column(i) for i in self.bucket_key_indices] + bucket_key_indices = self._get_data_field_indices(data, self.bucket_keys) + columns = [data.column(i) for i in bucket_key_indices] return [ _bucket_from_hash( self._binary_row_hash_code(tuple(col[row_idx].as_py() for col in columns)), @@ -278,24 +279,13 @@ def __init__(self, table_schema: 'TableSchema'): target_bucket_row_number=opts.dynamic_bucket_target_row_num(), max_buckets_num=opts.dynamic_bucket_max_buckets(), ) - # TODO: extract bucket key init logic to base class (shared with FixedBucketRowKeyExtractor) - bucket_key_option = opts.bucket_key() - if bucket_key_option and bucket_key_option.strip(): - self.bucket_keys = [k.strip() for k in bucket_key_option.split(',')] - else: - self.bucket_keys = [ - pk for pk in table_schema.primary_keys - if pk not in table_schema.partition_keys - ] - self.bucket_key_indices = self._get_field_indices(self.bucket_keys) - field_map = {f.name: f for f in table_schema.fields} - self._bucket_key_fields = [ - field_map[name] for name in self.bucket_keys if name in field_map - ] + self.bucket_keys = table_schema.bucket_keys + self._bucket_key_fields = table_schema.logical_bucket_key_fields def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]: partitions = self._extract_partitions_batch(data) - columns = [data.column(i) for i in self.bucket_key_indices] + bucket_key_indices = self._get_data_field_indices(data, self.bucket_keys) + columns = [data.column(i) for i in bucket_key_indices] buckets = [] for row_idx in range(data.num_rows): key_hash = _hash_bytes_by_words( diff --git a/paimon-python/pypaimon/write/writer/data_writer.py b/paimon-python/pypaimon/write/writer/data_writer.py index ef8ef3a2df84..37b20fc70482 100644 --- a/paimon-python/pypaimon/write/writer/data_writer.py +++ b/paimon-python/pypaimon/write/writer/data_writer.py @@ -31,6 +31,38 @@ from pypaimon.table.row.generic_row import GenericRow +def _truncate_min(value, length): + if value is None: + return None + if isinstance(value, (bytes, str)) and len(value) > length: + return value[:length] + return value + + +def _truncate_max(value, length): + if value is None: + return None + if isinstance(value, bytes): + if len(value) <= length: + return value + truncated = bytearray(value[:length]) + for i in range(len(truncated) - 1, -1, -1): + if truncated[i] < 0xFF: + truncated[i] += 1 + return bytes(truncated[:i + 1]) + return None + if isinstance(value, str): + if len(value) <= length: + return value + truncated = value[:length] + for i in range(len(truncated) - 1, -1, -1): + next_cp = ord(truncated[i]) + 1 + if next_cp <= 0x10FFFF and not 0xD800 <= next_cp <= 0xDFFF: + return truncated[:i] + chr(next_cp) + return None + return value + + class DataWriter(ABC): """Base class for data writers that handle PyArrow tables directly.""" @@ -63,6 +95,8 @@ def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, op self.committed_files: List[DataFileMeta] = [] self.write_cols = write_cols self.blob_as_descriptor = self.options.blob_as_descriptor() + self.stats_mode = self.options.metadata_stats_mode() + self._parse_truncate_length(self.stats_mode) self.path_factory = self.table.path_factory() self.external_path_provider: Optional[ExternalPathProvider] = self.path_factory.create_external_path_provider( @@ -210,24 +244,33 @@ def _write_data_to_file(self, data: pa.Table): min_key = [col.to_pylist()[0] for col in min_key_row_batch.columns] max_key = [col.to_pylist()[0] for col in max_key_row_batch.columns] - # key stats & value stats - value_stats_enabled = self.options.metadata_stats_enabled() - if value_stats_enabled: - stats_fields = self.table.fields if self.table.is_primary_key_table \ - else PyarrowFieldParser.to_paimon_schema(data.schema) - else: - stats_fields = self.table.trimmed_primary_keys_fields - column_stats = { - field.name: self._get_column_stats(data, field.name) - for field in stats_fields - } + # key stats (always computed with "full" mode, not affected by stats mode) key_fields = self.trimmed_primary_keys_fields - key_stats = self._collect_value_stats(data, key_fields, column_stats) + key_stats = self._collect_value_stats(data, key_fields, mode="full") if not all(count == 0 for count in key_stats.null_counts): raise RuntimeError("Primary key should not be null") - value_fields = stats_fields if value_stats_enabled else [] - value_stats = self._collect_value_stats(data, value_fields, column_stats) + # value stats + value_stats_enabled = self.options.metadata_stats_enabled() + if value_stats_enabled: + if self.table.is_primary_key_table: + data_col_names = set(data.schema.names) + value_fields = [ + field + for field in self.table.fields + if field.name in data_col_names + ] + if len(value_fields) < len(self.table.fields): + value_stats_cols = [field.name for field in value_fields] + else: + value_stats_cols = None + else: + value_fields = PyarrowFieldParser.to_paimon_schema(data.schema) + value_stats_cols = None if self.write_cols is None else [field.name for field in value_fields] + else: + value_fields = [] + value_stats_cols = [] + value_stats = self._collect_value_stats(data, value_fields) min_seq = self.sequence_generator.start max_seq = self.sequence_generator.current @@ -248,11 +291,10 @@ def _write_data_to_file(self, data: pa.Table): creation_time=Timestamp.now(), delete_row_count=0, file_source=0, - value_stats_cols=None if value_stats_enabled else [], + value_stats_cols=value_stats_cols, external_path=external_path_str, # Set external path if using external paths first_row_id=None, write_cols=self.write_cols, - # None means all columns in the table have been written file_path=file_path, )) @@ -312,16 +354,15 @@ def _find_optimal_split_point(data: pa.RecordBatch, target_size: int) -> int: return best_split - def _collect_value_stats(self, data: pa.Table, fields: List, - column_stats: Optional[Dict[str, Dict]] = None) -> SimpleStats: + def _collect_value_stats(self, data: pa.Table, fields: List, mode: str = None) -> SimpleStats: if not fields: return SimpleStats.empty_stats() - - if column_stats is None or not column_stats: - column_stats = { - field.name: self._get_column_stats(data, field.name) - for field in fields - } + + m = mode or self.stats_mode + column_stats = { + field.name: self._get_column_stats(data, field.name, m) + for field in fields + } min_stats = [column_stats[field.name]['min_values'] for field in fields] max_stats = [column_stats[field.name]['max_values'] for field in fields] @@ -334,32 +375,66 @@ def _collect_value_stats(self, data: pa.Table, fields: List, ) @staticmethod - def _get_column_stats(record_batch: pa.RecordBatch, column_name: str) -> Dict: + def _parse_truncate_length(mode: str): + return CoreOptions.parse_metadata_stats_mode(mode) + + @staticmethod + def _get_column_stats(record_batch: pa.RecordBatch, column_name: str, + mode: str = "truncate(16)") -> Dict: + parsed_mode, truncate_length = DataWriter._parse_truncate_length(mode) + + if parsed_mode == "NONE": + return { + "min_values": None, + "max_values": None, + "null_counts": None, + } + column_array = record_batch.column(column_name) + + if parsed_mode == "COUNTS": + return { + "min_values": None, + "max_values": None, + "null_counts": column_array.null_count, + } + if column_array.null_count == len(column_array): return { "min_values": None, "max_values": None, "null_counts": column_array.null_count, } - + column_type = column_array.type - supports_minmax = not (pa.types.is_nested(column_type) or pa.types.is_map(column_type)) - + supports_minmax = not (pa.types.is_nested(column_type) or pa.types.is_map(column_type) + or pa.types.is_binary(column_type) + or pa.types.is_large_binary(column_type) + or pa.types.is_fixed_size_binary(column_type) + or (pa.types.is_decimal(column_type) and column_type.precision > 18) + or (pa.types.is_timestamp(column_type) + and (column_type.tz is not None or column_type.unit not in ('s', 'ms')))) + if not supports_minmax: return { "min_values": None, "max_values": None, "null_counts": column_array.null_count, } - + min_values = pc.min(column_array).as_py() max_values = pc.max(column_array).as_py() - null_counts = column_array.null_count + + if truncate_length is not None: + min_values = _truncate_min(min_values, truncate_length) + max_values = _truncate_max(max_values, truncate_length) + if max_values is None: + min_values = None + return { "min_values": min_values, "max_values": max_values, - "null_counts": null_counts, + "null_counts": column_array.null_count, }