diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 7e4c6eb1ec..7bd4597399 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -344,7 +344,7 @@ def _partition_summary(self, update_metrics: UpdateMetrics) -> str: def update_snapshot_summaries(summary: Summary, previous_summary: Mapping[str, str] | None = None) -> Summary: - if summary.operation not in {Operation.APPEND, Operation.OVERWRITE, Operation.DELETE}: + if summary.operation not in {Operation.APPEND, Operation.OVERWRITE, Operation.DELETE, Operation.REPLACE}: raise ValueError(f"Operation not implemented: {summary.operation}") if not previous_summary: diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 37d120969a..ece046228f 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -165,6 +165,48 @@ def _calculate_added_rows(self, manifests: list[ManifestFile]) -> int: added_rows += manifest.added_rows_count return added_rows + def _get_existing_manifests(self, should_use_manifest_pruning: bool = False) -> list[ManifestFile]: + """Filter existing manifests and rewrite those containing deleted data files.""" + existing_files: list[ManifestFile] = [] + # Use manifest pruning if a predicate is set (primarily for Overwrite) + manifest_evaluators: dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator) + + if snapshot := self._transaction.table_metadata.snapshot_by_name(name=self._target_branch): + for manifest_file in snapshot.manifests(io=self._io): + # Skip pruning for rewrite operations unless we want to optimize later + if should_use_manifest_pruning and not manifest_evaluators[manifest_file.partition_spec_id](manifest_file): + existing_files.append(manifest_file) + continue + + entries_to_write: list[ManifestEntry] = [] + found_deleted_entries = False + + for entry in manifest_file.fetch_manifest_entry(io=self._io, discard_deleted=True): + if entry.data_file in self._deleted_data_files: + found_deleted_entries = True + else: + entries_to_write.append(entry) + + if not found_deleted_entries: + existing_files.append(manifest_file) + continue + + if len(entries_to_write) > 0: + with self.new_manifest_writer(self.spec(manifest_file.partition_spec_id)) as writer: + for entry in entries_to_write: + writer.add_entry( + ManifestEntry.from_args( + status=ManifestEntryStatus.EXISTING, + snapshot_id=entry.snapshot_id, + sequence_number=entry.sequence_number, + file_sequence_number=entry.file_sequence_number, + data_file=entry.data_file, + ) + ) + existing_files.append(writer.to_manifest_file()) + + return existing_files + @abstractmethod def _deleted_entries(self) -> list[ManifestEntry]: ... @@ -585,49 +627,7 @@ class _OverwriteFiles(_SnapshotProducer["_OverwriteFiles"]): def _existing_manifests(self) -> list[ManifestFile]: """Determine if there are any existing manifest files.""" - existing_files = [] - - manifest_evaluators: dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator) - if snapshot := self._transaction.table_metadata.snapshot_by_name(name=self._target_branch): - for manifest_file in snapshot.manifests(io=self._io): - # Manifest does not contain rows that match the files to delete partitions - if not manifest_evaluators[manifest_file.partition_spec_id](manifest_file): - existing_files.append(manifest_file) - continue - - entries_to_write: set[ManifestEntry] = set() - found_deleted_entries: set[ManifestEntry] = set() - - for entry in manifest_file.fetch_manifest_entry(io=self._io, discard_deleted=True): - if entry.data_file in self._deleted_data_files: - found_deleted_entries.add(entry) - else: - entries_to_write.add(entry) - - # Is the intercept the empty set? - if len(found_deleted_entries) == 0: - existing_files.append(manifest_file) - continue - - # Delete all files from manifest - if len(entries_to_write) == 0: - continue - - # We have to rewrite the manifest file without the deleted data files - with self.new_manifest_writer(self.spec(manifest_file.partition_spec_id)) as writer: - for entry in entries_to_write: - writer.add_entry( - ManifestEntry.from_args( - status=ManifestEntryStatus.EXISTING, - snapshot_id=entry.snapshot_id, - sequence_number=entry.sequence_number, - file_sequence_number=entry.file_sequence_number, - data_file=entry.data_file, - ) - ) - existing_files.append(writer.to_manifest_file()) - - return existing_files + return self._get_existing_manifests(should_use_manifest_pruning=True) def _deleted_entries(self) -> list[ManifestEntry]: """To determine if we need to record any deleted entries. @@ -667,6 +667,65 @@ def _get_entries(manifest: ManifestFile) -> list[ManifestEntry]: return [] +class _RewriteFiles(_SnapshotProducer["_RewriteFiles"]): + """A snapshot producer that rewrites data files.""" + + def _commit(self) -> UpdatesAndRequirements: + # Only produce a commit when there is something to rewrite + if self._deleted_data_files or self._added_data_files: + # Grab the entries that we actually found in the table's manifests + deleted_entries = self._deleted_entries() + found_deleted_files = {entry.data_file for entry in deleted_entries} + + # If the user asked to delete files that aren't in the table, abort. + if len(found_deleted_files) != len(self._deleted_data_files): + raise ValueError("Cannot delete files that are not present in the table") + + added_records = sum(f.record_count for f in self._added_data_files) + deleted_records = sum(entry.data_file.record_count for entry in deleted_entries) + + if added_records > deleted_records: + raise ValueError(f"Invalid replace: records added ({added_records}) exceeds records removed ({deleted_records})") + + return super()._commit() + else: + return (), () + + @cached_property + def _cached_deleted_entries(self) -> list[ManifestEntry]: + """Check if we need to mark the files as deleted.""" + if self._parent_snapshot_id is not None: + previous_snapshot = self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id) + if previous_snapshot is None: + raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}") + + executor = ExecutorFactory.get_or_create() + + def _get_entries(manifest: ManifestFile) -> list[ManifestEntry]: + return [ + ManifestEntry.from_args( + status=ManifestEntryStatus.DELETED, + snapshot_id=self.snapshot_id, + sequence_number=entry.sequence_number, + file_sequence_number=entry.file_sequence_number, + data_file=entry.data_file, + ) + for entry in manifest.fetch_manifest_entry(self._io, discard_deleted=True) + if entry.data_file.content == DataFileContent.DATA and entry.data_file in self._deleted_data_files + ] + + list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._io)) + return list(itertools.chain(*list_of_entries)) + else: + return [] + + def _deleted_entries(self) -> list[ManifestEntry]: + return self._cached_deleted_entries + + def _existing_manifests(self) -> list[ManifestFile]: + return self._get_existing_manifests() + + class UpdateSnapshot: _transaction: Transaction _io: FileIO @@ -724,6 +783,15 @@ def delete(self) -> _DeleteFiles: snapshot_properties=self._snapshot_properties, ) + def replace(self) -> _RewriteFiles: + return _RewriteFiles( + operation=Operation.REPLACE, + transaction=self._transaction, + io=self._io, + branch=self._branch, + snapshot_properties=self._snapshot_properties, + ) + class _ManifestMergeManager(Generic[U]): _target_size_bytes: int diff --git a/tests/table/test_replace.py b/tests/table/test_replace.py new file mode 100644 index 0000000000..1708e7173d --- /dev/null +++ b/tests/table/test_replace.py @@ -0,0 +1,619 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import cast + +import pytest + +from pyiceberg.catalog import Catalog +from pyiceberg.manifest import ( + DataFile, + DataFileContent, + FileFormat, + ManifestEntry, + ManifestEntryStatus, +) +from pyiceberg.schema import Schema +from pyiceberg.table.snapshots import Operation, Snapshot, Summary +from pyiceberg.typedef import Record + + +def _create_dummy_data_file( + file_path: str, + record_count: int, + file_size_in_bytes: int = 1024, + content: DataFileContent = DataFileContent.DATA, + partition: Record | None = None, + spec_id: int = 0, +) -> DataFile: + if partition is None: + partition = Record() + df = DataFile.from_args( + file_path=file_path, + file_format=FileFormat.PARQUET, + partition=partition, + record_count=record_count, + file_size_in_bytes=file_size_in_bytes, + content=content, + ) + df.spec_id = spec_id + return df + + +def test_replace_internally(catalog: Catalog) -> None: + # Setup a basic table using the catalog fixture + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace", + schema=Schema(), + ) + + # 1. File we will delete + file_to_delete = _create_dummy_data_file( + file_path="s3://bucket/test/data/deleted.parquet", + record_count=100, + ) + + # 2. File we will leave completely untouched + file_to_keep = _create_dummy_data_file( + file_path="s3://bucket/test/data/kept.parquet", + record_count=50, + file_size_in_bytes=512, + ) + + # 3. File we are adding as a replacement + file_to_add = _create_dummy_data_file( + file_path="s3://bucket/test/data/added.parquet", + record_count=100, + ) + + # Initially append BOTH the file to delete and the file to keep + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_to_delete) + append_snapshot.append_data_file(file_to_keep) + + old_snapshot = cast(Snapshot, table.current_snapshot()) + old_snapshot_id = old_snapshot.snapshot_id + old_sequence_number = cast(int, old_snapshot.sequence_number) + + # Call the internal replace API + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_to_delete) + rewrite.append_data_file(file_to_add) + + snapshot = cast(Snapshot, table.current_snapshot()) + summary = cast(Summary, snapshot.summary) + + # 1. Has a unique snapshot ID + assert snapshot.snapshot_id is not None + assert snapshot.snapshot_id != old_snapshot_id + + # 2. Parent points to the previous snapshot + assert snapshot.parent_snapshot_id == old_snapshot_id + + # 3. Sequence number is exactly previous + 1 + assert snapshot.sequence_number == old_sequence_number + 1 + + # 4. Operation type is set to "replace" + assert summary["operation"] == Operation.REPLACE + + # 5. Manifest list path is correct (just verify it exists and is a string path) + assert snapshot.manifest_list is not None + assert isinstance(snapshot.manifest_list, str) + + # 6. Summary counts are accurate + assert summary["added-data-files"] == "1" + assert summary["deleted-data-files"] == "1" + assert summary["added-records"] == "100" + assert summary["deleted-records"] == "100" + assert summary["total-records"] == "150" + + # Fetch all entries from the new manifests + manifest_files = snapshot.manifests(table.io) + entries: list[ManifestEntry] = [] + for manifest in manifest_files: + entries.extend(manifest.fetch_manifest_entry(table.io, discard_deleted=False)) + + # We expect 3 entries: ADDED, DELETED, and EXISTING + assert len(entries) == 3 + + # Check ADDED + added_entries = [e for e in entries if e.status == ManifestEntryStatus.ADDED] + assert len(added_entries) == 1 + assert added_entries[0].data_file.file_path == file_to_add.file_path + assert added_entries[0].snapshot_id == snapshot.snapshot_id + + # Check DELETED + deleted_entries = [e for e in entries if e.status == ManifestEntryStatus.DELETED] + assert len(deleted_entries) == 1 + assert deleted_entries[0].data_file.file_path == file_to_delete.file_path + assert deleted_entries[0].snapshot_id == snapshot.snapshot_id + + # Check EXISTING + existing_entries = [e for e in entries if e.status == ManifestEntryStatus.EXISTING] + assert len(existing_entries) == 1 + assert existing_entries[0].data_file.file_path == file_to_keep.file_path + assert existing_entries[0].snapshot_id == old_snapshot_id + assert existing_entries[0].sequence_number == old_sequence_number + + +def test_replace_reuses_unaffected_manifests(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_reuse_manifest", + schema=Schema(), + ) + + file_a = _create_dummy_data_file( + file_path="s3://bucket/test/data/a.parquet", + record_count=10, + file_size_in_bytes=100, + ) + + file_b = _create_dummy_data_file( + file_path="s3://bucket/test/data/b.parquet", + record_count=10, + file_size_in_bytes=100, + ) + + file_c = _create_dummy_data_file( + file_path="s3://bucket/test/data/c.parquet", + record_count=10, + file_size_in_bytes=100, + ) + + # Commit 1: Append file A (Creates Manifest 1) + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_a) + + # Commit 2: Append file B (Creates Manifest 2) + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_b) + + snapshot_before = cast(Snapshot, table.current_snapshot()) + manifests_before = snapshot_before.manifests(table.io) + assert len(manifests_before) == 2 + + # Identify which manifest belongs to file_b and file_a + manifest_b_path = None + manifest_a_path = None + for m in manifests_before: + entries = m.fetch_manifest_entry(table.io, discard_deleted=False) + if any(e.data_file.file_path == file_b.file_path for e in entries): + manifest_b_path = m.manifest_path + if any(e.data_file.file_path == file_a.file_path for e in entries): + manifest_a_path = m.manifest_path + + assert manifest_b_path is not None + assert manifest_a_path is not None + + # Commit 3: Replace file A with file C + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_a) + rewrite.append_data_file(file_c) + + snapshot_after = cast(Snapshot, table.current_snapshot()) + assert snapshot_after is not None + manifests_after = snapshot_after.manifests(table.io) + + # We expect 3 manifests: + # 1. The reused one for file B + # 2. The newly rewritten one marking file A as DELETED + # 3. The new one for file C (ADDED) + assert len(manifests_after) == 3 + + manifest_paths_after = [m.manifest_path for m in manifests_after] + + # ASSERTION 1: The untouched manifest is completely reused (the path matches exactly) + assert manifest_b_path in manifest_paths_after + + # ASSERTION 2: File A's old manifest is NOT reused (since it was rewritten to change status to DELETED) + assert manifest_a_path not in manifest_paths_after + + +def test_replace_empty_files(catalog: Catalog) -> None: + # Setup a basic table using the catalog fixture + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_empty", + schema=Schema(), + ) + + # Replacing empty lists should not throw errors, but should produce no changes. + with table.transaction() as tx: + with tx.update_snapshot().replace(): + pass # Entering and exiting the context manager without adding/deleting + + # History should be completely empty since no files were rewritten + assert len(table.history()) == 0 + assert table.current_snapshot() is None + + +def test_replace_missing_file_abort(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_missing", + schema=Schema(), + ) + + fake_data_file = _create_dummy_data_file( + file_path="s3://bucket/test/data/does_not_exist.parquet", + record_count=100, + ) + + new_data_file = _create_dummy_data_file( + file_path="s3://bucket/test/data/new.parquet", + record_count=100, + ) + + # Ensure it aborts when trying to replace a file that isn't in the table + with pytest.raises(ValueError, match="Cannot delete files that are not present in the table"): + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(fake_data_file) + rewrite.append_data_file(new_data_file) + + +def test_replace_invariant_violation(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_invariant", + schema=Schema(), + ) + + file_to_delete = _create_dummy_data_file( + file_path="s3://bucket/test/data/deleted.parquet", + record_count=100, + ) + + # Create a new file with MORE records than the one we are deleting + too_many_records_file = _create_dummy_data_file( + file_path="s3://bucket/test/data/too_many.parquet", + record_count=101, + ) + + # Initially append to have something to replace + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_to_delete) + + # Ensure it enforces the invariant: records added <= records removed + with pytest.raises(ValueError, match=r"Invalid replace: records added \(101\) exceeds records removed \(100\)"): + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_to_delete) + rewrite.append_data_file(too_many_records_file) + + +def test_replace_allows_shrinking_for_soft_deletes(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_shrink", + schema=Schema(), + ) + + # Old data file has 100 records + file_to_delete = _create_dummy_data_file( + file_path="s3://bucket/test/data/deleted.parquet", + record_count=100, + ) + + # New data file only has 90 records (simulating 10 records were soft-deleted) + shrunk_file_to_add = _create_dummy_data_file( + file_path="s3://bucket/test/data/shrunk.parquet", + record_count=90, + file_size_in_bytes=900, + ) + + # Initially append + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_to_delete) + + # This should succeed without throwing an invariant violation + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_to_delete) + rewrite.append_data_file(shrunk_file_to_add) + + snapshot = cast(Snapshot, table.current_snapshot()) + summary = cast(Summary, snapshot.summary) + + assert summary["operation"] == Operation.REPLACE + assert summary["added-records"] == "90" + assert summary["deleted-records"] == "100" + + +def test_replace_passes_through_delete_manifests(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_delete_manifests", + schema=Schema(), + properties={"format-version": "2"}, + ) + + # 1. Data file we will replace + file_a = _create_dummy_data_file( + file_path="s3://bucket/test/data/a.parquet", + record_count=10, + file_size_in_bytes=100, + ) + + # 2. A Position Delete file (representing row-level deletes) + file_a_deletes = _create_dummy_data_file( + file_path="s3://bucket/test/data/a_deletes.parquet", + record_count=2, + file_size_in_bytes=50, + content=DataFileContent.POSITION_DELETES, + ) + + # 3. Data file we are adding as a replacement + file_b = _create_dummy_data_file( + file_path="s3://bucket/test/data/b.parquet", + record_count=10, + file_size_in_bytes=100, + ) + + # Commit 1: Append the data file + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_a) + + # Commit 2: Append the delete file + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_a_deletes) + + # Find the path of the delete manifest so we can verify it survives + snapshot_before = cast(Snapshot, table.current_snapshot()) + manifests_before = snapshot_before.manifests(table.io) + + delete_manifest_path = None + for m in manifests_before: + entries = m.fetch_manifest_entry(table.io, discard_deleted=False) + if any(e.data_file.file_path == file_a_deletes.file_path for e in entries): + delete_manifest_path = m.manifest_path + break + + assert delete_manifest_path is not None + + # Commit 3: Replace data file A with data file B + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_a) + rewrite.append_data_file(file_b) + + # Verify the delete manifest was passed through unchanged + snapshot_after = cast(Snapshot, table.current_snapshot()) + assert snapshot_after is not None + manifests_after = snapshot_after.manifests(table.io) + manifest_paths_after = [m.manifest_path for m in manifests_after] + + assert delete_manifest_path in manifest_paths_after + + +def test_replace_multiple_files(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_multiple", + schema=Schema(), + ) + + file_1 = _create_dummy_data_file( + file_path="s3://bucket/test/data/1.parquet", + record_count=100, + ) + + file_2 = _create_dummy_data_file( + file_path="s3://bucket/test/data/2.parquet", + record_count=100, + ) + + file_1_new = _create_dummy_data_file( + file_path="s3://bucket/test/data/1_new.parquet", + record_count=50, + file_size_in_bytes=512, + ) + + file_2_new = _create_dummy_data_file( + file_path="s3://bucket/test/data/2_new.parquet", + record_count=50, + file_size_in_bytes=512, + ) + + # Append initial files + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_1) + append_snapshot.append_data_file(file_2) + + # Replace both files with new ones + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_1) + rewrite.delete_data_file(file_2) + rewrite.append_data_file(file_1_new) + rewrite.append_data_file(file_2_new) + + snapshot = cast(Snapshot, table.current_snapshot()) + summary = cast(Summary, snapshot.summary) + + assert summary["added-data-files"] == "2" + assert summary["deleted-data-files"] == "2" + assert summary["added-records"] == "100" + assert summary["deleted-records"] == "200" + assert summary["total-records"] == "100" + + +def test_replace_partitioned_table(catalog: Catalog) -> None: + from pyiceberg.partitioning import PartitionField, PartitionSpec + from pyiceberg.transforms import IdentityTransform + from pyiceberg.types import IntegerType, NestedField, StringType + + # Setup a partitioned table + catalog.create_namespace("default") + schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + ) + spec = PartitionSpec(PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="id")) + table = catalog.create_table( + identifier="default.test_replace_partitioned", + schema=schema, + partition_spec=spec, + ) + + # File in partition id=1 + file_part1 = _create_dummy_data_file( + file_path="s3://bucket/test/data/part1.parquet", + partition=Record(1), + record_count=100, + spec_id=table.spec().spec_id, + ) + + # File in partition id=2 + file_part2 = _create_dummy_data_file( + file_path="s3://bucket/test/data/part2.parquet", + partition=Record(2), + record_count=100, + spec_id=table.spec().spec_id, + ) + + # Add initial files + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_part1) + append_snapshot.append_data_file(file_part2) + + # Replace file in partition 1 + file_part1_new = _create_dummy_data_file( + file_path="s3://bucket/test/data/part1_new.parquet", + partition=Record(1), + record_count=50, + file_size_in_bytes=512, + spec_id=table.spec().spec_id, + ) + + with table.transaction() as tx: + with tx.update_snapshot().replace() as rewrite: + rewrite.delete_data_file(file_part1) + rewrite.append_data_file(file_part1_new) + + snapshot = cast(Snapshot, table.current_snapshot()) + summary = cast(Summary, snapshot.summary) + + assert summary["added-data-files"] == "1" + assert summary["deleted-data-files"] == "1" + assert summary["total-records"] == "150" + + +def test_replace_no_op_on_non_empty_table(catalog: Catalog) -> None: + # Setup a basic table + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_noop_nonempty", + schema=Schema(), + ) + + file_a = _create_dummy_data_file( + file_path="s3://bucket/test/data/a.parquet", + record_count=10, + file_size_in_bytes=100, + ) + + # Commit 1: Append file A + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_a) + + initial_snapshot = table.current_snapshot() + assert initial_snapshot is not None + + # Perform a no-op replace + with table.transaction() as tx: + with tx.update_snapshot().replace(): + pass + + # Successive calls to current_snapshot() should yield the same snapshot + assert table.current_snapshot() == initial_snapshot + assert len(table.history()) == 1 + + +def test_replace_on_custom_branch(catalog: Catalog) -> None: + # Setup a basic table using the catalog fixture + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_branch", + schema=Schema(), + ) + + # 1. File we will delete + file_to_delete = _create_dummy_data_file( + file_path="s3://bucket/test/data/deleted.parquet", + record_count=100, + ) + + # 2. File we are adding as a replacement + file_to_add = _create_dummy_data_file( + file_path="s3://bucket/test/data/added.parquet", + record_count=100, + ) + + # Initially append to have something to replace on main + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_to_delete) + + initial_main_snapshot = cast(Snapshot, table.current_snapshot()) + initial_main_snapshot_id = initial_main_snapshot.snapshot_id + + # Create a new branch called "test-branch" pointing to the initial snapshot + table.manage_snapshots().create_branch(branch_name="test-branch", snapshot_id=initial_main_snapshot_id).commit() + + # Perform a replace() operation explicitly targeting "test-branch" + with table.transaction() as tx: + with tx.update_snapshot(branch="test-branch").replace() as rewrite: + rewrite.delete_data_file(file_to_delete) + rewrite.append_data_file(file_to_add) + + # Reload table to get updated refs + table = catalog.load_table("default.test_replace_branch") + + test_branch_ref = table.metadata.refs["test-branch"] + main_branch_ref = table.metadata.refs["main"] + + # Assert that the operation was successful on test-branch + assert test_branch_ref.snapshot_id != initial_main_snapshot_id + + # Assert that the "test-branch" reference now points to a REPLACE snapshot + new_snapshot = table.snapshot_by_id(test_branch_ref.snapshot_id) + assert new_snapshot is not None + summary = cast(Summary, new_snapshot.summary) + assert summary["operation"] == Operation.REPLACE + + # Assert that the "main" branch reference was completely untouched + assert main_branch_ref.snapshot_id == initial_main_snapshot_id diff --git a/tests/table/test_snapshots.py b/tests/table/test_snapshots.py index cfdc516227..7f78a7546d 100644 --- a/tests/table/test_snapshots.py +++ b/tests/table/test_snapshots.py @@ -398,8 +398,8 @@ def test_merge_snapshot_summaries_overwrite_summary() -> None: def test_invalid_operation() -> None: with pytest.raises(ValueError) as e: - update_snapshot_summaries(summary=Summary(Operation.REPLACE)) - assert "Operation not implemented: Operation.REPLACE" in str(e.value) + update_snapshot_summaries(summary=Summary.model_construct(operation="unknown_operation")) + assert "Operation not implemented: unknown_operation" in str(e.value) def test_invalid_type() -> None: