From 0ffccfef091b4bdd5e7b9088df946a7f1d55c2cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=96=86=E5=AE=87?= Date: Wed, 29 Apr 2026 17:13:20 +0800 Subject: [PATCH] [python] add vector type in DataType system --- .../org/apache/paimon/data/BinaryRowTest.java | 5 ++ .../paimon/data/DataFormatTestUtil.java | 44 +++++------ .../java/org/apache/paimon/JavaPyE2ETest.java | 53 +++++++++++++ paimon-python/dev/run_mixed_tests.sh | 39 +++++++++- paimon-python/pypaimon/schema/data_types.py | 78 ++++++++++++++++++- .../pypaimon/tests/data_types_test.py | 42 +++++++++- .../tests/e2e/java_py_read_write_test.py | 26 +++++++ paimon-python/pypaimon/tests/rest/api_test.py | 13 +++- 8 files changed, 273 insertions(+), 27 deletions(-) diff --git a/paimon-common/src/test/java/org/apache/paimon/data/BinaryRowTest.java b/paimon-common/src/test/java/org/apache/paimon/data/BinaryRowTest.java index 6dd3c23bdb87..d102f3b4d51a 100644 --- a/paimon-common/src/test/java/org/apache/paimon/data/BinaryRowTest.java +++ b/paimon-common/src/test/java/org/apache/paimon/data/BinaryRowTest.java @@ -550,6 +550,11 @@ public void testBinaryVector() { InternalVector vector2 = row.getVector(0); assertThat(vector2.size()).isEqualTo(vector.size()); assertThat(vector2.toFloatArray()).isEqualTo(vector.toFloatArray()); + assertThat( + DataFormatTestUtil.toStringNoRowKind( + row, + RowType.of(DataTypes.VECTOR(vector.size(), DataTypes.FLOAT())))) + .isEqualTo(Arrays.toString(vector.toFloatArray())); } @Test diff --git a/paimon-common/src/test/java/org/apache/paimon/data/DataFormatTestUtil.java b/paimon-common/src/test/java/org/apache/paimon/data/DataFormatTestUtil.java index c6e2ceab7504..2e87eb557f79 100644 --- a/paimon-common/src/test/java/org/apache/paimon/data/DataFormatTestUtil.java +++ b/paimon-common/src/test/java/org/apache/paimon/data/DataFormatTestUtil.java @@ -23,6 +23,7 @@ import org.apache.paimon.types.DataType; import org.apache.paimon.types.MapType; import org.apache.paimon.types.RowType; +import org.apache.paimon.types.VectorType; import org.apache.paimon.utils.StringUtils; import java.util.Arrays; @@ -48,16 +49,7 @@ public static String toStringNoRowKind(InternalRow row, RowType type) { if (field instanceof byte[]) { build.append(Arrays.toString((byte[]) field)); } else if (field instanceof InternalArray) { - InternalArray internalArray = (InternalArray) field; - ArrayType arrayType = (ArrayType) type.getTypeAt(i); - InternalArray.ElementGetter elementGetter = - InternalArray.createElementGetter(arrayType.getElementType()); - String[] result = new String[internalArray.size()]; - for (int j = 0; j < internalArray.size(); j++) { - Object object = elementGetter.getElementOrNull(internalArray, j); - result[j] = null == object ? null : object.toString(); - } - build.append(Arrays.toString(result)); + build.append(getArrayLikeString((InternalArray) field, type.getTypeAt(i))); } else { build.append(field); } @@ -90,19 +82,7 @@ public static String getDataFieldString(Object field, DataType type) { if (field instanceof byte[]) { return Arrays.toString((byte[]) field); } else if (field instanceof InternalArray) { - InternalArray internalArray = (InternalArray) field; - ArrayType arrayType = (ArrayType) type; - InternalArray.ElementGetter elementGetter = - InternalArray.createElementGetter(arrayType.getElementType()); - String[] result = new String[internalArray.size()]; - for (int j = 0; j < internalArray.size(); j++) { - Object object = elementGetter.getElementOrNull(internalArray, j); - result[j] = - null == object - ? null - : getDataFieldString(object, arrayType.getElementType()); - } - return Arrays.toString(result); + return getArrayLikeString((InternalArray) field, type); } else if (field instanceof InternalRow) { return String.format("(%s)", toStringWithRowKind((InternalRow) field, (RowType) type)); } else if (field instanceof InternalMap) { @@ -139,6 +119,24 @@ public static String getDataFieldString(Object field, DataType type) { } } + private static String getArrayLikeString(InternalArray internalArray, DataType type) { + DataType elementType; + if (type instanceof VectorType) { + elementType = ((VectorType) type).getElementType(); + } else if (type instanceof ArrayType) { + elementType = ((ArrayType) type).getElementType(); + } else { + throw new IllegalArgumentException("Unsupported type for array data: " + type); + } + InternalArray.ElementGetter elementGetter = InternalArray.createElementGetter(elementType); + String[] result = new String[internalArray.size()]; + for (int j = 0; j < internalArray.size(); j++) { + Object object = elementGetter.getElementOrNull(internalArray, j); + result[j] = null == object ? null : getDataFieldString(object, elementType); + } + return Arrays.toString(result); + } + /** Stringify the given {@link InternalRow}. */ public static String internalRowToString(InternalRow row, RowType type) { return row.getRowKind().shortString() + "[" + toStringNoRowKind(row, type) + ']'; diff --git a/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java b/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java index b1b02d19cc57..86cea365c7a1 100644 --- a/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java +++ b/paimon-core/src/test/java/org/apache/paimon/JavaPyE2ETest.java @@ -25,6 +25,7 @@ import org.apache.paimon.catalog.CatalogFactory; import org.apache.paimon.catalog.Identifier; import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.BinaryVector; import org.apache.paimon.data.DataFormatTestUtil; import org.apache.paimon.data.GenericRow; import org.apache.paimon.data.InternalRow; @@ -798,6 +799,58 @@ public void testJavaWriteCompressedTextAppendTable() throws Exception { } } + @Test + @EnabledIfSystemProperty(named = "run.e2e.tests", matches = "true") + public void testJavaWriteVectorAppendTable() throws Exception { + Identifier identifier = identifier("mixed_test_vector_append_tablej_avro"); + catalog.dropTable(identifier, true); + Schema schema = + Schema.newBuilder() + .column("id", DataTypes.INT()) + .column("embedding", DataTypes.VECTOR(3, DataTypes.FLOAT())) + .column("label", DataTypes.STRING()) + .option("file.format", "avro") + .option("bucket", "-1") + .build(); + + catalog.createTable(identifier, schema, true); + FileStoreTable table = (FileStoreTable) catalog.getTable(identifier); + + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + try (BatchTableWrite write = writeBuilder.newWrite(); + BatchTableCommit commit = writeBuilder.newCommit()) { + write.write( + GenericRow.of( + 1, + BinaryVector.fromPrimitiveArray(new float[] {1.0f, 2.0f, 3.0f}), + BinaryString.fromString("first"))); + write.write( + GenericRow.of( + 2, + BinaryVector.fromPrimitiveArray(new float[] {4.0f, 5.0f, 6.0f}), + BinaryString.fromString("second"))); + write.write( + GenericRow.of( + 3, + BinaryVector.fromPrimitiveArray(new float[] {-1.0f, 0.5f, 2.5f}), + BinaryString.fromString("third"))); + commit.commit(write.prepareCommit()); + } + + List splits = new ArrayList<>(table.newSnapshotReader().read().dataSplits()); + TableRead read = table.newRead(); + List res = + getResult( + read, + splits, + row -> DataFormatTestUtil.toStringNoRowKind(row, table.rowType())); + assertThat(res) + .containsExactlyInAnyOrder( + "1, [1.0, 2.0, 3.0], first", + "2, [4.0, 5.0, 6.0], second", + "3, [-1.0, 0.5, 2.5], third"); + } + @Test @EnabledIfSystemProperty(named = "run.e2e.tests", matches = "true") public void testBlobWriteAlterCompact() throws Exception { diff --git a/paimon-python/dev/run_mixed_tests.sh b/paimon-python/dev/run_mixed_tests.sh index acfb583cf577..007805573366 100755 --- a/paimon-python/dev/run_mixed_tests.sh +++ b/paimon-python/dev/run_mixed_tests.sh @@ -242,6 +242,29 @@ run_compressed_text_test() { fi } +run_vector_append_table_test() { + echo -e "${YELLOW}=== Running Vector Append Table Test (Java Write, Python Read) ===${NC}" + + cd "$PROJECT_ROOT" + + echo "Running Maven test for JavaPyE2ETest.testJavaWriteVectorAppendTable..." + if mvn test -Dtest=org.apache.paimon.JavaPyE2ETest#testJavaWriteVectorAppendTable -pl paimon-core -q -Drun.e2e.tests=true; then + echo -e "${GREEN}✓ Java test completed successfully${NC}" + else + echo -e "${RED}✗ Java test failed${NC}" + return 1 + fi + cd "$PAIMON_PYTHON_DIR" + echo "Running Python test for JavaPyReadWriteTest.test_read_vector_append_table..." + if python -m pytest java_py_read_write_test.py::JavaPyReadWriteTest::test_read_vector_append_table -v; then + echo -e "${GREEN}✓ Python test completed successfully${NC}" + return 0 + else + echo -e "${RED}✗ Python test failed${NC}" + return 1 + fi +} + # Function to run Tantivy full-text index test (Java write index, Python read and search) run_tantivy_fulltext_test() { echo -e "${YELLOW}=== Step 8: Running Tantivy Full-Text Index Test (Java Write, Python Read) ===${NC}" @@ -504,6 +527,7 @@ main() { local pk_dv_result=0 local btree_index_result=0 local compressed_text_result=0 + local vector_append_table_result=0 local tantivy_fulltext_result=0 local lumina_vector_result=0 local lumina_vector_btree_result=0 @@ -576,6 +600,13 @@ main() { echo "" + # Run Vector append table test (Java write, Python read) + if ! run_vector_append_table_test; then + vector_append_table_result=1 + fi + + echo "" + # Run Tantivy full-text index test (requires Python >= 3.10) if [[ "$PYTHON_MINOR" -ge 10 ]]; then if ! run_tantivy_fulltext_test; then @@ -693,6 +724,12 @@ main() { echo -e "${RED}✗ Compressed Text Test (Java Write, Python Read): FAILED${NC}" fi + if [[ $vector_append_table_result -eq 0 ]]; then + echo -e "${GREEN}✓ Vector Append Table Test (Java Write, Python Read): PASSED${NC}" + else + echo -e "${RED}✗ Vector Append Table Test (Java Write, Python Read): FAILED${NC}" + fi + if [[ $tantivy_fulltext_result -eq 0 ]]; then echo -e "${GREEN}✓ Tantivy Full-Text Index Test (Java Write, Python Read): PASSED${NC}" else @@ -752,7 +789,7 @@ main() { # Clean up warehouse directory after all tests cleanup_warehouse - if [[ $java_write_result -eq 0 && $python_read_result -eq 0 && $python_write_result -eq 0 && $java_read_result -eq 0 && $pk_dv_result -eq 0 && $btree_index_result -eq 0 && $compressed_text_result -eq 0 && $tantivy_fulltext_result -eq 0 && $lumina_vector_result -eq 0 && $lumina_vector_btree_result -eq 0 && $compact_conflict_result -eq 0 && $blob_alter_compact_result -eq 0 && $data_evolution_result -eq 0 && $data_evolution_py_write_result -eq 0 && $java_variant_write_py_read_result -eq 0 && $py_variant_write_java_read_result -eq 0 ]]; then + if [[ $java_write_result -eq 0 && $python_read_result -eq 0 && $python_write_result -eq 0 && $java_read_result -eq 0 && $pk_dv_result -eq 0 && $btree_index_result -eq 0 && $compressed_text_result -eq 0 && $tantivy_fulltext_result -eq 0 && $lumina_vector_result -eq 0 && $lumina_vector_btree_result -eq 0 && $compact_conflict_result -eq 0 && $blob_alter_compact_result -eq 0 && $data_evolution_result -eq 0 && $data_evolution_py_write_result -eq 0 && $java_variant_write_py_read_result -eq 0 && $py_variant_write_java_read_result -eq 0 && $vector_append_table_result -eq 0 ]]; then echo -e "${GREEN}🎉 All tests passed! Java-Python interoperability verified.${NC}" return 0 else diff --git a/paimon-python/pypaimon/schema/data_types.py b/paimon-python/pypaimon/schema/data_types.py index 70639ba1a82d..45ae70ba672c 100755 --- a/paimon-python/pypaimon/schema/data_types.py +++ b/paimon-python/pypaimon/schema/data_types.py @@ -131,6 +131,66 @@ def __str__(self) -> str: return "ARRAY<{}>{}".format(self.element, null_suffix) +@dataclass +class VectorType(DataType): + element: DataType + length: int + + VALID_ELEMENT_TYPES = { + "BOOLEAN", + "TINYINT", + "SMALLINT", + "INT", + "INTEGER", + "BIGINT", + "FLOAT", + "DOUBLE", + } + + def __init__(self, nullable: bool, element_type: DataType, length: int): + super().__init__(nullable) + if length < 1: + raise ValueError("Vector length must be greater than or equal to 1.") + if not self.is_valid_element_type(element_type): + raise ValueError("Invalid element type for vector: {}".format(element_type)) + self.element = element_type + self.length = length + + @classmethod + def is_valid_element_type(cls, element_type: DataType) -> bool: + if not isinstance(element_type, AtomicType): + return False + return element_type.type.upper() in cls.VALID_ELEMENT_TYPES + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, VectorType): + return False + return (self.element == other.element + and self.length == other.length + and self.nullable == other.nullable) + + def __hash__(self): + return hash((self.element, self.length, self.nullable)) + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "VECTOR" + (" NOT NULL" if not self.nullable else ""), + "element": self.element.to_dict() if self.element else None, + "length": self.length, + "nullable": self.nullable + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "VectorType": + return DataTypeParser.parse_data_type(data) + + def __str__(self) -> str: + null_suffix = "" if self.nullable else " NOT NULL" + return "VECTOR<{}, {}>{}".format(self.element, self.length, null_suffix) + + @dataclass class MultisetType(DataType): element: DataType @@ -387,6 +447,14 @@ def parse_data_type( nullable = "NOT NULL" not in type_string return ArrayType(nullable, element) + elif type_string.startswith("VECTOR"): + element = DataTypeParser.parse_data_type( + json_data.get("element"), field_id + ) + length = int(json_data.get("length")) + nullable = "NOT NULL" not in type_string + return VectorType(nullable, element, length) + elif type_string.startswith("MULTISET"): element = DataTypeParser.parse_data_type( json_data.get("element"), field_id @@ -536,6 +604,8 @@ def from_paimon_type(data_type: DataType) -> pyarrow.DataType: return pyarrow.time32('ms') elif isinstance(data_type, ArrayType): return pyarrow.list_(PyarrowFieldParser.from_paimon_type(data_type.element)) + elif isinstance(data_type, VectorType): + return pyarrow.list_(PyarrowFieldParser.from_paimon_type(data_type.element), data_type.length) elif isinstance(data_type, MapType): key_type = PyarrowFieldParser.from_paimon_type(data_type.key) value_type = PyarrowFieldParser.from_paimon_type(data_type.value) @@ -603,6 +673,10 @@ def to_paimon_type(pa_type: pyarrow.DataType, nullable: bool) -> DataType: type_name = 'DATE' elif types.is_time(pa_type): type_name = 'TIME(0)' + elif types.is_fixed_size_list(pa_type): + pa_type: pyarrow.FixedSizeListType + element_type = PyarrowFieldParser.to_paimon_type(pa_type.value_type, pa_type.value_field.nullable) + return VectorType(nullable, element_type, pa_type.list_size) elif types.is_list(pa_type) or types.is_large_list(pa_type): pa_type: pyarrow.ListType element_type = PyarrowFieldParser.to_paimon_type(pa_type.value_type, nullable) @@ -697,7 +771,9 @@ def to_avro_type(field_type: pyarrow.DataType, field_name: str, return {"type": "long", "logicalType": "local-timestamp-micros"} else: raise ValueError(f"Avro does not support pyarrow timestamp with unit {unit}.") - elif pyarrow.types.is_list(field_type) or pyarrow.types.is_large_list(field_type): + elif pyarrow.types.is_fixed_size_list(field_type) or \ + pyarrow.types.is_list(field_type) or \ + pyarrow.types.is_large_list(field_type): value_field = field_type.value_field return { "type": "array", diff --git a/paimon-python/pypaimon/tests/data_types_test.py b/paimon-python/pypaimon/tests/data_types_test.py index de28235f388c..ce3e7d2e20c2 100755 --- a/paimon-python/pypaimon/tests/data_types_test.py +++ b/paimon-python/pypaimon/tests/data_types_test.py @@ -20,7 +20,7 @@ import pyarrow as pa from pypaimon.schema.data_types import (DataField, AtomicType, ArrayType, MultisetType, MapType, - RowType, PyarrowFieldParser) + RowType, VectorType, PyarrowFieldParser) class DataTypesTest(unittest.TestCase): @@ -59,6 +59,32 @@ def test_map_type(self): self.assertEqual(str(MapType(True, AtomicType("STRING"), AtomicType("TIMESTAMP(6)"))), "MAP") + def test_vector_type(self): + vector_type = VectorType(True, AtomicType("FLOAT"), 3) + self.assertEqual(str(vector_type), "VECTOR") + self.assertEqual( + vector_type.to_dict(), + { + "type": "VECTOR", + "element": "FLOAT", + "length": 3, + "nullable": True + } + ) + self.assertEqual(vector_type, VectorType.from_dict(vector_type.to_dict())) + self.assertEqual(hash(vector_type), hash(VectorType(True, AtomicType("FLOAT"), 3))) + + not_null_vector = VectorType(False, AtomicType("FLOAT", nullable=False), 3) + self.assertEqual(str(not_null_vector), "VECTOR NOT NULL") + self.assertEqual(not_null_vector, VectorType.from_dict(not_null_vector.to_dict())) + + with self.assertRaises(ValueError): + VectorType(True, AtomicType("FLOAT"), 0) + with self.assertRaises(ValueError): + VectorType(True, AtomicType("STRING"), 3) + with self.assertRaises(ValueError): + VectorType(True, ArrayType(True, AtomicType("INT")), 3) + def test_row_type(self): self.assertEqual(str(RowType(True, [DataField(0, "a", AtomicType("STRING"), "Someone's desc."), DataField(1, "b", AtomicType("TIMESTAMP(6)"),)])), @@ -135,6 +161,20 @@ def test_nested_field_roundtrip(self): self.assertEqual(converted_nested_field.fields[0].name, "inner_field1") self.assertEqual(converted_nested_field.fields[1].name, "inner_field2") + def test_vector_pyarrow_roundtrip(self): + paimon_vector = VectorType(True, AtomicType("FLOAT"), 3) + pa_type = PyarrowFieldParser.from_paimon_type(paimon_vector) + + self.assertTrue(pa.types.is_fixed_size_list(pa_type)) + self.assertEqual(pa_type.list_size, 3) + self.assertTrue(pa.types.is_float32(pa_type.value_type)) + + converted_paimon_vector = PyarrowFieldParser.to_paimon_type(pa_type, nullable=True) + self.assertEqual(converted_paimon_vector, paimon_vector) + + avro_type = PyarrowFieldParser.to_avro_type(pa_type, "embedding") + self.assertEqual(avro_type, {"type": "array", "items": "float"}) + def test_time_type(self): pa_type = PyarrowFieldParser.from_paimon_type(AtomicType("TIME")) self.assertEqual(pa_type, pa.time32('ms')) diff --git a/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py b/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py index 1f77f36291e3..fde72fcc4f9d 100644 --- a/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py +++ b/paimon-python/pypaimon/tests/e2e/java_py_read_write_test.py @@ -26,6 +26,7 @@ from parameterized import parameterized from pypaimon.catalog.catalog_factory import CatalogFactory from pypaimon.data.generic_variant import GenericVariant +from pypaimon.schema.data_types import VectorType from pypaimon.schema.schema import Schema from pypaimon.read.read_builder import ReadBuilder @@ -529,6 +530,31 @@ def test_read_compressed_text_append_table(self, file_format): self.assertIn(file_format, str(ctx.exception)) self.assertIn("not yet supported", str(ctx.exception)) + def test_read_vector_append_table(self): + table = self.catalog.get_table('default.mixed_test_vector_append_tablej_avro') + embedding_field = next(field for field in table.fields if field.name == 'embedding') + self.assertIsInstance(embedding_field.type, VectorType) + self.assertEqual(embedding_field.type.length, 3) + self.assertEqual(embedding_field.type.element.type, 'FLOAT') + + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + pa_table = table_read.to_arrow(table_scan.plan().splits()) + pa_table = table_sort_by(pa_table, 'id') + + embedding_type = pa_table.schema.field('embedding').type + self.assertTrue(pa.types.is_fixed_size_list(embedding_type)) + self.assertEqual(embedding_type.list_size, 3) + self.assertTrue(pa.types.is_float32(embedding_type.value_type)) + + self.assertEqual(pa_table.column('id').to_pylist(), [1, 2, 3]) + self.assertEqual( + pa_table.column('embedding').to_pylist(), + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [-1.0, 0.5, 2.5]] + ) + self.assertEqual(pa_table.column('label').to_pylist(), ['first', 'second', 'third']) + def test_read_tantivy_full_text_index(self): """Test reading a Tantivy full-text index built by Java.""" table = self.catalog.get_table('default.test_tantivy_fulltext') diff --git a/paimon-python/pypaimon/tests/rest/api_test.py b/paimon-python/pypaimon/tests/rest/api_test.py index b43c16a81ec9..b4a5a9664fbd 100644 --- a/paimon-python/pypaimon/tests/rest/api_test.py +++ b/paimon-python/pypaimon/tests/rest/api_test.py @@ -31,7 +31,7 @@ from pypaimon.common.json_util import JSON from pypaimon.schema.data_types import (ArrayType, AtomicInteger, AtomicType, DataField, DataTypeParser, MapType, - RowType) + RowType, VectorType) from pypaimon.schema.table_schema import TableSchema from pypaimon.tests.rest.rest_server import RESTCatalogServer @@ -125,6 +125,17 @@ def test_parse_data(self): self.assertEqual(value_type.fields[0].type.type, 'BIGINT') self.assertEqual(value_type.fields[1].type.type, 'DOUBLE') + vector_json = { + "type": "VECTOR", + "element": "BOOLEAN NOT NULL", + "length": 7 + } + vector_type: VectorType = DataTypeParser.parse_data_type(vector_json, field_id) + self.assertTrue(vector_type.nullable) + self.assertEqual(vector_type.element.type, "BOOLEAN") + self.assertFalse(vector_type.element.nullable) + self.assertEqual(vector_type.length, 7) + def test_api(self): """Example usage of RESTCatalogServer""" # Setup logging