Skip to content
86 changes: 83 additions & 3 deletions paimon-python/pypaimon/tests/table_upsert_by_key_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,82 @@ def test_mixed_update_and_append(self):
)
self.assertEqual([(1, 'Alice'), (2, 'Bob_new'), (3, 'Carol')], rows)

def test_upsert_for_existing_table_duplicate_keys(self):
table = self._create_table()
self._write_arrow(table, pa.Table.from_pydict({
'id': [1], 'name': ['old_A'], 'age': [10], 'city': ['X'],
}, schema=self.pa_schema))
self._write_arrow(table, pa.Table.from_pydict({
'id': [1], 'name': ['old_B'], 'age': [20], 'city': ['Y'],
}, schema=self.pa_schema))

self._upsert(table, pa.Table.from_pydict({
'id': [1], 'name': ['UPDATED'], 'age': [99], 'city': ['Z'],
}, schema=self.pa_schema), upsert_keys=['id'])

result = self._read_all(table)
names = sorted(n for i, n in zip(result['id'].to_pylist(),
result['name'].to_pylist()) if i == 1)
self.assertEqual(['UPDATED', 'UPDATED'], names)

def test_existing_duplicate_keys_partial_update_cols(self):
"""update_cols restricts which columns are rewritten; every matching
row is still updated, other columns keep each row's own value."""
table = self._create_table()
self._write_arrow(table, pa.Table.from_pydict({
'id': [1], 'name': ['old_A'], 'age': [10], 'city': ['X'],
}, schema=self.pa_schema))
self._write_arrow(table, pa.Table.from_pydict({
'id': [1], 'name': ['old_B'], 'age': [20], 'city': ['Y'],
}, schema=self.pa_schema))

self._upsert(table, pa.Table.from_pydict({
'id': [1], 'name': ['UPDATED'], 'age': [99], 'city': ['Z'],
}, schema=self.pa_schema), upsert_keys=['id'], update_cols=['name'])

result = self._read_all(table)
rows = sorted(zip(result['id'].to_pylist(), result['name'].to_pylist(),
result['age'].to_pylist(), result['city'].to_pylist()))
self.assertEqual([(1, 'UPDATED', 10, 'X'), (1, 'UPDATED', 20, 'Y')], rows)

def test_existing_duplicate_keys_partitioned(self):
"""Duplicate keys within a partition are all updated; rows in other
partitions are untouched."""
table = self._create_table(
pa_schema=self.partitioned_pa_schema, partition_keys=['region'])
self._write_arrow(table, pa.Table.from_pydict({
'id': [1, 1], 'name': ['a1', 'a2'], 'age': [10, 20], 'region': ['A', 'A'],
}, schema=self.partitioned_pa_schema))
self._write_arrow(table, pa.Table.from_pydict({
'id': [1], 'name': ['b1'], 'age': [30], 'region': ['B'],
}, schema=self.partitioned_pa_schema))

self._upsert(table, pa.Table.from_pydict({
'id': [1], 'name': ['UPDATED'], 'age': [99], 'region': ['A'],
}, schema=self.partitioned_pa_schema), upsert_keys=['id'])

result = self._read_all(table)
rows = sorted(zip(result['id'].to_pylist(), result['name'].to_pylist(),
result['region'].to_pylist()))
self.assertEqual(
[(1, 'UPDATED', 'A'), (1, 'UPDATED', 'A'), (1, 'b1', 'B')], rows)

def test_multiple_keys_each_with_duplicates(self):
"""One upsert updates every matching row across several keys."""
table = self._create_table()
self._write_arrow(table, pa.Table.from_pydict({
'id': [1, 1, 2, 2], 'name': ['a', 'b', 'c', 'd'],
'age': [1, 2, 3, 4], 'city': ['p', 'q', 'r', 's'],
}, schema=self.pa_schema))

self._upsert(table, pa.Table.from_pydict({
'id': [1, 2], 'name': ['U1', 'U2'], 'age': [10, 20], 'city': ['X', 'Y'],
}, schema=self.pa_schema), upsert_keys=['id'])

result = self._read_all(table)
names = sorted(zip(result['id'].to_pylist(), result['name'].to_pylist()))
self.assertEqual([(1, 'U1'), (1, 'U1'), (2, 'U2'), (2, 'U2')], names)

def test_composite_key_upsert(self):
"""Upsert with a multi-column composite key."""
table = self._create_table()
Expand All @@ -149,8 +225,7 @@ def test_composite_key_upsert(self):
'city': ['NYC', 'LA', 'Chicago'],
}, schema=self.pa_schema))

# (id, name) = (1, Alice) appears twice in the table → matches the
# first occurrence; (2, Carol) is new.
# (id, name) = (1, Alice) appears twice → both are updated; (2, Carol) is new.
self._upsert(table, pa.Table.from_pydict({
'id': [1, 2],
'name': ['Alice', 'Carol'],
Expand All @@ -165,7 +240,12 @@ def test_composite_key_upsert(self):
result['name'].to_pylist(),
result['city'].to_pylist(),
))
self.assertIn((2, 'Carol', 'Dallas'), rows)
self.assertEqual([
(1, 'Alice', 'Updated'),
(1, 'Alice', 'Updated'),
(2, 'Bob', 'Chicago'),
(2, 'Carol', 'Dallas'),
], rows)

def test_sequential_upserts(self):
"""A second upsert sees the rows inserted by the first."""
Expand Down
4 changes: 2 additions & 2 deletions paimon-python/pypaimon/write/table_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def _upsert_by_arrow_with_key(

For each row in the input Arrow table:

* If a row with the same composite ``upsert_keys`` value already
exists → update that row in-place.
* If one or more rows with the same composite ``upsert_keys`` value
already exist → update all of them in-place.
* Otherwise → append as a new row.

The public method lives on the concrete subclasses so each can
Expand Down
56 changes: 33 additions & 23 deletions paimon-python/pypaimon/write/table_upsert_by_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class TableUpsertByKey:
Table upsert by one or more user-specified key columns for append-only tables.

For each row in the input Arrow table:
- If a row with the same upsert_keys composite value already exists → update that row
(in-place rewrite).
- If one or more rows with the same upsert_keys composite value already exist →
update all of them (in-place rewrite).
- If no matching row exists → append as a new row.

All upsert_keys must be columns present in both the input data and the table schema.
Expand Down Expand Up @@ -168,28 +168,37 @@ def _upsert_partition(
partition_data, input_key_tuples, partition_spec,
)

# 3. Scan partition once, keeping only key → _ROW_ID pairs that
# appear in the input (memory ∝ |input|, not |partition|).
key_to_row_id = self._build_key_to_row_id_map(
# 3. Scan partition once, keeping key → [_ROW_ID, ...] for keys that
# appear in the input (memory ∝ matched existing rows, not the
# whole partition).
key_to_row_ids = self._build_key_to_row_ids_map(
match_keys, partition_spec, set(input_key_tuples),
)

# 4. Partition input rows into matched (update) vs unmatched (append).
matched_indices: List[int] = []
new_indices: List[int] = []
for i, key_tuple in enumerate(input_key_tuples):
(matched_indices if key_tuple in key_to_row_id else new_indices).append(i)
(matched_indices if key_tuple in key_to_row_ids else new_indices).append(i)

logger.info(
"Upserting partition %s: %d matched, %d new",
partition_spec, len(matched_indices), len(new_indices),
)
total_updates = sum(
len(key_to_row_ids[input_key_tuples[i]]) for i in matched_indices)
if total_updates > len(matched_indices):
logger.info(
"Upsert fan-out in partition %s: %d input rows expand to "
"%d row updates", partition_spec,
len(matched_indices), total_updates,
)

commit_messages: List[CommitMessage] = []
if matched_indices:
commit_messages.extend(self._do_updates(
partition_data, matched_indices,
input_key_tuples, key_to_row_id, update_cols,
input_key_tuples, key_to_row_ids, update_cols,
))
if new_indices:
commit_messages.extend(self._do_appends(partition_data, new_indices))
Expand Down Expand Up @@ -274,14 +283,14 @@ def _validate_inputs(self, data: pa.Table, upsert_keys: List[str],
# that partition columns can be stripped first. The same non-partition
# key may legally appear in different partitions.

def _build_key_to_row_id_map(
def _build_key_to_row_ids_map(
self,
match_keys: List[str],
partition_spec: Optional[Dict[str, Any]],
input_key_set: set,
) -> Dict[_KeyTuple, int]:
) -> Dict[_KeyTuple, List[int]]:
"""
Scan the partition in batches and collect key → _ROW_ID only for
Scan the partition in batches and collect key → [_ROW_ID, ...] for
rows whose composite key is in *input_key_set*.

The partition spec (if any) is pushed down as an ``and`` of per-key
Expand Down Expand Up @@ -322,35 +331,36 @@ def _build_key_to_row_id_map(
)

# Stream batches and filter against input_key_set on-the-fly
key_to_row_id: Dict[_KeyTuple, int] = {}
key_to_row_ids: Dict[_KeyTuple, List[int]] = {}
row_id_col = SpecialFields.ROW_ID.name
for batch in table_read.to_arrow_batch_reader(splits):
batch_key_cols = [batch.column(k).to_pylist() for k in match_keys]
batch_row_ids = batch.column(row_id_col).to_pylist()
for j, row_id in enumerate(batch_row_ids):
key_tuple = tuple(col[j] for col in batch_key_cols)
if key_tuple in input_key_set:
key_to_row_id[key_tuple] = row_id
key_to_row_ids.setdefault(key_tuple, []).append(row_id)

return key_to_row_id
return key_to_row_ids

def _do_updates(
self,
data: pa.Table,
matched_indices: List[int],
input_key_tuples: List[_KeyTuple],
key_to_row_id: Dict[_KeyTuple, int],
key_to_row_ids: Dict[_KeyTuple, List[int]],
update_cols: Optional[List[str]]
) -> List[CommitMessage]:
"""Update matched rows by rewriting them in-place via
:class:`TableUpdateByRowId`."""
matched_data = data.take(matched_indices)
row_id_array = pa.array(
[key_to_row_id[input_key_tuples[i]] for i in matched_indices],
type=pa.int64(),
)
update_data = matched_data.append_column(
SpecialFields.ROW_ID.name, row_id_array,
"""Update matched rows in-place via :class:`TableUpdateByRowId`."""
expanded_input_indices: List[int] = []
row_ids: List[int] = []
for i in matched_indices:
for row_id in key_to_row_ids[input_key_tuples[i]]:
expanded_input_indices.append(i)
row_ids.append(row_id)

update_data = data.take(expanded_input_indices).append_column(
SpecialFields.ROW_ID.name, pa.array(row_ids, type=pa.int64()),
)

cols_to_update = list(update_cols) if update_cols else list(self.table.field_names)
Expand Down
Loading