From 93d77d38fd892161de1b34167426b966bc4ae2d4 Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Sun, 19 Apr 2026 01:07:45 -0700 Subject: [PATCH] Implementation Co-Authored-By: Claude Opus 4.6 (1M context) --- pyiceberg/catalog/__init__.py | 83 +++++ pyiceberg/catalog/noop.py | 23 ++ pyiceberg/catalog/rest/__init__.py | 81 ++++- pyiceberg/partitioning.py | 67 ++++ pyiceberg/schema.py | 52 +++ pyiceberg/table/__init__.py | 122 +++++++ tests/catalog/test_rest.py | 479 ++++++++++++++++++++++++- tests/integration/test_rest_catalog.py | 478 ++++++++++++++++++++++++ tests/table/test_partitioning.py | 54 ++- tests/test_schema.py | 88 +++++ 10 files changed, 1523 insertions(+), 4 deletions(-) diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index 5797e1f050..be782576dd 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -47,6 +47,7 @@ DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, CommitTableResponse, CreateTableTransaction, + ReplaceTableTransaction, StagedTable, Table, TableProperties, @@ -442,6 +443,66 @@ def create_table_if_not_exists( except TableAlreadyExistsError: return self.load_table(identifier) + @abstractmethod + def replace_table( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> Table: + """Atomically replace a table's schema, spec, sort order, location, and properties. + + The table UUID and history (snapshots, schemas, specs, sort orders) are preserved. + The current snapshot is cleared (main branch ref is removed). + + Args: + identifier (str | Identifier): Table identifier. + schema (Schema): New table schema. + location (str | None): New table location. Defaults to the existing location. + partition_spec (PartitionSpec): New partition spec. + sort_order (SortOrder): New sort order. + properties (Properties): New table properties (merged with existing). + + Returns: + Table: the replaced table instance. + + Raises: + NoSuchTableError: If the table does not exist. + """ + + @abstractmethod + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + """Create a ReplaceTableTransaction. + + The transaction can be used to stage additional changes (schema evolution, + partition evolution, etc.) before committing. + + Args: + identifier (str | Identifier): Table identifier. + schema (Schema): New table schema. + location (str | None): New table location. Defaults to the existing location. + partition_spec (PartitionSpec): New partition spec. + sort_order (SortOrder): New sort order. + properties (Properties): New table properties (merged with existing). + + Returns: + ReplaceTableTransaction: A transaction for the replace operation. + + Raises: + NoSuchTableError: If the table does not exist. + """ + @abstractmethod def load_table(self, identifier: str | Identifier) -> Table: """Load the table's metadata and returns the table instance. @@ -888,6 +949,28 @@ def create_table_transaction( self._create_staged_table(identifier, schema, location, partition_spec, sort_order, properties) ) + def replace_table( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> Table: + raise NotImplementedError("replace_table is not yet supported for this catalog type") + + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + raise NotImplementedError("replace_table_transaction is not yet supported for this catalog type") + def table_exists(self, identifier: str | Identifier) -> bool: try: self.load_table(identifier) diff --git a/pyiceberg/catalog/noop.py b/pyiceberg/catalog/noop.py index c5399ad62e..b07e2cc824 100644 --- a/pyiceberg/catalog/noop.py +++ b/pyiceberg/catalog/noop.py @@ -26,6 +26,7 @@ from pyiceberg.table import ( CommitTableResponse, CreateTableTransaction, + ReplaceTableTransaction, Table, ) from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder @@ -64,6 +65,28 @@ def create_table_transaction( ) -> CreateTableTransaction: raise NotImplementedError + def replace_table( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> Table: + raise NotImplementedError + + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + raise NotImplementedError + def load_table(self, identifier: str | Identifier) -> Table: raise NotImplementedError diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index d06fd3885b..94c3a92f6b 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -67,13 +67,19 @@ FileIO, load_file_io, ) -from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec, assign_fresh_partition_spec_ids -from pyiceberg.schema import Schema, assign_fresh_schema_ids +from pyiceberg.partitioning import ( + UNPARTITIONED_PARTITION_SPEC, + PartitionSpec, + assign_fresh_partition_spec_ids, + assign_fresh_partition_spec_ids_for_replace, +) +from pyiceberg.schema import Schema, assign_fresh_schema_ids, assign_fresh_schema_ids_for_replace from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, CreateTableTransaction, FileScanTask, + ReplaceTableTransaction, StagedTable, Table, TableIdentifier, @@ -937,6 +943,77 @@ def create_table_transaction( staged_table = self._response_to_staged_table(self.identifier_to_tuple(identifier), table_response) return CreateTableTransaction(staged_table) + def replace_table( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> Table: + txn = self.replace_table_transaction( + identifier=identifier, + schema=schema, + location=location, + partition_spec=partition_spec, + sort_order=sort_order, + properties=properties, + ) + return txn.commit_transaction() + + @retry(**_RETRY_ARGS) + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + existing_table = self.load_table(identifier) + existing_metadata = existing_table.metadata + + iceberg_schema = self._convert_schema_if_needed( + schema, + int(properties.get(TableProperties.FORMAT_VERSION, existing_metadata.format_version)), # type: ignore + ) + + # Assign fresh schema IDs, reusing IDs from the existing schema by field name + fresh_schema, _ = assign_fresh_schema_ids_for_replace( + iceberg_schema, existing_metadata.schema(), existing_metadata.last_column_id + ) + + # Assign fresh partition spec IDs, reusing IDs from existing specs + fresh_partition_spec, _ = assign_fresh_partition_spec_ids_for_replace( + partition_spec, iceberg_schema, fresh_schema, existing_metadata.partition_specs, existing_metadata.last_partition_id + ) + + # Assign fresh sort order IDs + fresh_sort_order = assign_fresh_sort_order_ids(sort_order, iceberg_schema, fresh_schema) + + # Use existing location if not specified + resolved_location = location.rstrip("/") if location else existing_metadata.location + + # Create a StagedTable from the existing table + staged_table = StagedTable( + identifier=existing_table.name(), + metadata=existing_metadata, + metadata_location=existing_table.metadata_location, + io=existing_table.io, + catalog=self, + ) + + return ReplaceTableTransaction( + table=staged_table, + new_schema=fresh_schema, + new_spec=fresh_partition_spec, + new_sort_order=fresh_sort_order, + new_location=resolved_location, + new_properties=properties, + ) + @retry(**_RETRY_ARGS) def create_view( self, diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 3de185d886..3861e62c30 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -335,6 +335,73 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID) +def assign_fresh_partition_spec_ids_for_replace( + spec: PartitionSpec, + old_schema: Schema, + fresh_schema: Schema, + existing_specs: list[PartitionSpec], + last_partition_id: int | None, +) -> tuple[PartitionSpec, int]: + """Assign partition field IDs for a replace operation, reusing IDs from existing specs. + + For each partition field, if a field with the same (source_id, transform) pair exists in + any of the existing specs, its partition field ID is reused; otherwise a fresh ID is + allocated starting from last_partition_id + 1. + + Args: + spec: The new partition spec to assign IDs to. + old_schema: The schema that the new spec's source_ids reference. + fresh_schema: The schema with freshly assigned field IDs. + existing_specs: All partition specs from the existing table metadata. + last_partition_id: The current table's last_partition_id. + + Returns: + A tuple of (fresh_spec, new_last_partition_id). + """ + effective_last_partition_id = last_partition_id if last_partition_id is not None else PARTITION_FIELD_ID_START - 1 + + # Build (source_id, transform) → partition_field_id mapping from all existing specs + # Use max() for dedup when the same (source_id, transform) appears in multiple specs + transform_to_field_id: dict[tuple[int, str], int] = {} + for existing_spec in existing_specs: + for field in existing_spec.fields: + key = (field.source_id, str(field.transform)) + if key not in transform_to_field_id or field.field_id > transform_to_field_id[key]: + transform_to_field_id[key] = field.field_id + + next_id = effective_last_partition_id + partition_fields = [] + for field in spec.fields: + original_column_name = old_schema.find_column_name(field.source_id) + if original_column_name is None: + raise ValueError(f"Could not find in old schema: {field}") + fresh_field = fresh_schema.find_field(original_column_name) + if fresh_field is None: + raise ValueError(f"Could not find field in fresh schema: {original_column_name}") + + validate_partition_name(field.name, field.transform, fresh_field.field_id, fresh_schema, set()) + + key = (fresh_field.field_id, str(field.transform)) + if key in transform_to_field_id: + partition_field_id = transform_to_field_id[key] + else: + next_id += 1 + partition_field_id = next_id + transform_to_field_id[key] = partition_field_id + + partition_fields.append( + PartitionField( + name=field.name, + source_id=fresh_field.field_id, + field_id=partition_field_id, + transform=field.transform, + ) + ) + + new_last_partition_id = max(next_id, effective_last_partition_id) + return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID), new_last_partition_id + + T = TypeVar("T") diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index fd60eb8f94..9d60787978 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1380,6 +1380,58 @@ def primitive(self, primitive: PrimitiveType) -> PrimitiveType: return primitive +class _SetFreshIDsForReplace(_SetFreshIDs): + """Assign fresh IDs for a replace operation, reusing IDs from the base schema by field name. + + For each field in the new schema, if a field with the same full name exists in the + base schema, its ID is reused; otherwise a fresh ID is allocated starting from + last_column_id + 1. + """ + + def __init__(self, old_id_to_base_id: dict[int, int], starting_id: int) -> None: + self.old_id_to_new_id: dict[int, int] = {} + self._old_id_to_base_id = old_id_to_base_id + counter = itertools.count(starting_id + 1) + self.next_id_func = lambda: next(counter) + + def _get_and_increment(self, current_id: int) -> int: + if current_id in self._old_id_to_base_id: + new_id = self._old_id_to_base_id[current_id] + else: + new_id = self.next_id_func() + self.old_id_to_new_id[current_id] = new_id + return new_id + + +def assign_fresh_schema_ids_for_replace(schema: Schema, base_schema: Schema, last_column_id: int) -> tuple[Schema, int]: + """Assign fresh IDs to a schema for a replace operation, reusing IDs from the base schema. + + For each field in the new schema, if a field with the same full path name exists + in the base schema, its ID is reused. New fields get IDs starting from + last_column_id + 1. + + Args: + schema: The new schema to assign IDs to. + base_schema: The existing table's current schema (IDs are reused from here by name). + last_column_id: The current table's last_column_id (new IDs start above this). + + Returns: + A tuple of (fresh_schema, new_last_column_id). + """ + base_name_to_id = index_by_name(base_schema) + new_id_to_name = index_name_by_id(schema) + + old_id_to_base_id: dict[int, int] = {} + for old_id, name in new_id_to_name.items(): + if name in base_name_to_id: + old_id_to_base_id[old_id] = base_name_to_id[name] + + visitor = _SetFreshIDsForReplace(old_id_to_base_id, last_column_id) + fresh_schema = pre_order_visit(schema, visitor) + new_last_column_id = max(fresh_schema.highest_field_id, last_column_id) + return fresh_schema, new_last_column_id + + # Implementation copied from Apache Iceberg repo. def make_compatible_name(name: str) -> str: """Make a field name compatible with Avro specification. diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index bb8765b651..ec2f54ad57 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -62,6 +62,7 @@ AssertTableUUID, AssignUUIDUpdate, RemovePropertiesUpdate, + RemoveSnapshotRefUpdate, SetCurrentSchemaUpdate, SetDefaultSortOrderUpdate, SetDefaultSpecUpdate, @@ -1009,6 +1010,127 @@ def commit_transaction(self) -> Table: return self._table +class ReplaceTableTransaction(Transaction): + """A transaction that replaces an existing table's schema, spec, sort order, location, and properties. + + The existing table UUID, snapshots, snapshot log, metadata log, and history are preserved. + The "main" branch ref is removed (current-snapshot-id set to -1), and new + schema/spec/sort-order/location/properties are applied. + """ + + def _initial_changes( + self, + table_metadata: TableMetadata, + new_schema: Schema, + new_spec: PartitionSpec, + new_sort_order: SortOrder, + new_location: str, + new_properties: Properties, + ) -> None: + """Set the initial changes that transform the existing table into the replacement.""" + # Remove the main branch ref to clear the current snapshot + self._updates += (RemoveSnapshotRefUpdate(ref_name=MAIN_BRANCH),) + + # Reuse an existing schema if structurally identical (ignoring schema_id). + existing_schema_id = self._find_matching_schema_id(table_metadata, new_schema) + if existing_schema_id is not None: + if existing_schema_id != table_metadata.current_schema_id: + self._updates += (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) + else: + self._updates += ( + AddSchemaUpdate(schema_=new_schema), + SetCurrentSchemaUpdate(schema_id=-1), + ) + + # Only emit AddPartitionSpecUpdate + SetDefaultSpecUpdate(-1) when the spec + # is new. If an identical spec already exists, use its concrete ID. + effective_spec = UNPARTITIONED_PARTITION_SPEC if new_spec.is_unpartitioned() else new_spec + existing_spec_id = self._find_matching_spec_id(table_metadata, effective_spec) + if existing_spec_id is not None: + if existing_spec_id != table_metadata.default_spec_id: + self._updates += (SetDefaultSpecUpdate(spec_id=existing_spec_id),) + else: + self._updates += ( + AddPartitionSpecUpdate(spec=effective_spec), + SetDefaultSpecUpdate(spec_id=-1), + ) + + # Set the new sort order (same logic as spec). + effective_sort_order = UNSORTED_SORT_ORDER if new_sort_order.is_unsorted else new_sort_order + existing_order_id = self._find_matching_sort_order_id(table_metadata, effective_sort_order) + if existing_order_id is not None: + if existing_order_id != table_metadata.default_sort_order_id: + self._updates += (SetDefaultSortOrderUpdate(sort_order_id=existing_order_id),) + else: + self._updates += ( + AddSortOrderUpdate(sort_order=effective_sort_order), + SetDefaultSortOrderUpdate(sort_order_id=-1), + ) + + # Set location if changed + if new_location != table_metadata.location: + self._updates += (SetLocationUpdate(location=new_location),) + + # Merge properties (SetPropertiesUpdate merges onto existing properties) + if new_properties: + self._updates += (SetPropertiesUpdate(updates=new_properties),) + + @staticmethod + def _find_matching_schema_id(table_metadata: TableMetadata, schema: Schema) -> int | None: + """Find an existing schema structurally equal to the given one, returning its schema_id or None.""" + for existing in table_metadata.schemas: + if existing == schema: + return existing.schema_id + return None + + @staticmethod + def _find_matching_spec_id(table_metadata: TableMetadata, spec: PartitionSpec) -> int | None: + """Find an existing partition spec with the same fields, returning its spec_id or None.""" + for existing in table_metadata.partition_specs: + if existing.fields == spec.fields: + return existing.spec_id + return None + + @staticmethod + def _find_matching_sort_order_id(table_metadata: TableMetadata, sort_order: SortOrder) -> int | None: + """Find an existing sort order with the same fields, returning its order_id or None.""" + for existing in table_metadata.sort_orders: + if existing.fields == sort_order.fields: + return existing.order_id + return None + + def __init__( + self, + table: StagedTable, + new_schema: Schema, + new_spec: PartitionSpec, + new_sort_order: SortOrder, + new_location: str, + new_properties: Properties, + ) -> None: + super().__init__(table, autocommit=False) + self._initial_changes(table.metadata, new_schema, new_spec, new_sort_order, new_location, new_properties) + + def commit_transaction(self) -> Table: + """Commit the replace changes to the catalog. + + Uses AssertTableUUID as the only requirement. + + Returns: + The table with the updates applied. + """ + if len(self._updates) > 0: + self._table._do_commit( # pylint: disable=W0212 + updates=self._updates, + requirements=(AssertTableUUID(uuid=self._table.metadata.table_uuid),), + ) + + self._updates = () + self._requirements = () + + return self._table + + class Namespace(IcebergRootModel[list[str]]): """Reference to one or more levels of a namespace.""" diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index aa9a467381..a92748c6bc 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -64,7 +64,7 @@ from pyiceberg.table.sorting import SortField, SortOrder from pyiceberg.transforms import IdentityTransform, TruncateTransform from pyiceberg.typedef import RecursiveDict -from pyiceberg.types import StringType +from pyiceberg.types import BooleanType, IntegerType, NestedField, StringType from pyiceberg.utils.config import Config from pyiceberg.view import View from pyiceberg.view.metadata import ViewMetadata, ViewVersion @@ -2654,3 +2654,480 @@ def test_load_table_without_storage_credentials( ) assert actual.metadata.model_dump() == expected.metadata.model_dump() assert actual == expected + + +def test_replace_table_transaction_200( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that replace_table_transaction loads the existing table, then commits with AssertTableUUID.""" + expected_table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = expected_table_uuid + + # Mock load_table (GET) to return existing table with snapshot + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + # Mock commit (POST) for the replace + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + # Replace with a new schema (3 fields: id stays, data stays, new_col is new) + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="new_col", field_type=BooleanType(), required=False), + ) + txn = catalog.replace_table_transaction( + identifier=("fokko", "fokko2"), + schema=new_schema, + ) + txn.commit_transaction() + + actual_request = rest_mock.last_request.json() + + # Verify requirements: only AssertTableUUID + assert actual_request["requirements"] == [ + {"type": "assert-table-uuid", "uuid": expected_table_uuid}, + ] + + # Verify updates sequence. Since the existing table already has the same + # unpartitioned spec and unsorted sort order, those updates are skipped. + updates = actual_request["updates"] + actions = [u["action"] for u in updates] + assert actions == [ + "remove-snapshot-ref", + "add-schema", + "set-current-schema", + ] + + # Verify remove-snapshot-ref targets "main" + assert updates[0] == {"action": "remove-snapshot-ref", "ref-name": "main"} + + # Verify schema has reused field IDs (id=1, data=2 reused from existing schema) + schema_fields = updates[1]["schema"]["fields"] + assert schema_fields[0]["id"] == 1 + assert schema_fields[0]["name"] == "id" + assert schema_fields[1]["id"] == 2 + assert schema_fields[1]["name"] == "data" + # new_col gets a fresh ID above last_column_id (which is 2), so it gets 3 + assert schema_fields[2]["id"] == 3 + assert schema_fields[2]["name"] == "new_col" + + # set-current-schema uses -1 (meaning last added) + assert updates[2] == {"action": "set-current-schema", "schema-id": -1} + + +def test_replace_table_transaction_preserves_uuid( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that replace preserves the table UUID.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + txn = catalog.replace_table_transaction( + identifier=("fokko", "fokko2"), + schema=Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ), + ) + + # The staged table should have the same UUID as the existing table + assert str(txn.table_metadata.table_uuid) == table_uuid + + result = txn.commit_transaction() + # After commit, the table should still have the same UUID + assert str(result.metadata.table_uuid) == table_uuid + + +def test_replace_table_transaction_with_new_location( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that replace_table_transaction with a new location includes SetLocationUpdate.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + txn = catalog.replace_table_transaction( + identifier=("fokko", "fokko2"), + schema=Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + ), + location="s3://new-warehouse/database/table", + ) + txn.commit_transaction() + + updates = rest_mock.last_request.json()["updates"] + actions = [u["action"] for u in updates] + + # Should include set-location since the location changed + assert "set-location" in actions + set_location_update = next(u for u in updates if u["action"] == "set-location") + assert set_location_update["location"] == "s3://new-warehouse/database/table" + + +def test_replace_table_transaction_with_properties( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that replace merges properties via SetPropertiesUpdate.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + txn = catalog.replace_table_transaction( + identifier=("fokko", "fokko2"), + schema=Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + ), + properties={"new-prop": "new-value"}, + ) + txn.commit_transaction() + + updates = rest_mock.last_request.json()["updates"] + actions = [u["action"] for u in updates] + + assert "set-properties" in actions + set_props = next(u for u in updates if u["action"] == "set-properties") + # SetPropertiesUpdate sends the user properties; the server merges onto existing + assert set_props["updates"] == {"new-prop": "new-value"} + + +def test_replace_table_transaction_with_partition_spec( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that replace_table_transaction with a new partition spec includes AddPartitionSpecUpdate.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + txn = catalog.replace_table_transaction( + identifier=("fokko", "fokko2"), + schema=Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ), + partition_spec=PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=3), name="id_trunc"), spec_id=1 + ), + ) + txn.commit_transaction() + + updates = rest_mock.last_request.json()["updates"] + add_spec = next(u for u in updates if u["action"] == "add-spec") + spec_fields = add_spec["spec"]["fields"] + assert len(spec_fields) == 1 + assert spec_fields[0]["source-id"] == 1 # id field + assert spec_fields[0]["transform"] == "truncate[3]" + assert spec_fields[0]["name"] == "id_trunc" + + # set-default-spec should also be present, pointing to the newly added spec + actions = [u["action"] for u in updates] + assert "set-default-spec" in actions + set_default_spec = next(u for u in updates if u["action"] == "set-default-spec") + assert set_default_spec["spec-id"] == -1 + + +def test_replace_table_404( + rest_mock: Mocker, +) -> None: + """Test that replace_table raises NoSuchTableError when the table doesn't exist.""" + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/nonexistent", + json={ + "error": { + "message": "Table does not exist: fokko.nonexistent", + "type": "NoSuchTableException", + "code": 404, + } + }, + status_code=404, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + with pytest.raises(NoSuchTableError): + catalog.replace_table( + identifier=("fokko", "nonexistent"), + schema=Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + ), + ) + + +def test_replace_table_200( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that replace_table commits immediately and returns the table.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + result = catalog.replace_table( + identifier=("fokko", "fokko2"), + schema=Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + ), + ) + + assert isinstance(result, Table) + # The commit should have used assert-table-uuid + actual_request = rest_mock.last_request.json() + assert actual_request["requirements"] == [ + {"type": "assert-table-uuid", "uuid": table_uuid}, + ] + + +def test_replace_table_transaction_same_location_no_set_location( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that when location is not changed, SetLocationUpdate is NOT included.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + # Replace with no location specified - should use existing location + txn = catalog.replace_table_transaction( + identifier=("fokko", "fokko2"), + schema=Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + ), + ) + txn.commit_transaction() + + updates = rest_mock.last_request.json()["updates"] + actions = [u["action"] for u in updates] + # set-location should NOT be present since location didn't change + assert "set-location" not in actions + + +def test_replace_table_transaction_same_schema_skips_add_schema( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that replacing with the same schema skips add-schema and set-current-schema.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + # Use the exact same schema as the existing table (id: int, data: string) + same_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + txn = catalog.replace_table_transaction( + identifier=("fokko", "fokko2"), + schema=same_schema, + ) + txn.commit_transaction() + + updates = rest_mock.last_request.json()["updates"] + actions = [u["action"] for u in updates] + + # Since the schema is unchanged, add-schema and set-current-schema should be skipped + assert "add-schema" not in actions + assert "set-current-schema" not in actions + + # The only update should be remove-snapshot-ref + assert actions == ["remove-snapshot-ref"] + + +def test_replace_table_transaction_different_schema_adds_schema( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that replacing with a genuinely new schema includes add-schema and set-current-schema.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + # A new schema with a different field (new_col instead of data) + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=3, name="new_col", field_type=BooleanType(), required=False), + ) + txn = catalog.replace_table_transaction( + identifier=("fokko", "fokko2"), + schema=new_schema, + ) + txn.commit_transaction() + + updates = rest_mock.last_request.json()["updates"] + actions = [u["action"] for u in updates] + + # Since the schema is different, add-schema and set-current-schema must be present + assert "add-schema" in actions + assert "set-current-schema" in actions + + # set-current-schema should reference -1 (the last added schema) + set_schema = next(u for u in updates if u["action"] == "set-current-schema") + assert set_schema["schema-id"] == -1 + + +def test_replace_table_transaction_with_sort_order( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + """Test that replacing with a custom sort order includes add-sort-order and set-default-sort-order.""" + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_with_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables/fokko2", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + txn = catalog.replace_table_transaction( + identifier=("fokko", "fokko2"), + schema=Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ), + sort_order=SortOrder(SortField(source_id=1, transform=IdentityTransform())), + ) + txn.commit_transaction() + + updates = rest_mock.last_request.json()["updates"] + actions = [u["action"] for u in updates] + + # Should include add-sort-order and set-default-sort-order + assert "add-sort-order" in actions + assert "set-default-sort-order" in actions diff --git a/tests/integration/test_rest_catalog.py b/tests/integration/test_rest_catalog.py index 18aa943175..637fedcaf6 100644 --- a/tests/integration/test_rest_catalog.py +++ b/tests/integration/test_rest_catalog.py @@ -16,10 +16,16 @@ # under the License. # pylint:disable=redefined-outer-name +import pyarrow as pa import pytest from pytest_lazy_fixtures import lf +from pyiceberg.catalog import Catalog from pyiceberg.catalog.rest import RestCatalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.io.pyarrow import _dataframe_to_data_files +from pyiceberg.schema import Schema +from pyiceberg.types import BooleanType, IntegerType, LongType, NestedField, StringType TEST_NAMESPACE_IDENTIFIER = "TEST NS" @@ -62,3 +68,475 @@ def test_create_namespace_if_already_existing(catalog: RestCatalog) -> None: catalog.create_namespace_if_not_exists(TEST_NAMESPACE_IDENTIFIER) assert catalog.namespace_exists(TEST_NAMESPACE_IDENTIFIER) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_replace_table_transaction(catalog: Catalog, format_version: int) -> None: + identifier = f"default.test_replace_table_txn_{catalog.name}_{format_version}" + try: + catalog.create_namespace("default") + except Exception: + pass + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + # Create a table with initial schema and write some data + original_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + original = catalog.create_table(identifier, schema=original_schema, properties={"format-version": str(format_version)}) + original_uuid = original.metadata.table_uuid + + pa_table = pa.Table.from_pydict( + {"id": [1, 2, 3], "data": ["a", "b", "c"]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.large_string())]), + ) + + with original.transaction() as txn: + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=original.io): + snapshot_update.append_data_file(data_file) + + original.refresh() + current_snapshot = original.current_snapshot() + assert current_snapshot is not None + original_snapshot_id = current_snapshot.snapshot_id + + # Replace with a new schema + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="name", field_type=StringType(), required=False), + NestedField(field_id=3, name="active", field_type=BooleanType(), required=False), + ) + + with catalog.replace_table_transaction( + identifier, schema=new_schema, properties={"format-version": str(format_version)} + ) as txn: + pass # just replace the schema, no data + + table = catalog.load_table(identifier) + + # UUID must be preserved + assert table.metadata.table_uuid == original_uuid + + # Current snapshot should be cleared (main ref removed) + assert table.current_snapshot() is None + + # Old snapshots should still exist in the metadata + assert len(table.metadata.snapshots) >= 1 + assert any(s.snapshot_id == original_snapshot_id for s in table.metadata.snapshots) + + # New schema should be current, with field IDs reused for "id" (should still be 1) + current_schema = table.schema() + id_field = current_schema.find_field("id") + assert id_field.field_id == 1 # reused from original schema + + name_field = current_schema.find_field("name") + assert name_field is not None + assert name_field.field_id >= 3 # "name" is new, must not reuse "data"'s ID (2) + + active_field = current_schema.find_field("active") + assert active_field is not None + assert active_field.field_id >= 4 # "active" is new + + # last_column_id must account for the new fields + assert table.metadata.last_column_id >= 4 + + # Old schemas should still exist — exactly 2 (original + replacement) + assert len(table.metadata.schemas) == 2 + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +def test_replace_table(catalog: Catalog) -> None: + identifier = f"default.test_replace_table_{catalog.name}" + try: + catalog.create_namespace("default") + except Exception: + pass + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + original_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + original = catalog.create_table(identifier, schema=original_schema) + original_uuid = original.metadata.table_uuid + + new_schema = Schema( + NestedField(field_id=1, name="x", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="y", field_type=IntegerType(), required=False), + ) + + result = catalog.replace_table(identifier, schema=new_schema) + + # UUID preserved + assert result.metadata.table_uuid == original_uuid + # New schema applied — "x" and "y" are entirely new names, so they get fresh IDs >= 3 + x_field = result.schema().find_field("x") + y_field = result.schema().find_field("y") + assert x_field is not None + assert y_field is not None + assert x_field.field_id >= 3 + assert y_field.field_id >= 4 + assert result.metadata.last_column_id >= 4 + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +def test_replace_table_not_found(catalog: Catalog) -> None: + with pytest.raises(NoSuchTableError): + catalog.replace_table( + "default.does_not_exist_for_replace", + schema=Schema(NestedField(field_id=1, name="id", field_type=LongType(), required=False)), + ) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +def test_replace_table_same_schema_no_duplication(catalog: Catalog) -> None: + """Replacing a table with the exact same schema should succeed without adding duplicates. + + The code detects that the schema already exists and skips AddSchemaUpdate, + matching Java's reuseOrCreateNewSchemaId behavior. + """ + identifier = f"default.test_replace_same_schema_{catalog.name}" + try: + catalog.create_namespace("default") + except Exception: + pass + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + schema_a = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + catalog.create_table(identifier, schema=schema_a) + + # Replace with the SAME schema — should succeed without adding a duplicate + catalog.replace_table(identifier, schema=schema_a) + + table = catalog.load_table(identifier) + # Should still have exactly 1 schema (no duplicate added) + assert len(table.metadata.schemas) == 1 + # Current schema ID should remain 0 (the original) + assert table.metadata.current_schema_id == 0 + # Schema should be unchanged + assert len(table.schema().fields) == 2 + assert table.schema().find_field("id") is not None + assert table.schema().find_field("data") is not None + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +def test_replace_table_back_to_previous_schema(catalog: Catalog) -> None: + """Replacing A -> B -> A where A and B have disjoint fields. + + Since field IDs are reused from the current schema only (matching Java), + "data" gets a new field ID when replacing back from B (which doesn't have "data"). + The resulting schema is structurally different from the original, so a 3rd schema + is created. This matches Java's behavior. + """ + identifier = f"default.test_replace_back_to_prev_schema_{catalog.name}" + try: + catalog.create_namespace("default") + except Exception: + pass + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + schema_a = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + catalog.create_table(identifier, schema=schema_a) + + # Step 2: Replace with schema B (disjoint fields: "name" and "active" instead of "data") + schema_b = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="name", field_type=StringType(), required=False), + NestedField(field_id=3, name="active", field_type=BooleanType(), required=False), + ) + catalog.replace_table(identifier, schema=schema_b) + + table_after_b = catalog.load_table(identifier) + assert len(table_after_b.metadata.schemas) == 2 + + # Step 3: Replace BACK to schema A + catalog.replace_table(identifier, schema=schema_a) + + table = catalog.load_table(identifier) + # "data" gets a new field ID (not reused from historical schema A), so a 3rd schema is created + assert len(table.metadata.schemas) == 3 + # last_column_id must be monotonically non-decreasing + assert table.metadata.last_column_id >= table_after_b.metadata.last_column_id + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +def test_replace_table_last_column_id_monotonic(catalog: Catalog) -> None: + """last_column_id must never decrease, even when replacing with fewer columns.""" + identifier = f"default.test_replace_last_col_id_{catalog.name}" + try: + catalog.create_namespace("default") + except Exception: + pass + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + # Create table with 5 columns + schema_5col = Schema( + NestedField(field_id=1, name="a", field_type=LongType(), required=False), + NestedField(field_id=2, name="b", field_type=StringType(), required=False), + NestedField(field_id=3, name="c", field_type=StringType(), required=False), + NestedField(field_id=4, name="d", field_type=StringType(), required=False), + NestedField(field_id=5, name="e", field_type=StringType(), required=False), + ) + catalog.create_table(identifier, schema=schema_5col) + + table = catalog.load_table(identifier) + initial_last_col_id = table.metadata.last_column_id + assert initial_last_col_id >= 5, f"Initial last_column_id should be >= 5, got {initial_last_col_id}" + + # Replace with only 2 columns (subset) + schema_2col = Schema( + NestedField(field_id=1, name="a", field_type=LongType(), required=False), + NestedField(field_id=2, name="b", field_type=StringType(), required=False), + ) + catalog.replace_table(identifier, schema=schema_2col) + + table = catalog.load_table(identifier) + after_shrink_last_col_id = table.metadata.last_column_id + # last_column_id must NOT decrease + assert after_shrink_last_col_id >= initial_last_col_id, ( + f"last_column_id decreased from {initial_last_col_id} to {after_shrink_last_col_id} " + f"after replacing with fewer columns. It must be monotonically non-decreasing." + ) + + # Replace with 3 columns (2 existing + 1 new) + schema_3col = Schema( + NestedField(field_id=1, name="a", field_type=LongType(), required=False), + NestedField(field_id=2, name="b", field_type=StringType(), required=False), + NestedField(field_id=3, name="f", field_type=BooleanType(), required=False), # new column + ) + catalog.replace_table(identifier, schema=schema_3col) + + table = catalog.load_table(identifier) + after_grow_last_col_id = table.metadata.last_column_id + # New column should get an ID > previous last_column_id, so last_column_id should grow + assert after_grow_last_col_id >= initial_last_col_id + 1, ( + f"last_column_id should be >= {initial_last_col_id + 1} after adding a new column, got {after_grow_last_col_id}" + ) + + # Verify the new column "f" got an ID > initial_last_col_id (not reusing a dropped column's ID) + f_field = table.schema().find_field("f") + assert f_field.field_id > initial_last_col_id, ( + f"New field 'f' got field_id={f_field.field_id}, but it should be > {initial_last_col_id} to maintain monotonicity" + ) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +def test_replace_table_metadata_log_grows(catalog: Catalog) -> None: + """After replacing a table, the metadata_log should contain the pre-replace metadata.""" + identifier = f"default.test_replace_metadata_log_{catalog.name}" + try: + catalog.create_namespace("default") + except Exception: + pass + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + schema_a = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + catalog.create_table(identifier, schema=schema_a) + + table_before = catalog.load_table(identifier) + metadata_log_before = len(table_before.metadata.metadata_log) + + # Replace with a different schema + schema_b = Schema( + NestedField(field_id=1, name="x", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="y", field_type=IntegerType(), required=False), + ) + catalog.replace_table(identifier, schema=schema_b) + + table_after = catalog.load_table(identifier) + metadata_log_after = len(table_after.metadata.metadata_log) + + # The metadata_log should have grown by at least 1 entry + assert metadata_log_after > metadata_log_before, ( + f"metadata_log did not grow after replace_table. " + f"Before: {metadata_log_before} entries, After: {metadata_log_after} entries. " + f"The pre-replace metadata location should have been appended." + ) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +def test_replace_table_snapshot_preserved_after_replace(catalog: Catalog) -> None: + """Snapshots are preserved but current snapshot is cleared after replace.""" + identifier = f"default.test_replace_snapshot_preserved_{catalog.name}" + try: + catalog.create_namespace("default") + except Exception: + pass + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + table = catalog.create_table(identifier, schema=schema) + + # Write data to create a snapshot + pa_table = pa.Table.from_pydict( + {"id": [1, 2], "data": ["a", "b"]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.large_string())]), + ) + with table.transaction() as txn: + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=table.io): + snapshot_update.append_data_file(data_file) + + table.refresh() + current_snapshot = table.current_snapshot() + assert current_snapshot is not None + original_snapshot_id = current_snapshot.snapshot_id + original_snapshot_log_len = len(table.metadata.snapshot_log) + + # Replace with new schema + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="value", field_type=StringType(), required=False), + ) + catalog.replace_table(identifier, schema=new_schema) + + replaced = catalog.load_table(identifier) + + # Current snapshot cleared (main ref removed) + assert replaced.current_snapshot() is None + + # Old snapshot still in metadata + assert any(s.snapshot_id == original_snapshot_id for s in replaced.metadata.snapshots) + + # Snapshot log preserved + assert len(replaced.metadata.snapshot_log) >= original_snapshot_log_len + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +def test_replace_table_with_partition_spec(catalog: Catalog) -> None: + """Replace table with a new partition spec preserves old spec in metadata.""" + from pyiceberg.partitioning import PartitionField, PartitionSpec + from pyiceberg.transforms import IdentityTransform + + identifier = f"default.test_replace_with_spec_{catalog.name}" + try: + catalog.create_namespace("default") + except Exception: + pass + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + catalog.create_table(identifier, schema=schema) + + # Replace with a partition spec on "id" + new_spec = PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), + spec_id=0, + ) + catalog.replace_table(identifier, schema=schema, partition_spec=new_spec) + + table = catalog.load_table(identifier) + + # New spec is the default + current_spec = table.metadata.spec() + assert len(current_spec.fields) == 1 + assert current_spec.fields[0].name == "id" + + # Old unpartitioned spec still in metadata + assert len(table.metadata.partition_specs) >= 2 + + # last_partition_id should be correct + assert table.metadata.last_partition_id is not None and table.metadata.last_partition_id >= 1000 + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog")]) +def test_replace_table_sequential_replaces(catalog: Catalog) -> None: + """Multiple sequential replaces: schemas grow, last_column_id is monotonic, metadata_log grows.""" + identifier = f"default.test_replace_sequential_{catalog.name}" + try: + catalog.create_namespace("default") + except Exception: + pass + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + schema_a = Schema( + NestedField(field_id=1, name="a", field_type=LongType(), required=False), + NestedField(field_id=2, name="b", field_type=StringType(), required=False), + ) + catalog.create_table(identifier, schema=schema_a) + + table = catalog.load_table(identifier) + prev_last_col_id = table.metadata.last_column_id + prev_metadata_log_len = len(table.metadata.metadata_log) + + # Replace 1: A -> B (completely different fields) + schema_b = Schema( + NestedField(field_id=1, name="x", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="y", field_type=IntegerType(), required=False), + NestedField(field_id=3, name="z", field_type=BooleanType(), required=False), + ) + catalog.replace_table(identifier, schema=schema_b) + + table = catalog.load_table(identifier) + assert table.metadata.last_column_id >= prev_last_col_id + assert len(table.metadata.schemas) == 2 + assert len(table.metadata.metadata_log) > prev_metadata_log_len + prev_last_col_id = table.metadata.last_column_id + prev_metadata_log_len = len(table.metadata.metadata_log) + + # Replace 2: B -> C (again different) + schema_c = Schema( + NestedField(field_id=1, name="p", field_type=StringType(), required=False), + NestedField(field_id=2, name="q", field_type=LongType(), required=False), + ) + catalog.replace_table(identifier, schema=schema_c) + + table = catalog.load_table(identifier) + assert table.metadata.last_column_id >= prev_last_col_id + assert len(table.metadata.schemas) == 3 # A, B, C all have different fields + assert len(table.metadata.metadata_log) > prev_metadata_log_len diff --git a/tests/table/test_partitioning.py b/tests/table/test_partitioning.py index a27046ef30..e7a4c5fc00 100644 --- a/tests/table/test_partitioning.py +++ b/tests/table/test_partitioning.py @@ -22,7 +22,12 @@ import pytest from pyiceberg.exceptions import ValidationError -from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec +from pyiceberg.partitioning import ( + UNPARTITIONED_PARTITION_SPEC, + PartitionField, + PartitionSpec, + assign_fresh_partition_spec_ids_for_replace, +) from pyiceberg.schema import Schema from pyiceberg.transforms import ( BucketTransform, @@ -298,3 +303,50 @@ def test_incompatible_transform_source_type() -> None: spec.check_compatible(schema) assert "Invalid source field foo with type int for transform: year" in str(exc.value) + + +def test_assign_fresh_partition_spec_ids_for_replace_reuses_ids() -> None: + """Test that partition field IDs are reused from existing specs.""" + old_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + fresh_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + existing_specs = [ + PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), + spec_id=0, + ) + ] + spec = PartitionSpec( + PartitionField(source_id=1, field_id=999, transform=IdentityTransform(), name="id"), + spec_id=0, + ) + fresh_spec, last_pid = assign_fresh_partition_spec_ids_for_replace(spec, old_schema, fresh_schema, existing_specs, 1000) + assert fresh_spec.fields[0].field_id == 1000 # reused from existing spec + assert last_pid == 1000 + + +def test_assign_fresh_partition_spec_ids_for_replace_new_field() -> None: + """Test that new partition fields get IDs above last_partition_id.""" + old_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + fresh_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + existing_specs = [ + PartitionSpec(spec_id=0) # unpartitioned + ] + spec = PartitionSpec( + PartitionField(source_id=1, field_id=999, transform=IdentityTransform(), name="id"), + spec_id=0, + ) + fresh_spec, last_pid = assign_fresh_partition_spec_ids_for_replace(spec, old_schema, fresh_schema, existing_specs, 999) + assert fresh_spec.fields[0].field_id == 1000 # new, above last_partition_id=999 + assert last_pid == 1000 diff --git a/tests/test_schema.py b/tests/test_schema.py index 93ddc16202..3937354a5d 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -26,6 +26,7 @@ Accessor, Schema, _check_schema_compatible, + assign_fresh_schema_ids_for_replace, build_position_accessors, index_by_id, index_by_name, @@ -1815,3 +1816,90 @@ def test_check_schema_compatible_optional_map_field_present() -> None: ) # Should not raise - schemas match _check_schema_compatible(requested_schema, provided_schema) + + +def test_assign_fresh_schema_ids_for_replace_reuses_ids() -> None: + """Test that field IDs are reused from the base schema by name.""" + base_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + new_schema = Schema( + NestedField(field_id=10, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=20, name="data", field_type=StringType(), required=False), + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 2) + assert fresh.fields[0].field_id == 1 # reused from base + assert fresh.fields[1].field_id == 2 # reused from base + assert last_col_id == 2 # no new columns added + + +def test_assign_fresh_schema_ids_for_replace_assigns_new_ids_for_new_fields() -> None: + """Test that new fields get IDs above last_column_id.""" + base_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + new_schema = Schema( + NestedField(field_id=10, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=20, name="data", field_type=StringType(), required=False), + NestedField(field_id=30, name="new_col", field_type=BooleanType(), required=False), + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 2) + assert fresh.fields[0].field_id == 1 # reused + assert fresh.fields[1].field_id == 2 # reused + assert fresh.fields[2].field_id == 3 # new, starts after last_column_id=2 + assert last_col_id == 3 + + +def test_assign_fresh_schema_ids_for_replace_with_nested_struct() -> None: + """Test that nested struct field IDs are reused by full path name.""" + base_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField( + field_id=2, + name="location", + field_type=StructType( + NestedField(field_id=3, name="lat", field_type=FloatType(), required=False), + NestedField(field_id=4, name="lon", field_type=FloatType(), required=False), + ), + required=False, + ), + ) + new_schema = Schema( + NestedField(field_id=10, name="id", field_type=IntegerType(), required=False), + NestedField( + field_id=20, + name="location", + field_type=StructType( + NestedField(field_id=30, name="lat", field_type=FloatType(), required=False), + NestedField(field_id=40, name="lon", field_type=FloatType(), required=False), + NestedField(field_id=50, name="alt", field_type=FloatType(), required=False), + ), + required=False, + ), + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 4) + assert fresh.fields[0].field_id == 1 # id reused + assert fresh.fields[1].field_id == 2 # location reused + loc_fields = fresh.fields[1].field_type.fields + assert loc_fields[0].field_id == 3 # location.lat reused + assert loc_fields[1].field_id == 4 # location.lon reused + assert loc_fields[2].field_id == 5 # location.alt is new + assert last_col_id == 5 + + +def test_assign_fresh_schema_ids_for_replace_completely_new_schema() -> None: + """Test that a completely new schema gets IDs starting from last_column_id + 1.""" + base_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + new_schema = Schema( + NestedField(field_id=10, name="x", field_type=IntegerType(), required=False), + NestedField(field_id=20, name="y", field_type=IntegerType(), required=False), + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 2) + assert fresh.fields[0].field_id == 3 # starts after last_column_id=2 + assert fresh.fields[1].field_id == 4 + assert last_col_id == 4