diff --git a/paimon-python/pypaimon/tests/table_upsert_by_key_test.py b/paimon-python/pypaimon/tests/table_upsert_by_key_test.py index f276e44ab8ea..cf612a968cb0 100644 --- a/paimon-python/pypaimon/tests/table_upsert_by_key_test.py +++ b/paimon-python/pypaimon/tests/table_upsert_by_key_test.py @@ -139,6 +139,82 @@ def test_mixed_update_and_append(self): ) self.assertEqual([(1, 'Alice'), (2, 'Bob_new'), (3, 'Carol')], rows) + def test_upsert_for_existing_table_duplicate_keys(self): + table = self._create_table() + self._write_arrow(table, pa.Table.from_pydict({ + 'id': [1], 'name': ['old_A'], 'age': [10], 'city': ['X'], + }, schema=self.pa_schema)) + self._write_arrow(table, pa.Table.from_pydict({ + 'id': [1], 'name': ['old_B'], 'age': [20], 'city': ['Y'], + }, schema=self.pa_schema)) + + self._upsert(table, pa.Table.from_pydict({ + 'id': [1], 'name': ['UPDATED'], 'age': [99], 'city': ['Z'], + }, schema=self.pa_schema), upsert_keys=['id']) + + result = self._read_all(table) + names = sorted(n for i, n in zip(result['id'].to_pylist(), + result['name'].to_pylist()) if i == 1) + self.assertEqual(['UPDATED', 'UPDATED'], names) + + def test_existing_duplicate_keys_partial_update_cols(self): + """update_cols restricts which columns are rewritten; every matching + row is still updated, other columns keep each row's own value.""" + table = self._create_table() + self._write_arrow(table, pa.Table.from_pydict({ + 'id': [1], 'name': ['old_A'], 'age': [10], 'city': ['X'], + }, schema=self.pa_schema)) + self._write_arrow(table, pa.Table.from_pydict({ + 'id': [1], 'name': ['old_B'], 'age': [20], 'city': ['Y'], + }, schema=self.pa_schema)) + + self._upsert(table, pa.Table.from_pydict({ + 'id': [1], 'name': ['UPDATED'], 'age': [99], 'city': ['Z'], + }, schema=self.pa_schema), upsert_keys=['id'], update_cols=['name']) + + result = self._read_all(table) + rows = sorted(zip(result['id'].to_pylist(), result['name'].to_pylist(), + result['age'].to_pylist(), result['city'].to_pylist())) + self.assertEqual([(1, 'UPDATED', 10, 'X'), (1, 'UPDATED', 20, 'Y')], rows) + + def test_existing_duplicate_keys_partitioned(self): + """Duplicate keys within a partition are all updated; rows in other + partitions are untouched.""" + table = self._create_table( + pa_schema=self.partitioned_pa_schema, partition_keys=['region']) + self._write_arrow(table, pa.Table.from_pydict({ + 'id': [1, 1], 'name': ['a1', 'a2'], 'age': [10, 20], 'region': ['A', 'A'], + }, schema=self.partitioned_pa_schema)) + self._write_arrow(table, pa.Table.from_pydict({ + 'id': [1], 'name': ['b1'], 'age': [30], 'region': ['B'], + }, schema=self.partitioned_pa_schema)) + + self._upsert(table, pa.Table.from_pydict({ + 'id': [1], 'name': ['UPDATED'], 'age': [99], 'region': ['A'], + }, schema=self.partitioned_pa_schema), upsert_keys=['id']) + + result = self._read_all(table) + rows = sorted(zip(result['id'].to_pylist(), result['name'].to_pylist(), + result['region'].to_pylist())) + self.assertEqual( + [(1, 'UPDATED', 'A'), (1, 'UPDATED', 'A'), (1, 'b1', 'B')], rows) + + def test_multiple_keys_each_with_duplicates(self): + """One upsert updates every matching row across several keys.""" + table = self._create_table() + self._write_arrow(table, pa.Table.from_pydict({ + 'id': [1, 1, 2, 2], 'name': ['a', 'b', 'c', 'd'], + 'age': [1, 2, 3, 4], 'city': ['p', 'q', 'r', 's'], + }, schema=self.pa_schema)) + + self._upsert(table, pa.Table.from_pydict({ + 'id': [1, 2], 'name': ['U1', 'U2'], 'age': [10, 20], 'city': ['X', 'Y'], + }, schema=self.pa_schema), upsert_keys=['id']) + + result = self._read_all(table) + names = sorted(zip(result['id'].to_pylist(), result['name'].to_pylist())) + self.assertEqual([(1, 'U1'), (1, 'U1'), (2, 'U2'), (2, 'U2')], names) + def test_composite_key_upsert(self): """Upsert with a multi-column composite key.""" table = self._create_table() @@ -149,8 +225,7 @@ def test_composite_key_upsert(self): 'city': ['NYC', 'LA', 'Chicago'], }, schema=self.pa_schema)) - # (id, name) = (1, Alice) appears twice in the table → matches the - # first occurrence; (2, Carol) is new. + # (id, name) = (1, Alice) appears twice → both are updated; (2, Carol) is new. self._upsert(table, pa.Table.from_pydict({ 'id': [1, 2], 'name': ['Alice', 'Carol'], @@ -165,7 +240,12 @@ def test_composite_key_upsert(self): result['name'].to_pylist(), result['city'].to_pylist(), )) - self.assertIn((2, 'Carol', 'Dallas'), rows) + self.assertEqual([ + (1, 'Alice', 'Updated'), + (1, 'Alice', 'Updated'), + (2, 'Bob', 'Chicago'), + (2, 'Carol', 'Dallas'), + ], rows) def test_sequential_upserts(self): """A second upsert sees the rows inserted by the first.""" diff --git a/paimon-python/pypaimon/write/table_update.py b/paimon-python/pypaimon/write/table_update.py index 4b063dfa7bf5..8271b19f3789 100644 --- a/paimon-python/pypaimon/write/table_update.py +++ b/paimon-python/pypaimon/write/table_update.py @@ -154,8 +154,8 @@ def _upsert_by_arrow_with_key( For each row in the input Arrow table: - * If a row with the same composite ``upsert_keys`` value already - exists → update that row in-place. + * If one or more rows with the same composite ``upsert_keys`` value + already exist → update all of them in-place. * Otherwise → append as a new row. The public method lives on the concrete subclasses so each can diff --git a/paimon-python/pypaimon/write/table_upsert_by_key.py b/paimon-python/pypaimon/write/table_upsert_by_key.py index 42d3ced13ec4..4faf02fea0bd 100644 --- a/paimon-python/pypaimon/write/table_upsert_by_key.py +++ b/paimon-python/pypaimon/write/table_upsert_by_key.py @@ -37,8 +37,8 @@ class TableUpsertByKey: Table upsert by one or more user-specified key columns for append-only tables. For each row in the input Arrow table: - - If a row with the same upsert_keys composite value already exists → update that row - (in-place rewrite). + - If one or more rows with the same upsert_keys composite value already exist → + update all of them (in-place rewrite). - If no matching row exists → append as a new row. All upsert_keys must be columns present in both the input data and the table schema. @@ -168,9 +168,10 @@ def _upsert_partition( partition_data, input_key_tuples, partition_spec, ) - # 3. Scan partition once, keeping only key → _ROW_ID pairs that - # appear in the input (memory ∝ |input|, not |partition|). - key_to_row_id = self._build_key_to_row_id_map( + # 3. Scan partition once, keeping key → [_ROW_ID, ...] for keys that + # appear in the input (memory ∝ matched existing rows, not the + # whole partition). + key_to_row_ids = self._build_key_to_row_ids_map( match_keys, partition_spec, set(input_key_tuples), ) @@ -178,18 +179,26 @@ def _upsert_partition( matched_indices: List[int] = [] new_indices: List[int] = [] for i, key_tuple in enumerate(input_key_tuples): - (matched_indices if key_tuple in key_to_row_id else new_indices).append(i) + (matched_indices if key_tuple in key_to_row_ids else new_indices).append(i) logger.info( "Upserting partition %s: %d matched, %d new", partition_spec, len(matched_indices), len(new_indices), ) + total_updates = sum( + len(key_to_row_ids[input_key_tuples[i]]) for i in matched_indices) + if total_updates > len(matched_indices): + logger.info( + "Upsert fan-out in partition %s: %d input rows expand to " + "%d row updates", partition_spec, + len(matched_indices), total_updates, + ) commit_messages: List[CommitMessage] = [] if matched_indices: commit_messages.extend(self._do_updates( partition_data, matched_indices, - input_key_tuples, key_to_row_id, update_cols, + input_key_tuples, key_to_row_ids, update_cols, )) if new_indices: commit_messages.extend(self._do_appends(partition_data, new_indices)) @@ -274,14 +283,14 @@ def _validate_inputs(self, data: pa.Table, upsert_keys: List[str], # that partition columns can be stripped first. The same non-partition # key may legally appear in different partitions. - def _build_key_to_row_id_map( + def _build_key_to_row_ids_map( self, match_keys: List[str], partition_spec: Optional[Dict[str, Any]], input_key_set: set, - ) -> Dict[_KeyTuple, int]: + ) -> Dict[_KeyTuple, List[int]]: """ - Scan the partition in batches and collect key → _ROW_ID only for + Scan the partition in batches and collect key → [_ROW_ID, ...] for rows whose composite key is in *input_key_set*. The partition spec (if any) is pushed down as an ``and`` of per-key @@ -322,7 +331,7 @@ def _build_key_to_row_id_map( ) # Stream batches and filter against input_key_set on-the-fly - key_to_row_id: Dict[_KeyTuple, int] = {} + key_to_row_ids: Dict[_KeyTuple, List[int]] = {} row_id_col = SpecialFields.ROW_ID.name for batch in table_read.to_arrow_batch_reader(splits): batch_key_cols = [batch.column(k).to_pylist() for k in match_keys] @@ -330,27 +339,28 @@ def _build_key_to_row_id_map( for j, row_id in enumerate(batch_row_ids): key_tuple = tuple(col[j] for col in batch_key_cols) if key_tuple in input_key_set: - key_to_row_id[key_tuple] = row_id + key_to_row_ids.setdefault(key_tuple, []).append(row_id) - return key_to_row_id + return key_to_row_ids def _do_updates( self, data: pa.Table, matched_indices: List[int], input_key_tuples: List[_KeyTuple], - key_to_row_id: Dict[_KeyTuple, int], + key_to_row_ids: Dict[_KeyTuple, List[int]], update_cols: Optional[List[str]] ) -> List[CommitMessage]: - """Update matched rows by rewriting them in-place via - :class:`TableUpdateByRowId`.""" - matched_data = data.take(matched_indices) - row_id_array = pa.array( - [key_to_row_id[input_key_tuples[i]] for i in matched_indices], - type=pa.int64(), - ) - update_data = matched_data.append_column( - SpecialFields.ROW_ID.name, row_id_array, + """Update matched rows in-place via :class:`TableUpdateByRowId`.""" + expanded_input_indices: List[int] = [] + row_ids: List[int] = [] + for i in matched_indices: + for row_id in key_to_row_ids[input_key_tuples[i]]: + expanded_input_indices.append(i) + row_ids.append(row_id) + + update_data = data.take(expanded_input_indices).append_column( + SpecialFields.ROW_ID.name, pa.array(row_ids, type=pa.int64()), ) cols_to_update = list(update_cols) if update_cols else list(self.table.field_names)