From 214939e88806f8a8053127edfaca7cc5cc94d12d Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 13 Jun 2026 18:06:54 +0000 Subject: [PATCH 1/4] feat: add cloud job session controls --- sqlspec/adapters/bigquery/config.py | 4 + sqlspec/adapters/bigquery/core.py | 52 +++- sqlspec/adapters/bigquery/driver.py | 104 ++++++- sqlspec/adapters/spanner/adk/store.py | 13 + sqlspec/adapters/spanner/config.py | 25 +- sqlspec/adapters/spanner/driver.py | 169 ++++++++++- .../adapters/test_bigquery/test_config.py | 1 + .../test_bigquery/test_job_controls.py | 245 ++++++++++++++++ .../unit/adapters/test_spanner/test_config.py | 8 +- .../test_spanner/test_session_controls.py | 267 ++++++++++++++++++ 10 files changed, 861 insertions(+), 27 deletions(-) create mode 100644 tests/unit/adapters/test_bigquery/test_job_controls.py create mode 100644 tests/unit/adapters/test_spanner/test_session_controls.py diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 6a0916806..5b864726b 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -94,6 +94,9 @@ class BigQueryDriverFeatures(TypedDict): job_result_timeout: Timeout (seconds) for polling ``QueryJob.result()``. Defaults to the client polling default (waits indefinitely for the job using the API's per-call default timeouts). Also used as the per-request HTTP timeout when ``request_timeout`` is unset. + use_query_and_wait: Use ``Client.query_and_wait()`` for single-statement queries. + Pair with ``connection_config["default_job_creation_mode"]="JOB_CREATION_OPTIONAL"`` + to allow short queries to run without creating a job. Defaults to False. request_timeout: Per-request HTTP transport timeout (seconds) for the API calls that start query jobs. Bounds each request so a server that accepts the request but never responds (e.g. a wedged emulator) raises instead of blocking indefinitely. Defaults to @@ -111,6 +114,7 @@ class BigQueryDriverFeatures(TypedDict): events_backend: NotRequired[str] job_retry_deadline: NotRequired[float] job_result_timeout: NotRequired[float] + use_query_and_wait: NotRequired[bool] request_timeout: NotRequired[float] query_page_size: NotRequired[int] query_max_results: NotRequired[int] diff --git a/sqlspec/adapters/bigquery/core.py b/sqlspec/adapters/bigquery/core.py index ef8091441..f058b24ac 100644 --- a/sqlspec/adapters/bigquery/core.py +++ b/sqlspec/adapters/bigquery/core.py @@ -88,6 +88,7 @@ def _check_pending_exception(self, exc_handler: "SyncExceptionHandler") -> None: "normalize_script_rowcount", "resolve_column_names", "run_query_job", + "run_query_and_wait", "storage_api_available", "try_bulk_insert", ) @@ -170,6 +171,7 @@ def try_bulk_insert( expression: "exp.Expr | None" = None, *, allow_parse: bool = True, + result_timeout: float | None = None, ) -> "int | None": """Attempt bulk insert via Parquet load. @@ -204,8 +206,8 @@ def try_bulk_insert( buffer.seek(0) job_config = build_load_job_config("parquet", overwrite=False) - job = connection.load_table_from_file(buffer, table_name, job_config=job_config) - job.result() + job = connection.load_table_from_file(buffer, table_name, job_config=job_config, timeout=result_timeout) + job.result(timeout=result_timeout) return len(parameters) except ImportError: logger.debug("pyarrow not available, falling back to literal inlining") @@ -535,6 +537,10 @@ def run_query_job( retry: Retry | None = None, timeout: float | None = None, job_retry: Retry | None = None, + api_method: str | None = None, + timestamp_precision: Any | None = None, + job_id: str | None = None, + job_id_prefix: str | None = None, ) -> QueryJob: """Execute a BigQuery query job with merged configuration. @@ -569,9 +575,51 @@ def run_query_job( "timeout": timeout, "job_retry": job_retry, } + if api_method is not None: + query_kwargs["api_method"] = api_method + if timestamp_precision is not None: + query_kwargs["timestamp_precision"] = timestamp_precision + if job_id is not None: + query_kwargs["job_id"] = job_id + elif job_id_prefix is not None: + query_kwargs["job_id_prefix"] = job_id_prefix return connection.query(sql, **query_kwargs) +def run_query_and_wait( + connection: "BigQueryConnection", + sql: str, + parameters: Any, + *, + default_job_config: QueryJobConfig | None, + json_serializer: "Callable[[Any], str]", + retry: Retry | None = None, + wait_timeout: float | None = None, + job_retry: Retry | None = None, + page_size: int | None = None, + max_results: int | None = None, +) -> Any: + """Execute a BigQuery query via query_and_wait and return the row iterator.""" + final_job_config = QueryJobConfig() + if default_job_config: + copy_job_config(default_job_config, final_job_config) + final_job_config.query_parameters = create_parameters(parameters, json_serializer) + + query_kwargs: dict[str, Any] = {"job_config": final_job_config} + if retry is not None: + query_kwargs["retry"] = retry + if wait_timeout is not None: + query_kwargs["api_timeout"] = wait_timeout + query_kwargs["wait_timeout"] = wait_timeout + if job_retry is not None: + query_kwargs["job_retry"] = job_retry + if page_size is not None: + query_kwargs["page_size"] = page_size + if max_results is not None: + query_kwargs["max_results"] = max_results + return connection.query_and_wait(sql, **query_kwargs) + + def build_load_job_config(file_format: "StorageFormat", overwrite: bool) -> "LoadJobConfig": job_config = LoadJobConfig() job_config.source_format = _map_bigquery_source_format(file_format) diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index a3f381f9c..aba59ae6c 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -30,6 +30,7 @@ normalize_script_rowcount, resolve_column_names, run_query_job, + run_query_and_wait, storage_api_available, try_bulk_insert, ) @@ -53,10 +54,10 @@ from google.api_core.retry import Retry from google.cloud import bigquery_storage # type: ignore[attr-defined, unused-ignore] - from google.cloud.bigquery import QueryJob, QueryJobConfig + from google.cloud.bigquery import ExtractJob, ExtractJobConfig, QueryJob, QueryJobConfig from sqlspec.builder import QueryBuilder - from sqlspec.core import SQL, ArrowResult, SQLResult, Statement, StatementFilter + from sqlspec.core import SQL, ArrowResult, Statement, StatementFilter from sqlspec.storage import ( StorageBridgeJob, StorageDestination, @@ -104,13 +105,14 @@ class BigQueryDriver(SyncDriverAdapterBase): "_column_name_cache", "_data_dictionary", "_default_query_job_config", - "_job_result_kwargs", + "_job_result_kwargs_defaults", "_job_result_timeout", "_job_retry", "_job_retry_deadline", "_json_serializer", "_literal_inliner", "_request_timeout", + "_use_query_and_wait", ) dialect = "bigquery" @@ -138,11 +140,12 @@ def __init__( self._default_query_job_config: QueryJobConfig | None = (driver_features or {}).get("default_query_job_config") self._data_dictionary: BigQueryDataDictionary | None = None self._column_name_cache: dict[int, tuple[Any, list[str]]] = {} - self._job_result_kwargs = self._build_job_result_kwargs(features) + self._job_result_kwargs_defaults = self._build_job_result_kwargs(features) self._job_retry_deadline = float(features.get("job_retry_deadline", 60.0)) self._job_retry: Retry | None = build_retry(self._job_retry_deadline) if self._job_retry_deadline > 0 else None self._job_result_timeout: float | object = features.get("job_result_timeout", POLLING_DEFAULT_VALUE) self._request_timeout = self._resolve_request_timeout(features) + self._use_query_and_wait = bool(features.get("use_query_and_wait", False)) def _resolve_request_timeout(self, features: "dict[str, Any]") -> float: timeout = features.get("request_timeout") @@ -163,6 +166,9 @@ def _build_job_result_kwargs(self, features: dict[str, Any]) -> dict[str, Any]: job_result_kwargs["max_results"] = query_max_results return job_result_kwargs + def _job_request_timeout(self) -> float: + return self._request_timeout + def _run_query_job(self, connection: "BigQueryConnection", sql: str, parameters: Any) -> "QueryJob": return run_query_job( connection, @@ -172,10 +178,13 @@ def _run_query_job(self, connection: "BigQueryConnection", sql: str, parameters: job_config=None, json_serializer=self._json_serializer, retry=self._job_retry, - timeout=self._request_timeout, + timeout=self._job_request_timeout(), job_retry=self._job_retry, ) + def _job_result_kwargs(self) -> dict[str, Any]: + return dict(self._job_result_kwargs_defaults) + # ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── @@ -191,6 +200,35 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult: ExecutionResult with query results and metadata """ sql, parameters = self._get_compiled_sql(statement, self.statement_config) + if self._use_query_and_wait: + row_iterator = run_query_and_wait( + cursor, + sql, + parameters, + default_job_config=self._default_query_job_config, + json_serializer=self._json_serializer, + retry=self._job_retry, + wait_timeout=self._job_request_timeout(), + job_retry=self._job_retry, + ) + cursor.job = None + iterator_schema = getattr(row_iterator, "schema", None) + if statement.returns_rows() or iterator_schema: + column_names = resolve_column_names(iterator_schema, self._column_name_cache) + rows_list, _ = collect_rows(row_iterator, iterator_schema, column_names=column_names) + + return self.create_execution_result( + cursor, + selected_data=rows_list, + column_names=column_names, + data_row_count=len(rows_list), + is_select_result=True, + row_format="record", + ) + + affected_rows = build_dml_rowcount(row_iterator, 0) + return self.create_execution_result(cursor, rowcount_override=affected_rows) + cursor.job = self._run_query_job(cursor, sql, parameters) statement_type = str(cursor.job.statement_type or "").upper() is_select_like = ( @@ -198,9 +236,7 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult: ) if is_select_like: - job_result = cursor.job.result( - job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs - ) + job_result = cursor.job.result(job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs()) job_schema = cursor.job.schema or getattr(job_result, "schema", None) column_names = resolve_column_names(job_schema, self._column_name_cache) rows_list, _ = collect_rows(job_result, job_schema, column_names=column_names) @@ -246,7 +282,12 @@ def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> ExecutionResul allow_parse = statement.statement_config.enable_parsing if is_simple_insert(sql, parsed_expression, allow_parse=allow_parse): rowcount = try_bulk_insert( - self.connection, sql, prepared_parameters, parsed_expression, allow_parse=allow_parse + self.connection, + sql, + prepared_parameters, + parsed_expression, + allow_parse=allow_parse, + result_timeout=self._job_request_timeout(), ) if rowcount is not None: return self.create_execution_result(cursor, rowcount_override=rowcount, is_many_result=True) @@ -446,7 +487,7 @@ def select_to_arrow( with exc_handler: query_job = self._run_query_job(self.connection, sql, driver_params) query_job.result( - job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs + job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs() ) # Wait for completion # Native Arrow via Storage API @@ -528,8 +569,13 @@ def load_from_arrow( pq.write_table(arrow_table, buffer) buffer.seek(0) job_config = build_load_job_config("parquet", overwrite) - job = self.connection.load_table_from_file(buffer, table, job_config=job_config) - job.result() + job = self.connection.load_table_from_file( + buffer, + table, + job_config=job_config, + timeout=self._job_request_timeout(), + ) + job.result(timeout=self._job_request_timeout()) telemetry_payload = build_load_job_telemetry(job, table, format_label="parquet") if telemetry: telemetry_payload.setdefault("extra", {}) @@ -552,12 +598,42 @@ def load_from_storage( msg = "BigQuery storage bridge currently supports Parquet ingest only" raise StorageCapabilityError(msg, capability="parquet_import_enabled") job_config = build_load_job_config(file_format, overwrite) - job = self.connection.load_table_from_uri(source, table, job_config=job_config) - job.result() + job = self.connection.load_table_from_uri( + source, + table, + job_config=job_config, + retry=self._job_retry, + timeout=self._job_request_timeout(), + ) + job.result(timeout=self._job_request_timeout()) telemetry_payload = build_load_job_telemetry(job, table, format_label=file_format) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload) + def export_table_to_storage( + self, + table: str, + destination_uris: "str | list[str]", + *, + job_config: "ExtractJobConfig | None" = None, + job_id: "str | None" = None, + job_id_prefix: "str | None" = None, + location: "str | None" = None, + ) -> "ExtractJob": + """Export a BigQuery table to Cloud Storage via an extract job.""" + job = self.connection.extract_table( + table, + destination_uris, + job_config=job_config, + job_id=job_id, + job_id_prefix=job_id_prefix, + location=location, + retry=self._job_retry, + timeout=self._job_request_timeout(), + ) + job.result(timeout=self._job_request_timeout()) + return job + # ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index b3b8e030b..c94c1c695 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -7,6 +7,7 @@ from google.cloud.spanner_v1 import param_types from sqlspec.adapters.spanner.config import SpannerSyncConfig +from sqlspec.exceptions import OperationalError from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.protocols import SpannerParamTypesProtocol @@ -726,6 +727,12 @@ def __init__(self, statements: "list[tuple[str, dict[str, Any], dict[str, Any]]] self._statements = statements def __call__(self, transaction: "Transaction") -> None: + if len(self._statements) > 1: + status, _row_counts = transaction.batch_update(self._statements) # type: ignore[no-untyped-call] + if status.code != 0: + msg = f"Spanner batch update failed (code {status.code}): {status.message}" + raise OperationalError(msg) + return for sql, params, types in self._statements: transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call] @@ -737,6 +744,12 @@ def __init__(self, statements: "list[tuple[str, dict[str, Any], dict[str, Any]]] self._statements = statements def __call__(self, transaction: "Transaction") -> None: + if len(self._statements) > 1: + status, _row_counts = transaction.batch_update(self._statements) # type: ignore[no-untyped-call] + if status.code != 0: + msg = f"Spanner batch update failed (code {status.code}): {status.message}" + raise OperationalError(msg) + return for sql, params, types in self._statements: transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call] diff --git a/sqlspec/adapters/spanner/config.py b/sqlspec/adapters/spanner/config.py index 6b376845a..6716279b6 100644 --- a/sqlspec/adapters/spanner/config.py +++ b/sqlspec/adapters/spanner/config.py @@ -1,5 +1,6 @@ """Spanner configuration.""" +from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from google.cloud.spanner_v1 import Client @@ -28,7 +29,9 @@ from google.cloud.spanner_admin_database_v1.types import DatabaseDialect, EncryptionConfig from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest from google.cloud.spanner_v1.database import Database + from google.cloud.spanner_v1.database import BatchSnapshot from google.cloud.spanner_v1.transaction import DefaultTransactionOptions + from collections.abc import Iterator from sqlspec.config import ExtensionConfigs from sqlspec.core import StatementConfig @@ -137,6 +140,9 @@ class SpannerDriverFeatures(TypedDict): json_deserializer: Custom JSON deserializer for result conversion. retry: Per-request retry policy passed to execute_sql(), execute_update(), and batch_update(). timeout: Per-request timeout in seconds passed to execute_sql(), execute_update(), and batch_update(). + request_options: Default RequestOptions forwarded to execute_sql(), execute_update(), + and batch_update(). Per-call overrides are available via + SpannerSyncDriver.execute_with_options(). session_labels: Deprecated compatibility alias for pool session labels. Prefer ``connection_config["session_labels"]``. enable_events: Enable database event channel support. @@ -150,6 +156,7 @@ class SpannerDriverFeatures(TypedDict): json_deserializer: "NotRequired[Callable[[str], Any]]" retry: "NotRequired[Retry | None]" timeout: "NotRequired[float | None]" + request_options: "NotRequired[Any]" session_labels: "NotRequired[dict[str, str]]" enable_events: "NotRequired[bool]" events_backend: "NotRequired[str]" @@ -419,12 +426,14 @@ def provide_session( """ connection_ctx = SpannerConnectionContext(self, transaction=transaction) handler = _SpannerSessionConnectionHandler(self, connection_ctx) + session_driver_features: dict[str, Any] = dict(self.driver_features) + session_driver_features["database_provider"] = self.get_database return SpannerSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, + driver_features=session_driver_features, prepare_driver=self._prepare_driver, ) @@ -444,6 +453,20 @@ def provide_read_session( """ return self.provide_session(*args, statement_config=statement_config, transaction=False, **kwargs) + @contextmanager + def provide_batch_snapshot( + self, *, read_timestamp: "Any | None" = None, exact_staleness: "Any | None" = None + ) -> "Iterator[BatchSnapshot]": + """Yield a BatchSnapshot for partitioned reads across parallel workers.""" + snapshot = self.get_database().batch_snapshot( + read_timestamp=read_timestamp, + exact_staleness=exact_staleness, + ) + try: + yield snapshot + finally: + snapshot.close() + def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for SpannerSyncConfig types. diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index 8a35850e8..bc6d93581 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -5,6 +5,7 @@ import sqlglot as _sqlglot from google.api_core import exceptions as api_exceptions +from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.transaction import Transaction from sqlglot import exp as _sqlglot_exp @@ -26,7 +27,7 @@ from sqlspec.adapters.spanner.type_converter import SpannerOutputConverter from sqlspec.core import StatementConfig, create_arrow_result, register_driver_profile from sqlspec.driver import BaseSyncExceptionHandler, ExecutionResult, SyncDriverAdapterBase -from sqlspec.exceptions import SQLConversionError +from sqlspec.exceptions import ImproperConfigurationError, SQLConversionError from sqlspec.utils.serializers import from_json _READ_ONLY_SNAPSHOT_ERROR_MESSAGE = ( @@ -37,11 +38,12 @@ if TYPE_CHECKING: from collections.abc import Callable + from collections.abc import Sequence from sqlglot.dialects.dialect import DialectType from sqlspec.adapters.spanner._typing import SpannerConnection - from sqlspec.core import ArrowResult + from sqlspec.core import ArrowResult, SQLResult from sqlspec.core.statement import SQL from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry from sqlspec.typing import ArrowReturnFormat @@ -111,11 +113,30 @@ def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "Ba return False +class _PerCallExecuteOptions: + """Per-call Spanner execution options captured for a single dispatch.""" + + __slots__ = ("directed_read_options", "request_options", "retry", "timeout") + + def __init__( + self, + *, + request_options: "Any | None" = None, + directed_read_options: "Any | None" = None, + retry: "Any | None" = None, + timeout: "float | None" = None, + ) -> None: + self.request_options = request_options + self.directed_read_options = directed_read_options + self.retry = retry + self.timeout = timeout + + class SpannerSyncDriver(SyncDriverAdapterBase): """Synchronous Spanner driver operating on Snapshot or Transaction contexts.""" dialect: "DialectType" = "spanner" - __slots__ = ("_column_name_cache", "_data_dictionary", "_type_converter") + __slots__ = ("_column_name_cache", "_data_dictionary", "_pending_execute_options", "_type_converter") def __init__( self, @@ -136,6 +157,7 @@ def __init__( ) self._column_name_cache: dict[int, tuple[Any, list[str]]] = {} self._data_dictionary: SpannerDataDictionary | None = None + self._pending_execute_options: _PerCallExecuteOptions | None = None # ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS - The Execution Engine @@ -146,10 +168,10 @@ def dispatch_execute(self, cursor: "SpannerConnection", statement: "SQL") -> Exe params = cast("dict[str, Any] | None", params) coerced_params = self._coerce_params(params) param_types_map = self._infer_param_types(coerced_params) - execute_kwargs = self._execute_kwargs() if statement.returns_rows(): reader = cast("_SpannerReadProtocol", cursor) + execute_kwargs = self._execute_kwargs(for_read=True) result_set = reader.execute_sql(sql, params=coerced_params, param_types=param_types_map, **execute_kwargs) rows = list(result_set) try: @@ -174,6 +196,7 @@ def dispatch_execute(self, cursor: "SpannerConnection", statement: "SQL") -> Exe if supports_write(cursor): writer = cast("_SpannerWriteProtocol", cursor) + execute_kwargs = self._execute_kwargs() row_count = writer.execute_update(sql, params=coerced_params, param_types=param_types_map, **execute_kwargs) return self.create_execution_result(cursor, rowcount_override=row_count) @@ -225,7 +248,8 @@ def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL") script_params = cast("dict[str, Any] | None", params) coerced_params = self._coerce_params(script_params) param_types_map = self._infer_param_types(coerced_params) - execute_kwargs = self._execute_kwargs() + read_execute_kwargs = self._execute_kwargs(for_read=True) + write_execute_kwargs = self._execute_kwargs() for stmt in statements: try: parsed = _sqlglot.parse_one(stmt) @@ -236,9 +260,11 @@ def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL") raise SQLConversionError(_READ_ONLY_SNAPSHOT_ERROR_MESSAGE) if not is_select and is_transaction: writer = cast("_SpannerWriteProtocol", cursor) - writer.execute_update(stmt, params=coerced_params, param_types=param_types_map, **execute_kwargs) + writer.execute_update(stmt, params=coerced_params, param_types=param_types_map, **write_execute_kwargs) else: - _ = list(reader.execute_sql(stmt, params=coerced_params, param_types=param_types_map, **execute_kwargs)) + _ = list( + reader.execute_sql(stmt, params=coerced_params, param_types=param_types_map, **read_execute_kwargs) + ) count += 1 return self.create_execution_result( @@ -270,8 +296,133 @@ def with_cursor(self, connection: "SpannerConnection") -> "SpannerSyncCursor": def handle_database_exceptions(self) -> "SpannerExceptionHandler": return SpannerExceptionHandler() - def _execute_kwargs(self) -> dict[str, Any]: - return {key: self.driver_features[key] for key in ("retry", "timeout") if key in self.driver_features} + def _execute_kwargs(self, *, for_read: bool = False) -> dict[str, Any]: + kwargs = {key: self.driver_features[key] for key in ("retry", "timeout") if key in self.driver_features} + request_options = self.driver_features.get("request_options") + if request_options is not None: + kwargs["request_options"] = request_options + pending = self._pending_execute_options + if pending is not None: + if pending.request_options is not None: + kwargs["request_options"] = pending.request_options + if pending.retry is not None: + kwargs["retry"] = pending.retry + if pending.timeout is not None: + kwargs["timeout"] = pending.timeout + if for_read and pending.directed_read_options is not None: + kwargs["directed_read_options"] = pending.directed_read_options + return kwargs + + def execute_with_options( + self, + statement: "Any", + /, + *parameters: "Any", + request_options: "Any | None" = None, + directed_read_options: "Any | None" = None, + retry: "Any | None" = None, + timeout: "float | None" = None, + statement_config: "StatementConfig | None" = None, + **kwargs: Any, + ) -> "SQLResult": + """Execute a single statement with per-call Spanner request options.""" + config = statement_config or self.statement_config + sql_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs) + self._pending_execute_options = _PerCallExecuteOptions( + request_options=request_options, + directed_read_options=directed_read_options, + retry=retry, + timeout=timeout, + ) + try: + return self.dispatch_statement_execution(statement=sql_statement, connection=self.connection) + finally: + self._pending_execute_options = None + + def _require_database(self) -> "Any": + provider = self.driver_features.get("database_provider") + if provider is None: + msg = ( + "Spanner database-level operations require a session created via " + "SpannerSyncConfig.provide_session()." + ) + raise ImproperConfigurationError(msg) + return provider() + + def execute_partitioned_dml( + self, + statement: "Any", + /, + *parameters: "Any", + request_options: "Any | None" = None, + exclude_txn_from_change_streams: bool = False, + statement_config: "StatementConfig | None" = None, + **kwargs: Any, + ) -> int: + """Execute a partitioned DML statement across the whole table.""" + database = self._require_database() + config = statement_config or self.statement_config + sql_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs) + sql, params = self._get_compiled_sql(sql_statement, config) + coerced_params = self._coerce_params(cast("dict[str, Any] | None", params)) + param_types_map = self._infer_param_types(coerced_params) + exc_handler = self.handle_database_exceptions() + row_count = 0 + with exc_handler: + row_count = int( + database.execute_partitioned_dml( + sql, + params=coerced_params, + param_types=param_types_map, + request_options=request_options + if request_options is not None + else self.driver_features.get("request_options"), + exclude_txn_from_change_streams=exclude_txn_from_change_streams, + ) + ) + self._check_pending_exception(exc_handler) + return row_count + + def apply_mutations( + self, + table: str, + *, + columns: "Sequence[str] | None" = None, + insert: "Sequence[Sequence[Any]] | None" = None, + update: "Sequence[Sequence[Any]] | None" = None, + insert_or_update: "Sequence[Sequence[Any]] | None" = None, + replace: "Sequence[Sequence[Any]] | None" = None, + delete_keys: "Sequence[Sequence[Any]] | None" = None, + delete_all: bool = False, + request_options: "Any | None" = None, + max_commit_delay: "Any | None" = None, + ) -> None: + """Apply blind-write mutations to a table in a single atomic commit.""" + row_groups = (insert, update, insert_or_update, replace) + if any(group is not None for group in row_groups) and columns is None: + msg = "apply_mutations() requires 'columns' when row mutations are provided." + raise ImproperConfigurationError(msg) + + database = self._require_database() + resolved_request_options = request_options if request_options is not None else self.driver_features.get( + "request_options" + ) + exc_handler = self.handle_database_exceptions() + with exc_handler: + with database.batch(request_options=resolved_request_options, max_commit_delay=max_commit_delay) as batch: + if insert is not None: + batch.insert(table, columns, insert) + if update is not None: + batch.update(table, columns, update) + if insert_or_update is not None: + batch.insert_or_update(table, columns, insert_or_update) + if replace is not None: + batch.replace(table, columns, replace) + if delete_all: + batch.delete(table, KeySet(all_=True)) + elif delete_keys is not None: + batch.delete(table, KeySet(keys=list(delete_keys))) + self._check_pending_exception(exc_handler) # ───────────────────────────────────────────────────────────────────────────── # ARROW API METHODS diff --git a/tests/unit/adapters/test_bigquery/test_config.py b/tests/unit/adapters/test_bigquery/test_config.py index c5d8d5c40..9c7145d2a 100644 --- a/tests/unit/adapters/test_bigquery/test_config.py +++ b/tests/unit/adapters/test_bigquery/test_config.py @@ -95,5 +95,6 @@ def test_bigquery_config_routes_current_client_level_settings(monkeypatch: Monke def test_bigquery_config_typed_surfaces_do_not_advertise_inert_settings() -> None: """Typed public settings should only include options that SQLSpec routes.""" assert "credentials_path" not in BigQueryConnectionParams.__annotations__ + assert "use_query_and_wait" in BigQueryDriverFeatures.__annotations__ assert "on_job_start" not in BigQueryDriverFeatures.__annotations__ assert "on_job_complete" not in BigQueryDriverFeatures.__annotations__ diff --git a/tests/unit/adapters/test_bigquery/test_job_controls.py b/tests/unit/adapters/test_bigquery/test_job_controls.py new file mode 100644 index 000000000..ab20ea330 --- /dev/null +++ b/tests/unit/adapters/test_bigquery/test_job_controls.py @@ -0,0 +1,245 @@ +"""Unit tests for BigQuery job-control behavior.""" + +from types import SimpleNamespace +from typing import Any, cast + +import pyarrow as pa +from google.cloud.bigquery import LoadJobConfig +from google.cloud.bigquery.enums import QueryApiMethod, TimestampPrecision + +from sqlspec.adapters.bigquery.core import build_load_job_config, run_query_job, try_bulk_insert +from sqlspec.adapters.bigquery.driver import BigQueryDriver +from sqlspec.utils.serializers import to_json + +CAPABILITIES = { + "arrow_export_enabled": True, + "arrow_import_enabled": True, + "parquet_export_enabled": True, + "parquet_import_enabled": True, + "requires_staging_for_load": False, + "staging_protocols": [], + "partition_strategies": ["fixed"], +} + + +class _RecordingJob: + def __init__( + self, + rows: list[dict[str, object]] | None = None, + *, + statement_type: str = "SELECT", + schema: list[SimpleNamespace] | None = None, + num_dml_affected_rows: int | None = None, + job_id: str = "job_123", + ) -> None: + self.rows = rows or [] + self.statement_type = statement_type + self.schema = schema + self.num_dml_affected_rows = num_dml_affected_rows + self.job_id = job_id + self.labels: dict[str, str] = {} + self.started = None + self.ended = None + self._properties: dict[str, Any] = {} + self.result_calls: list[dict[str, Any]] = [] + + def result(self, **kwargs: Any) -> list[dict[str, object]]: + self.result_calls.append(kwargs) + return self.rows + + +class _RecordingExtractJob: + def __init__(self) -> None: + self.result_calls: list[dict[str, Any]] = [] + self.job_id = "extract_123" + + def result(self, **kwargs: Any) -> None: + self.result_calls.append(kwargs) + + +class _RecordingRowIterator: + def __init__( + self, + rows: list[dict[str, object]] | None = None, + *, + schema: list[SimpleNamespace] | None = None, + num_dml_affected_rows: int | None = None, + ) -> None: + self.rows = rows or [] + self.schema = schema + self.num_dml_affected_rows = num_dml_affected_rows + + def __iter__(self) -> Any: + return iter(self.rows) + + +class _RecordingConnection: + def __init__(self) -> None: + self.query_calls: list[tuple[str, dict[str, Any]]] = [] + self.query_and_wait_calls: list[tuple[str, dict[str, Any]]] = [] + self.load_file_calls: list[tuple[Any, Any, dict[str, Any]]] = [] + self.load_uri_calls: list[tuple[Any, Any, dict[str, Any]]] = [] + self.extract_calls: list[tuple[Any, Any, dict[str, Any]]] = [] + self.query_job = _RecordingJob() + self.row_iterator = _RecordingRowIterator() + self.load_job = _RecordingJob(statement_type="LOAD", schema=None) + self.extract_job = _RecordingExtractJob() + + def query(self, sql: str, **kwargs: Any) -> _RecordingJob: + self.query_calls.append((sql, kwargs)) + return self.query_job + + def query_and_wait(self, sql: str, **kwargs: Any) -> _RecordingRowIterator: + self.query_and_wait_calls.append((sql, kwargs)) + return self.row_iterator + + def load_table_from_file(self, file_obj: Any, destination: Any, **kwargs: Any) -> _RecordingJob: + self.load_file_calls.append((file_obj, destination, kwargs)) + return self.load_job + + def load_table_from_uri(self, source_uris: Any, destination: Any, **kwargs: Any) -> _RecordingJob: + self.load_uri_calls.append((source_uris, destination, kwargs)) + return self.load_job + + def extract_table(self, source: Any, destination_uris: Any, **kwargs: Any) -> _RecordingExtractJob: + self.extract_calls.append((source, destination_uris, kwargs)) + return self.extract_job + + +def _schema(*names: str) -> list[SimpleNamespace]: + return [SimpleNamespace(name=name) for name in names] + + +def test_run_query_job_passes_job_id_prefix_only_without_job_id() -> None: + connection = _RecordingConnection() + + run_query_job( + cast(Any, connection), + "SELECT @name", + {"name": "alpha"}, + default_job_config=None, + job_config=None, + json_serializer=to_json, + retry=None, + timeout=3.0, + job_retry=None, + api_method=QueryApiMethod.INSERT, + timestamp_precision=TimestampPrecision.MICROSECOND, + job_id_prefix="prefix-", + ) + + sql, kwargs = connection.query_calls[0] + assert sql == "SELECT @name" + assert kwargs["api_method"] == QueryApiMethod.INSERT + assert kwargs["timestamp_precision"] == TimestampPrecision.MICROSECOND + assert kwargs["job_id_prefix"] == "prefix-" + assert "job_id" not in kwargs + + +def test_query_and_wait_used_when_enabled() -> None: + connection = _RecordingConnection() + connection.row_iterator = _RecordingRowIterator(rows=[{"v": 1}], schema=_schema("v")) + driver = BigQueryDriver(cast(Any, connection), driver_features={"use_query_and_wait": True}) + + result = driver.execute("SELECT 1 AS v") + + assert connection.query_calls == [] + assert connection.query_and_wait_calls[0][0] == "SELECT 1 AS v" + assert connection.query_and_wait_calls[0][1]["api_timeout"] == driver._job_request_timeout() + assert connection.query_and_wait_calls[0][1]["wait_timeout"] == driver._job_request_timeout() + assert result.get_data()[0]["v"] == 1 + + +def test_query_and_wait_dml_rowcount() -> None: + connection = _RecordingConnection() + connection.row_iterator = _RecordingRowIterator(num_dml_affected_rows=3) + driver = BigQueryDriver(cast(Any, connection), driver_features={"use_query_and_wait": True}) + + result = driver.execute("UPDATE t SET v = 1") + + assert connection.query_calls == [] + assert result.rows_affected == 3 + + +def test_query_and_wait_default_off() -> None: + connection = _RecordingConnection() + connection.query_job = _RecordingJob(rows=[{"v": 1}], schema=_schema("v")) + connection.row_iterator = _RecordingRowIterator(rows=[{"v": 1}], schema=_schema("v")) + driver = BigQueryDriver(cast(Any, connection)) + + result = driver.execute("SELECT 1 AS v") + + assert connection.query_calls + assert connection.query_and_wait_calls == [] + assert result.get_data()[0]["v"] == 1 + + +def test_try_bulk_insert_bounds_result_timeout() -> None: + connection = _RecordingConnection() + connection.load_job = _RecordingJob(statement_type="LOAD") + + rowcount = try_bulk_insert( + cast(Any, connection), + "INSERT INTO dataset.table (id) VALUES (@id)", + [{"id": 1}], + result_timeout=5.0, + ) + + assert rowcount == 1 + assert connection.load_file_calls[0][2]["timeout"] == 5.0 + assert connection.load_job.result_calls[0]["timeout"] == 5.0 + + +def test_load_from_arrow_bounds_result_timeout() -> None: + connection = _RecordingConnection() + connection.load_job = _RecordingJob(statement_type="LOAD") + driver = BigQueryDriver( + cast(Any, connection), driver_features={"job_result_timeout": 5.0, "storage_capabilities": CAPABILITIES} + ) + + result = driver.load_from_arrow("dataset.table", pa.table({"id": [1]})) + + assert result.telemetry["rows_processed"] == 0 + assert connection.load_file_calls[0][2]["timeout"] == driver._job_request_timeout() + assert connection.load_job.result_calls[0]["timeout"] == driver._job_request_timeout() + + +def test_load_from_storage_forwards_retry_and_bounds_result_timeout() -> None: + connection = _RecordingConnection() + connection.load_job = _RecordingJob(statement_type="LOAD") + driver = BigQueryDriver(cast(Any, connection), driver_features={"job_result_timeout": 5.0}) + + result = driver.load_from_storage( + "dataset.table", + "gs://bucket/object.parquet", + file_format="parquet", + ) + + assert result.telemetry["rows_processed"] == 0 + assert connection.load_uri_calls[0][2]["retry"] is driver._job_retry + assert connection.load_uri_calls[0][2]["timeout"] == driver._job_request_timeout() + assert connection.load_job.result_calls[0]["timeout"] == driver._job_request_timeout() + + +def test_load_job_config_fill_from_default_preserves_defaults() -> None: + default_job_config = LoadJobConfig(labels={"source": "default"}) + + filled_job_config = build_load_job_config("parquet", overwrite=False)._fill_from_default(default_job_config) + + assert filled_job_config.source_format == "PARQUET" + assert filled_job_config.write_disposition == "WRITE_APPEND" + assert filled_job_config.labels == {"source": "default"} + + +def test_export_table_to_storage_forwards_job_controls() -> None: + connection = _RecordingConnection() + driver = BigQueryDriver(cast(Any, connection)) + + job = driver.export_table_to_storage("dataset.table", "gs://bucket/object.csv") + + assert job is connection.extract_job + assert connection.extract_calls[0][0] == "dataset.table" + assert connection.extract_calls[0][1] == "gs://bucket/object.csv" + assert connection.extract_calls[0][2]["retry"] is driver._job_retry + assert connection.extract_calls[0][2]["timeout"] == driver._job_request_timeout() + assert connection.extract_job.result_calls[0]["timeout"] == driver._job_request_timeout() diff --git a/tests/unit/adapters/test_spanner/test_config.py b/tests/unit/adapters/test_spanner/test_config.py index 52e5aeaa6..7264af465 100644 --- a/tests/unit/adapters/test_spanner/test_config.py +++ b/tests/unit/adapters/test_spanner/test_config.py @@ -5,7 +5,12 @@ from google.cloud.spanner_v1.pool import AbstractSessionPool, BurstyPool, FixedSizePool from sqlspec.adapters.spanner import config as config_module -from sqlspec.adapters.spanner.config import SpannerConnectionParams, SpannerPoolParams, SpannerSyncConfig +from sqlspec.adapters.spanner.config import ( + SpannerConnectionParams, + SpannerDriverFeatures, + SpannerPoolParams, + SpannerSyncConfig, +) from sqlspec.adapters.spanner.core import default_statement_config from sqlspec.driver import SyncDriverAdapterBase from sqlspec.exceptions import ImproperConfigurationError @@ -70,6 +75,7 @@ def test_driver_features_defaults() -> None: config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) assert config.driver_features["enable_uuid_conversion"] is True assert config.driver_features["json_serializer"] is not None + assert "request_options" in SpannerDriverFeatures.__annotations__ def test_driver_feature_session_labels_are_routed_to_pool() -> None: diff --git a/tests/unit/adapters/test_spanner/test_session_controls.py b/tests/unit/adapters/test_spanner/test_session_controls.py new file mode 100644 index 000000000..59bf32190 --- /dev/null +++ b/tests/unit/adapters/test_spanner/test_session_controls.py @@ -0,0 +1,267 @@ +"""Unit tests for Spanner session-control behavior.""" + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from google.cloud.spanner_v1 import Transaction +from google.cloud.spanner_v1.keyset import KeySet +from google.cloud.spanner_v1.streamed import StreamedResultSet + +from sqlspec.adapters.spanner.config import SpannerSyncConfig +from sqlspec.adapters.spanner.driver import SpannerSyncDriver +from sqlspec.exceptions import ImproperConfigurationError + + +def _mock_result_set() -> MagicMock: + result = MagicMock(spec=StreamedResultSet) + field = MagicMock() + field.name = "id" + result.metadata.row_type.fields = [field] + result.__iter__.return_value = iter([(1,)]) + return result + + +def _mock_transaction() -> MagicMock: + transaction = MagicMock(spec=Transaction) + transaction.execute_sql = MagicMock() + transaction.execute_update = MagicMock() + transaction.batch_update = MagicMock() + return transaction + + +def test_provide_session_injects_database_provider(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, Any] = {} + + class _SessionContext: + def __init__(self, **kwargs: Any) -> None: + captured.update(kwargs) + + config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) + sentinel_database = object() + config.get_database = lambda: sentinel_database # type: ignore[assignment] + monkeypatch.setattr("sqlspec.adapters.spanner.config.SpannerSessionContext", _SessionContext) + + context = config.provide_session() + + assert isinstance(context, _SessionContext) + assert captured["driver_features"] is not config.driver_features + assert captured["driver_features"]["database_provider"] is config.get_database + assert "database_provider" not in config.driver_features + + +def test_require_database_raises_without_provider() -> None: + driver = SpannerSyncDriver(MagicMock()) + + with pytest.raises(ImproperConfigurationError, match="provide_session"): + driver._require_database() # type: ignore[attr-defined] + + +def test_driver_feature_request_options_forwarded_to_select() -> None: + request_options = {"priority": 1, "request_tag": "sqlspec-test"} + connection = _mock_transaction() + result_set = _mock_result_set() + connection.execute_sql.return_value = result_set + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": request_options}) + + statement = driver.prepare_statement("SELECT id FROM users", statement_config=driver.statement_config) + result = driver.dispatch_execute(connection, statement) # type: ignore[protected-access] + + assert result.is_select_result + connection.execute_sql.assert_called_once_with( + "SELECT id FROM users", params=None, param_types={}, request_options=request_options + ) + assert result.selected_data[0] == (1,) + + +def test_driver_feature_request_options_forwarded_to_dml() -> None: + request_options = {"priority": 1, "request_tag": "sqlspec-test"} + connection = _mock_transaction() + connection.execute_update.return_value = 10 + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": request_options}) + + statement = driver.prepare_statement("UPDATE users SET name = 'Bob'", statement_config=driver.statement_config) + result = driver.dispatch_execute(connection, statement) # type: ignore[protected-access] + + assert result.rowcount_override == 10 + connection.execute_update.assert_called_once_with( + "UPDATE users SET name = 'Bob'", params=None, param_types={}, request_options=request_options + ) + + +def test_driver_feature_request_options_forwarded_to_batch_update() -> None: + request_options = {"priority": 1, "request_tag": "sqlspec-test"} + connection = _mock_transaction() + connection.batch_update.return_value = (None, [1, 1]) + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": request_options}) + + statement = driver.prepare_statement( + "UPDATE users SET name = @name WHERE id = @id", + parameters=([{"id": 1, "name": "alice"}, {"id": 2, "name": "bob"}],), + statement_config=driver.statement_config, + ) + result = driver.dispatch_execute_many(connection, statement) # type: ignore[protected-access] + + assert result.rowcount_override == 2 + connection.batch_update.assert_called_once() + assert connection.batch_update.call_args.kwargs["request_options"] == request_options + + +def test_execute_with_options_overrides_feature_defaults() -> None: + feature_request_options = {"priority": 0, "request_tag": "feature"} + override_request_options = {"priority": 1, "request_tag": "override"} + connection = _mock_transaction() + result_set = _mock_result_set() + connection.execute_sql.return_value = result_set + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": feature_request_options}) + + result = driver.execute_with_options("SELECT id FROM users", request_options=override_request_options) + + connection.execute_sql.assert_called_once_with( + "SELECT id FROM users", params=None, param_types={}, request_options=override_request_options + ) + assert result.get_data()[0]["id"] == 1 + assert driver._pending_execute_options is None # type: ignore[attr-defined] + + +def test_execute_with_options_directed_read_only_on_select() -> None: + connection = _mock_transaction() + result_set = _mock_result_set() + connection.execute_sql.return_value = result_set + driver = SpannerSyncDriver(cast(Any, connection)) + + directed_read_options = SimpleNamespace(tag="directed") + driver.execute_with_options("SELECT id FROM users", directed_read_options=directed_read_options) + + assert connection.execute_sql.call_args.kwargs["directed_read_options"] is directed_read_options + + +def test_execute_with_options_directed_read_not_forwarded_to_dml() -> None: + connection = _mock_transaction() + connection.execute_update.return_value = 10 + driver = SpannerSyncDriver(cast(Any, connection)) + + directed_read_options = SimpleNamespace(tag="directed") + result = driver.execute_with_options("UPDATE users SET name = 'Bob'", directed_read_options=directed_read_options) + + assert result.rows_affected == 10 + assert "directed_read_options" not in connection.execute_update.call_args.kwargs + + +def test_execute_with_options_clears_stash_on_error() -> None: + connection = _mock_transaction() + connection.execute_sql.side_effect = RuntimeError("boom") + driver = SpannerSyncDriver(cast(Any, connection)) + + with pytest.raises(RuntimeError, match="boom"): + driver.execute_with_options("SELECT id FROM users") + + assert driver._pending_execute_options is None # type: ignore[attr-defined] + + +def test_execute_partitioned_dml_raises_without_database_provider() -> None: + driver = SpannerSyncDriver(MagicMock()) + + with pytest.raises(ImproperConfigurationError, match="provide_session"): + driver.execute_partitioned_dml("UPDATE users SET name = 'Bob'") + + +def test_execute_partitioned_dml_forwards_to_database() -> None: + request_options = {"priority": 1, "request_tag": "sqlspec-test"} + database = MagicMock() + database.execute_partitioned_dml.return_value = 7 + driver = SpannerSyncDriver( + MagicMock(), + driver_features={"request_options": request_options, "database_provider": lambda: database}, + ) + + row_count = driver.execute_partitioned_dml( + "UPDATE users SET name = 'Bob' WHERE TRUE", + request_options=request_options, + ) + + assert row_count == 7 + database.execute_partitioned_dml.assert_called_once_with( + "UPDATE users SET name = 'Bob' WHERE TRUE", + params=None, + param_types={}, + request_options=request_options, + exclude_txn_from_change_streams=False, + ) + + +def test_apply_mutations_requires_columns_for_rows() -> None: + driver = SpannerSyncDriver(MagicMock(), driver_features={"database_provider": lambda: MagicMock()}) + + with pytest.raises(ImproperConfigurationError, match="columns"): + driver.apply_mutations("users", insert=[(1, "alice")]) + + +def test_apply_mutations_routes_each_group() -> None: + batch = MagicMock() + batch.__enter__.return_value = batch + batch.__exit__.return_value = False + database = MagicMock() + database.batch.return_value = batch + driver = SpannerSyncDriver(MagicMock(), driver_features={"database_provider": lambda: database}) + + driver.apply_mutations( + "users", + columns=("id", "name"), + insert=[(1, "alice")], + update=[(2, "bob")], + insert_or_update=[(3, "carol")], + replace=[(4, "dave")], + delete_keys=[(5,)], + request_options={"priority": 1}, + max_commit_delay=3.0, + ) + + database.batch.assert_called_once_with(request_options={"priority": 1}, max_commit_delay=3.0) + batch.insert.assert_called_once_with("users", ("id", "name"), [(1, "alice")]) + batch.update.assert_called_once_with("users", ("id", "name"), [(2, "bob")]) + batch.insert_or_update.assert_called_once_with("users", ("id", "name"), [(3, "carol")]) + batch.replace.assert_called_once_with("users", ("id", "name"), [(4, "dave")]) + batch.delete.assert_called_once_with("users", KeySet(keys=[(5,)])) + + +def test_apply_mutations_delete_all_wins_over_delete_keys() -> None: + batch = MagicMock() + batch.__enter__.return_value = batch + batch.__exit__.return_value = False + database = MagicMock() + database.batch.return_value = batch + driver = SpannerSyncDriver(MagicMock(), driver_features={"database_provider": lambda: database}) + + driver.apply_mutations("users", delete_keys=[(5,)], delete_all=True) + + batch.delete.assert_called_once_with("users", KeySet(all_=True)) + + +def test_provide_batch_snapshot_closes_on_normal_exit(monkeypatch: pytest.MonkeyPatch) -> None: + snapshot = MagicMock() + database = MagicMock() + database.batch_snapshot.return_value = snapshot + config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) + config.get_database = lambda: database # type: ignore[assignment] + + with config.provide_batch_snapshot() as yielded: + assert yielded is snapshot + + snapshot.close.assert_called_once() + + +def test_provide_batch_snapshot_closes_on_exception(monkeypatch: pytest.MonkeyPatch) -> None: + snapshot = MagicMock() + database = MagicMock() + database.batch_snapshot.return_value = snapshot + config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) + config.get_database = lambda: database # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="boom"): + with config.provide_batch_snapshot() as yielded: + assert yielded is snapshot + raise RuntimeError("boom") + + snapshot.close.assert_called_once() From a8ad23c157655d8c9575c1f498967a731adbd9dd Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 13 Jun 2026 19:05:14 +0000 Subject: [PATCH 2/4] fix: satisfy cloud job session CI --- sqlspec/adapters/bigquery/driver.py | 38 +++++++-------- sqlspec/adapters/spanner/config.py | 11 ++--- sqlspec/adapters/spanner/driver.py | 47 +++++++++---------- .../test_bigquery/test_job_controls.py | 17 +++---- .../test_spanner/test_session_controls.py | 11 ++--- 5 files changed, 54 insertions(+), 70 deletions(-) diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index aba59ae6c..f1fa325cb 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -29,8 +29,8 @@ is_simple_insert, normalize_script_rowcount, resolve_column_names, - run_query_job, run_query_and_wait, + run_query_job, storage_api_available, try_bulk_insert, ) @@ -236,7 +236,9 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult: ) if is_select_like: - job_result = cursor.job.result(job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs()) + job_result = cursor.job.result( + job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs() + ) job_schema = cursor.job.schema or getattr(job_result, "schema", None) column_names = resolve_column_names(job_schema, self._column_name_cache) rows_list, _ = collect_rows(job_result, job_schema, column_names=column_names) @@ -570,10 +572,7 @@ def load_from_arrow( buffer.seek(0) job_config = build_load_job_config("parquet", overwrite) job = self.connection.load_table_from_file( - buffer, - table, - job_config=job_config, - timeout=self._job_request_timeout(), + buffer, table, job_config=job_config, timeout=self._job_request_timeout() ) job.result(timeout=self._job_request_timeout()) telemetry_payload = build_load_job_telemetry(job, table, format_label="parquet") @@ -599,11 +598,7 @@ def load_from_storage( raise StorageCapabilityError(msg, capability="parquet_import_enabled") job_config = build_load_job_config(file_format, overwrite) job = self.connection.load_table_from_uri( - source, - table, - job_config=job_config, - retry=self._job_retry, - timeout=self._job_request_timeout(), + source, table, job_config=job_config, retry=self._job_retry, timeout=self._job_request_timeout() ) job.result(timeout=self._job_request_timeout()) telemetry_payload = build_load_job_telemetry(job, table, format_label=file_format) @@ -621,15 +616,18 @@ def export_table_to_storage( location: "str | None" = None, ) -> "ExtractJob": """Export a BigQuery table to Cloud Storage via an extract job.""" - job = self.connection.extract_table( - table, - destination_uris, - job_config=job_config, - job_id=job_id, - job_id_prefix=job_id_prefix, - location=location, - retry=self._job_retry, - timeout=self._job_request_timeout(), + job = cast( + "ExtractJob", + self.connection.extract_table( + table, + destination_uris, + job_config=job_config, + job_id=job_id, + job_id_prefix=job_id_prefix, + location=location, + retry=self._job_retry, + timeout=self._job_request_timeout(), + ), ) job.result(timeout=self._job_request_timeout()) return job diff --git a/sqlspec/adapters/spanner/config.py b/sqlspec/adapters/spanner/config.py index 6716279b6..8d7f86fe6 100644 --- a/sqlspec/adapters/spanner/config.py +++ b/sqlspec/adapters/spanner/config.py @@ -18,7 +18,7 @@ from sqlspec.utils.type_guards import supports_close if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Iterator from logging import Logger from types import TracebackType @@ -28,10 +28,8 @@ from google.auth.credentials import Credentials from google.cloud.spanner_admin_database_v1.types import DatabaseDialect, EncryptionConfig from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest - from google.cloud.spanner_v1.database import Database - from google.cloud.spanner_v1.database import BatchSnapshot + from google.cloud.spanner_v1.database import BatchSnapshot, Database from google.cloud.spanner_v1.transaction import DefaultTransactionOptions - from collections.abc import Iterator from sqlspec.config import ExtensionConfigs from sqlspec.core import StatementConfig @@ -458,9 +456,8 @@ def provide_batch_snapshot( self, *, read_timestamp: "Any | None" = None, exact_staleness: "Any | None" = None ) -> "Iterator[BatchSnapshot]": """Yield a BatchSnapshot for partitioned reads across parallel workers.""" - snapshot = self.get_database().batch_snapshot( - read_timestamp=read_timestamp, - exact_staleness=exact_staleness, + snapshot = self.get_database().batch_snapshot( # type: ignore[no-untyped-call] + read_timestamp=read_timestamp, exact_staleness=exact_staleness ) try: yield snapshot diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index bc6d93581..31f4332ba 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -37,8 +37,7 @@ ) if TYPE_CHECKING: - from collections.abc import Callable - from collections.abc import Sequence + from collections.abc import Callable, Sequence from sqlglot.dialects.dialect import DialectType @@ -329,10 +328,7 @@ def execute_with_options( config = statement_config or self.statement_config sql_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs) self._pending_execute_options = _PerCallExecuteOptions( - request_options=request_options, - directed_read_options=directed_read_options, - retry=retry, - timeout=timeout, + request_options=request_options, directed_read_options=directed_read_options, retry=retry, timeout=timeout ) try: return self.dispatch_statement_execution(statement=sql_statement, connection=self.connection) @@ -342,10 +338,7 @@ def execute_with_options( def _require_database(self) -> "Any": provider = self.driver_features.get("database_provider") if provider is None: - msg = ( - "Spanner database-level operations require a session created via " - "SpannerSyncConfig.provide_session()." - ) + msg = "Spanner database-level operations require a session created via SpannerSyncConfig.provide_session()." raise ImproperConfigurationError(msg) return provider() @@ -404,24 +397,26 @@ def apply_mutations( raise ImproperConfigurationError(msg) database = self._require_database() - resolved_request_options = request_options if request_options is not None else self.driver_features.get( - "request_options" + resolved_request_options = ( + request_options if request_options is not None else self.driver_features.get("request_options") ) exc_handler = self.handle_database_exceptions() - with exc_handler: - with database.batch(request_options=resolved_request_options, max_commit_delay=max_commit_delay) as batch: - if insert is not None: - batch.insert(table, columns, insert) - if update is not None: - batch.update(table, columns, update) - if insert_or_update is not None: - batch.insert_or_update(table, columns, insert_or_update) - if replace is not None: - batch.replace(table, columns, replace) - if delete_all: - batch.delete(table, KeySet(all_=True)) - elif delete_keys is not None: - batch.delete(table, KeySet(keys=list(delete_keys))) + with ( + exc_handler, + database.batch(request_options=resolved_request_options, max_commit_delay=max_commit_delay) as batch, + ): + if insert is not None: + batch.insert(table, columns, insert) + if update is not None: + batch.update(table, columns, update) + if insert_or_update is not None: + batch.insert_or_update(table, columns, insert_or_update) + if replace is not None: + batch.replace(table, columns, replace) + if delete_all: + batch.delete(table, KeySet(all_=True)) # type: ignore[no-untyped-call] + elif delete_keys is not None: + batch.delete(table, KeySet(keys=list(delete_keys))) # type: ignore[no-untyped-call] self._check_pending_exception(exc_handler) # ───────────────────────────────────────────────────────────────────────────── diff --git a/tests/unit/adapters/test_bigquery/test_job_controls.py b/tests/unit/adapters/test_bigquery/test_job_controls.py index ab20ea330..e18defb90 100644 --- a/tests/unit/adapters/test_bigquery/test_job_controls.py +++ b/tests/unit/adapters/test_bigquery/test_job_controls.py @@ -179,10 +179,7 @@ def test_try_bulk_insert_bounds_result_timeout() -> None: connection.load_job = _RecordingJob(statement_type="LOAD") rowcount = try_bulk_insert( - cast(Any, connection), - "INSERT INTO dataset.table (id) VALUES (@id)", - [{"id": 1}], - result_timeout=5.0, + cast(Any, connection), "INSERT INTO dataset.table (id) VALUES (@id)", [{"id": 1}], result_timeout=5.0 ) assert rowcount == 1 @@ -209,11 +206,7 @@ def test_load_from_storage_forwards_retry_and_bounds_result_timeout() -> None: connection.load_job = _RecordingJob(statement_type="LOAD") driver = BigQueryDriver(cast(Any, connection), driver_features={"job_result_timeout": 5.0}) - result = driver.load_from_storage( - "dataset.table", - "gs://bucket/object.parquet", - file_format="parquet", - ) + result = driver.load_from_storage("dataset.table", "gs://bucket/object.parquet", file_format="parquet") assert result.telemetry["rows_processed"] == 0 assert connection.load_uri_calls[0][2]["retry"] is driver._job_retry @@ -224,7 +217,9 @@ def test_load_from_storage_forwards_retry_and_bounds_result_timeout() -> None: def test_load_job_config_fill_from_default_preserves_defaults() -> None: default_job_config = LoadJobConfig(labels={"source": "default"}) - filled_job_config = build_load_job_config("parquet", overwrite=False)._fill_from_default(default_job_config) + filled_job_config = build_load_job_config("parquet", overwrite=False)._fill_from_default( # type: ignore[no-untyped-call] + default_job_config + ) assert filled_job_config.source_format == "PARQUET" assert filled_job_config.write_disposition == "WRITE_APPEND" @@ -237,7 +232,7 @@ def test_export_table_to_storage_forwards_job_controls() -> None: job = driver.export_table_to_storage("dataset.table", "gs://bucket/object.csv") - assert job is connection.extract_job + assert cast(object, job) is connection.extract_job assert connection.extract_calls[0][0] == "dataset.table" assert connection.extract_calls[0][1] == "gs://bucket/object.csv" assert connection.extract_calls[0][2]["retry"] is driver._job_retry diff --git a/tests/unit/adapters/test_spanner/test_session_controls.py b/tests/unit/adapters/test_spanner/test_session_controls.py index 59bf32190..496ba6e6d 100644 --- a/tests/unit/adapters/test_spanner/test_session_controls.py +++ b/tests/unit/adapters/test_spanner/test_session_controls.py @@ -72,6 +72,7 @@ def test_driver_feature_request_options_forwarded_to_select() -> None: connection.execute_sql.assert_called_once_with( "SELECT id FROM users", params=None, param_types={}, request_options=request_options ) + assert result.selected_data is not None assert result.selected_data[0] == (1,) @@ -172,13 +173,11 @@ def test_execute_partitioned_dml_forwards_to_database() -> None: database = MagicMock() database.execute_partitioned_dml.return_value = 7 driver = SpannerSyncDriver( - MagicMock(), - driver_features={"request_options": request_options, "database_provider": lambda: database}, + MagicMock(), driver_features={"request_options": request_options, "database_provider": lambda: database} ) row_count = driver.execute_partitioned_dml( - "UPDATE users SET name = 'Bob' WHERE TRUE", - request_options=request_options, + "UPDATE users SET name = 'Bob' WHERE TRUE", request_options=request_options ) assert row_count == 7 @@ -223,7 +222,7 @@ def test_apply_mutations_routes_each_group() -> None: batch.update.assert_called_once_with("users", ("id", "name"), [(2, "bob")]) batch.insert_or_update.assert_called_once_with("users", ("id", "name"), [(3, "carol")]) batch.replace.assert_called_once_with("users", ("id", "name"), [(4, "dave")]) - batch.delete.assert_called_once_with("users", KeySet(keys=[(5,)])) + batch.delete.assert_called_once_with("users", KeySet(keys=[(5,)])) # type: ignore[no-untyped-call] def test_apply_mutations_delete_all_wins_over_delete_keys() -> None: @@ -236,7 +235,7 @@ def test_apply_mutations_delete_all_wins_over_delete_keys() -> None: driver.apply_mutations("users", delete_keys=[(5,)], delete_all=True) - batch.delete.assert_called_once_with("users", KeySet(all_=True)) + batch.delete.assert_called_once_with("users", KeySet(all_=True)) # type: ignore[no-untyped-call] def test_provide_batch_snapshot_closes_on_normal_exit(monkeypatch: pytest.MonkeyPatch) -> None: From 1b531d6042bc8a050a32f4d795bca4f9993fc528 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 13 Jun 2026 19:13:12 +0000 Subject: [PATCH 3/4] fix: harden cloud job session CI --- sqlspec/adapters/bigquery/core.py | 8 +++++++- tests/unit/adapters/test_spanner/test_session_controls.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sqlspec/adapters/bigquery/core.py b/sqlspec/adapters/bigquery/core.py index f058b24ac..d6f8eb1b7 100644 --- a/sqlspec/adapters/bigquery/core.py +++ b/sqlspec/adapters/bigquery/core.py @@ -87,8 +87,8 @@ def _check_pending_exception(self, exc_handler: "SyncExceptionHandler") -> None: "is_simple_insert", "normalize_script_rowcount", "resolve_column_names", - "run_query_job", "run_query_and_wait", + "run_query_job", "storage_api_available", "try_bulk_insert", ) @@ -181,6 +181,7 @@ def try_bulk_insert( parameters: Parameter dictionaries for the insert. expression: Optional parsed expression to reuse. allow_parse: Whether to parse SQL when expression is unavailable. + result_timeout: Timeout forwarded to the load job request and result wait. Returns: Inserted row count if bulk insert succeeds, otherwise None. @@ -558,6 +559,11 @@ def run_query_job( job_retry: Retry policy attached to the returned query job. ``None`` disables job retries and the client's built-in ``jobs.insert`` retry wrapper (which carries a fixed 600s deadline). + api_method: Optional query API method override. + timestamp_precision: Optional timestamp precision override. + job_id: Explicit BigQuery job ID. + job_id_prefix: Prefix used by BigQuery to generate a job ID when + ``job_id`` is not provided. Returns: QueryJob object representing the executed job. diff --git a/tests/unit/adapters/test_spanner/test_session_controls.py b/tests/unit/adapters/test_spanner/test_session_controls.py index 496ba6e6d..23fea9e2e 100644 --- a/tests/unit/adapters/test_spanner/test_session_controls.py +++ b/tests/unit/adapters/test_spanner/test_session_controls.py @@ -9,6 +9,7 @@ from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.streamed import StreamedResultSet +import sqlspec.adapters.spanner.config as spanner_config from sqlspec.adapters.spanner.config import SpannerSyncConfig from sqlspec.adapters.spanner.driver import SpannerSyncDriver from sqlspec.exceptions import ImproperConfigurationError @@ -41,7 +42,7 @@ def __init__(self, **kwargs: Any) -> None: config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) sentinel_database = object() config.get_database = lambda: sentinel_database # type: ignore[assignment] - monkeypatch.setattr("sqlspec.adapters.spanner.config.SpannerSessionContext", _SessionContext) + monkeypatch.setattr(spanner_config, "SpannerSessionContext", _SessionContext) context = config.provide_session() From 0dbdf2a32a3d53290c9779fb657e60b772a7f82c Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 13 Jun 2026 20:20:42 +0000 Subject: [PATCH 4/4] fix: keep cloud controls on existing surfaces --- sqlspec/adapters/bigquery/core.py | 3 +- sqlspec/adapters/bigquery/driver.py | 33 +-- sqlspec/adapters/spanner/config.py | 104 ++++++--- sqlspec/adapters/spanner/driver.py | 190 +++++++--------- .../test_bigquery/test_job_controls.py | 35 +-- .../test_spanner/test_session_controls.py | 210 ++++++++---------- 6 files changed, 263 insertions(+), 312 deletions(-) diff --git a/sqlspec/adapters/bigquery/core.py b/sqlspec/adapters/bigquery/core.py index d6f8eb1b7..81be63175 100644 --- a/sqlspec/adapters/bigquery/core.py +++ b/sqlspec/adapters/bigquery/core.py @@ -87,7 +87,6 @@ def _check_pending_exception(self, exc_handler: "SyncExceptionHandler") -> None: "is_simple_insert", "normalize_script_rowcount", "resolve_column_names", - "run_query_and_wait", "run_query_job", "storage_api_available", "try_bulk_insert", @@ -592,7 +591,7 @@ def run_query_job( return connection.query(sql, **query_kwargs) -def run_query_and_wait( +def _run_query_and_wait( connection: "BigQueryConnection", sql: str, parameters: Any, diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index f1fa325cb..534032cf8 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -16,6 +16,7 @@ from sqlspec.adapters.bigquery.core import ( DEFAULT_REQUEST_TIMEOUT, BigQueryStreamSource, + _run_query_and_wait, _uses_local_bigquery_endpoint, build_dml_rowcount, build_inlined_script, @@ -29,7 +30,6 @@ is_simple_insert, normalize_script_rowcount, resolve_column_names, - run_query_and_wait, run_query_job, storage_api_available, try_bulk_insert, @@ -54,7 +54,7 @@ from google.api_core.retry import Retry from google.cloud import bigquery_storage # type: ignore[attr-defined, unused-ignore] - from google.cloud.bigquery import ExtractJob, ExtractJobConfig, QueryJob, QueryJobConfig + from google.cloud.bigquery import QueryJob, QueryJobConfig from sqlspec.builder import QueryBuilder from sqlspec.core import SQL, ArrowResult, Statement, StatementFilter @@ -201,7 +201,7 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult: """ sql, parameters = self._get_compiled_sql(statement, self.statement_config) if self._use_query_and_wait: - row_iterator = run_query_and_wait( + row_iterator = _run_query_and_wait( cursor, sql, parameters, @@ -605,33 +605,6 @@ def load_from_storage( self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload) - def export_table_to_storage( - self, - table: str, - destination_uris: "str | list[str]", - *, - job_config: "ExtractJobConfig | None" = None, - job_id: "str | None" = None, - job_id_prefix: "str | None" = None, - location: "str | None" = None, - ) -> "ExtractJob": - """Export a BigQuery table to Cloud Storage via an extract job.""" - job = cast( - "ExtractJob", - self.connection.extract_table( - table, - destination_uris, - job_config=job_config, - job_id=job_id, - job_id_prefix=job_id_prefix, - location=location, - retry=self._job_retry, - timeout=self._job_request_timeout(), - ), - ) - job.result(timeout=self._job_request_timeout()) - return job - # ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── diff --git a/sqlspec/adapters/spanner/config.py b/sqlspec/adapters/spanner/config.py index 8d7f86fe6..af9258033 100644 --- a/sqlspec/adapters/spanner/config.py +++ b/sqlspec/adapters/spanner/config.py @@ -1,6 +1,5 @@ """Spanner configuration.""" -from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from google.cloud.spanner_v1 import Client @@ -18,7 +17,7 @@ from sqlspec.utils.type_guards import supports_close if TYPE_CHECKING: - from collections.abc import Callable, Iterator + from collections.abc import Callable from logging import Logger from types import TracebackType @@ -27,8 +26,8 @@ from google.api_core.retry import Retry from google.auth.credentials import Credentials from google.cloud.spanner_admin_database_v1.types import DatabaseDialect, EncryptionConfig - from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest - from google.cloud.spanner_v1.database import BatchSnapshot, Database + from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest, RequestOptions + from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.transaction import DefaultTransactionOptions from sqlspec.config import ExtensionConfigs @@ -139,8 +138,9 @@ class SpannerDriverFeatures(TypedDict): retry: Per-request retry policy passed to execute_sql(), execute_update(), and batch_update(). timeout: Per-request timeout in seconds passed to execute_sql(), execute_update(), and batch_update(). request_options: Default RequestOptions forwarded to execute_sql(), execute_update(), - and batch_update(). Per-call overrides are available via - SpannerSyncDriver.execute_with_options(). + and batch_update(). Per-call overrides are available through normal + driver execution methods. + directed_read_options: Default DirectedReadOptions forwarded to execute_sql(). session_labels: Deprecated compatibility alias for pool session labels. Prefer ``connection_config["session_labels"]``. enable_events: Enable database event channel support. @@ -154,7 +154,8 @@ class SpannerDriverFeatures(TypedDict): json_deserializer: "NotRequired[Callable[[str], Any]]" retry: "NotRequired[Retry | None]" timeout: "NotRequired[float | None]" - request_options: "NotRequired[Any]" + request_options: "NotRequired[RequestOptions | dict[str, Any] | None]" + directed_read_options: "NotRequired[DirectedReadOptions | None]" session_labels: "NotRequired[dict[str, str]]" enable_events: "NotRequired[bool]" events_backend: "NotRequired[str]" @@ -404,6 +405,10 @@ def provide_session( *args: Any, statement_config: "StatementConfig | None" = None, transaction: "bool" = _DEFAULT_SESSION_TRANSACTION, + request_options: "RequestOptions | dict[str, Any] | None" = None, + directed_read_options: "DirectedReadOptions | None" = None, + retry: "Retry | None" = None, + timeout: "float | None" = None, **kwargs: Any, ) -> "SpannerSessionContext": """Provide a Spanner driver session context manager. @@ -417,6 +422,10 @@ def provide_session( statement_config: Optional statement configuration override. transaction: Whether to use a Transaction (True, default) or Snapshot (False). + request_options: Session-scoped RequestOptions for Spanner statements. + directed_read_options: Session-scoped DirectedReadOptions for reads. + retry: Session-scoped retry policy for Spanner statement calls. + timeout: Session-scoped timeout for Spanner statement calls. **kwargs: Additional keyword arguments. Returns: @@ -424,45 +433,88 @@ def provide_session( """ connection_ctx = SpannerConnectionContext(self, transaction=transaction) handler = _SpannerSessionConnectionHandler(self, connection_ctx) - session_driver_features: dict[str, Any] = dict(self.driver_features) - session_driver_features["database_provider"] = self.get_database return SpannerSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=session_driver_features, + driver_features=self._session_driver_features( + request_options=request_options, + directed_read_options=directed_read_options, + retry=retry, + timeout=timeout, + ), prepare_driver=self._prepare_driver, ) def provide_write_session( - self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any + self, + *args: Any, + statement_config: "StatementConfig | None" = None, + request_options: "RequestOptions | dict[str, Any] | None" = None, + directed_read_options: "DirectedReadOptions | None" = None, + retry: "Retry | None" = None, + timeout: "float | None" = None, + **kwargs: Any, ) -> "SpannerSessionContext": """Provide a write-capable Spanner session (alias for :meth:`provide_session`).""" - return self.provide_session(*args, statement_config=statement_config, transaction=True, **kwargs) + return self.provide_session( + *args, + statement_config=statement_config, + transaction=True, + request_options=request_options, + directed_read_options=directed_read_options, + retry=retry, + timeout=timeout, + **kwargs, + ) def provide_read_session( - self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any + self, + *args: Any, + statement_config: "StatementConfig | None" = None, + request_options: "RequestOptions | dict[str, Any] | None" = None, + directed_read_options: "DirectedReadOptions | None" = None, + retry: "Retry | None" = None, + timeout: "float | None" = None, + **kwargs: Any, ) -> "SpannerSessionContext": """Provide a read-only Snapshot Spanner session. Use for query workloads that benefit from Spanner's snapshot reads. For DDL/DML, use :meth:`provide_session` (write-capable by default). """ - return self.provide_session(*args, statement_config=statement_config, transaction=False, **kwargs) - - @contextmanager - def provide_batch_snapshot( - self, *, read_timestamp: "Any | None" = None, exact_staleness: "Any | None" = None - ) -> "Iterator[BatchSnapshot]": - """Yield a BatchSnapshot for partitioned reads across parallel workers.""" - snapshot = self.get_database().batch_snapshot( # type: ignore[no-untyped-call] - read_timestamp=read_timestamp, exact_staleness=exact_staleness + return self.provide_session( + *args, + statement_config=statement_config, + transaction=False, + request_options=request_options, + directed_read_options=directed_read_options, + retry=retry, + timeout=timeout, + **kwargs, ) - try: - yield snapshot - finally: - snapshot.close() + + def _session_driver_features( + self, + *, + request_options: "RequestOptions | dict[str, Any] | None", + directed_read_options: "DirectedReadOptions | None", + retry: "Retry | None", + timeout: "float | None", + ) -> "dict[str, Any]": + if request_options is None and directed_read_options is None and retry is None and timeout is None: + return self.driver_features + driver_features = dict(self.driver_features) + if request_options is not None: + driver_features["request_options"] = request_options + if directed_read_options is not None: + driver_features["directed_read_options"] = directed_read_options + if retry is not None: + driver_features["retry"] = retry + if timeout is not None: + driver_features["timeout"] = timeout + return driver_features def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for SpannerSyncConfig types. diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index 31f4332ba..28a8c58b9 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -5,7 +5,6 @@ import sqlglot as _sqlglot from google.api_core import exceptions as api_exceptions -from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.transaction import Transaction from sqlglot import exp as _sqlglot_exp @@ -27,7 +26,7 @@ from sqlspec.adapters.spanner.type_converter import SpannerOutputConverter from sqlspec.core import StatementConfig, create_arrow_result, register_driver_profile from sqlspec.driver import BaseSyncExceptionHandler, ExecutionResult, SyncDriverAdapterBase -from sqlspec.exceptions import ImproperConfigurationError, SQLConversionError +from sqlspec.exceptions import SQLConversionError from sqlspec.utils.serializers import from_json _READ_ONLY_SNAPSHOT_ERROR_MESSAGE = ( @@ -39,13 +38,16 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence + from google.api_core.retry import Retry + from google.cloud.spanner_v1 import DirectedReadOptions, RequestOptions from sqlglot.dialects.dialect import DialectType from sqlspec.adapters.spanner._typing import SpannerConnection - from sqlspec.core import ArrowResult, SQLResult + from sqlspec.builder import QueryBuilder + from sqlspec.core import ArrowResult, SQLResult, Statement, StatementFilter from sqlspec.core.statement import SQL from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry - from sqlspec.typing import ArrowReturnFormat + from sqlspec.typing import ArrowReturnFormat, StatementParameters __all__ = ( "SpannerDataDictionary", @@ -120,9 +122,9 @@ class _PerCallExecuteOptions: def __init__( self, *, - request_options: "Any | None" = None, - directed_read_options: "Any | None" = None, - retry: "Any | None" = None, + request_options: "RequestOptions | dict[str, Any] | None" = None, + directed_read_options: "DirectedReadOptions | None" = None, + retry: "Retry | None" = None, timeout: "float | None" = None, ) -> None: self.request_options = request_options @@ -295,11 +297,72 @@ def with_cursor(self, connection: "SpannerConnection") -> "SpannerSyncCursor": def handle_database_exceptions(self) -> "SpannerExceptionHandler": return SpannerExceptionHandler() + def execute( + self, + statement: "SQL | Statement | QueryBuilder", + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + **kwargs: Any, + ) -> "SQLResult": + """Execute a statement with optional Spanner per-call request options.""" + execute_options = self._pop_execute_options(kwargs) + if execute_options is None: + return super().execute(statement, *parameters, statement_config=statement_config, **kwargs) + previous_options = self._pending_execute_options + self._pending_execute_options = execute_options + try: + return super().execute(statement, *parameters, statement_config=statement_config, **kwargs) + finally: + self._pending_execute_options = previous_options + + def execute_many( + self, + statement: "SQL | Statement | QueryBuilder", + /, + parameters: "Sequence[StatementParameters]", + *filters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + **kwargs: Any, + ) -> "SQLResult": + """Execute a batch statement with optional Spanner per-call request options.""" + execute_options = self._pop_execute_options(kwargs) + if execute_options is None: + return super().execute_many(statement, parameters, *filters, statement_config=statement_config, **kwargs) + previous_options = self._pending_execute_options + self._pending_execute_options = execute_options + try: + return super().execute_many(statement, parameters, *filters, statement_config=statement_config, **kwargs) + finally: + self._pending_execute_options = previous_options + + def execute_script( + self, + statement: "str | SQL", + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + **kwargs: Any, + ) -> "SQLResult": + """Execute a multi-statement script with optional Spanner per-call request options.""" + execute_options = self._pop_execute_options(kwargs) + if execute_options is None: + return super().execute_script(statement, *parameters, statement_config=statement_config, **kwargs) + previous_options = self._pending_execute_options + self._pending_execute_options = execute_options + try: + return super().execute_script(statement, *parameters, statement_config=statement_config, **kwargs) + finally: + self._pending_execute_options = previous_options + def _execute_kwargs(self, *, for_read: bool = False) -> dict[str, Any]: kwargs = {key: self.driver_features[key] for key in ("retry", "timeout") if key in self.driver_features} request_options = self.driver_features.get("request_options") if request_options is not None: kwargs["request_options"] = request_options + directed_read_options = self.driver_features.get("directed_read_options") + if for_read and directed_read_options is not None: + kwargs["directed_read_options"] = directed_read_options pending = self._pending_execute_options if pending is not None: if pending.request_options is not None: @@ -312,112 +375,15 @@ def _execute_kwargs(self, *, for_read: bool = False) -> dict[str, Any]: kwargs["directed_read_options"] = pending.directed_read_options return kwargs - def execute_with_options( - self, - statement: "Any", - /, - *parameters: "Any", - request_options: "Any | None" = None, - directed_read_options: "Any | None" = None, - retry: "Any | None" = None, - timeout: "float | None" = None, - statement_config: "StatementConfig | None" = None, - **kwargs: Any, - ) -> "SQLResult": - """Execute a single statement with per-call Spanner request options.""" - config = statement_config or self.statement_config - sql_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs) - self._pending_execute_options = _PerCallExecuteOptions( - request_options=request_options, directed_read_options=directed_read_options, retry=retry, timeout=timeout - ) - try: - return self.dispatch_statement_execution(statement=sql_statement, connection=self.connection) - finally: - self._pending_execute_options = None - - def _require_database(self) -> "Any": - provider = self.driver_features.get("database_provider") - if provider is None: - msg = "Spanner database-level operations require a session created via SpannerSyncConfig.provide_session()." - raise ImproperConfigurationError(msg) - return provider() - - def execute_partitioned_dml( - self, - statement: "Any", - /, - *parameters: "Any", - request_options: "Any | None" = None, - exclude_txn_from_change_streams: bool = False, - statement_config: "StatementConfig | None" = None, - **kwargs: Any, - ) -> int: - """Execute a partitioned DML statement across the whole table.""" - database = self._require_database() - config = statement_config or self.statement_config - sql_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs) - sql, params = self._get_compiled_sql(sql_statement, config) - coerced_params = self._coerce_params(cast("dict[str, Any] | None", params)) - param_types_map = self._infer_param_types(coerced_params) - exc_handler = self.handle_database_exceptions() - row_count = 0 - with exc_handler: - row_count = int( - database.execute_partitioned_dml( - sql, - params=coerced_params, - param_types=param_types_map, - request_options=request_options - if request_options is not None - else self.driver_features.get("request_options"), - exclude_txn_from_change_streams=exclude_txn_from_change_streams, - ) - ) - self._check_pending_exception(exc_handler) - return row_count - - def apply_mutations( - self, - table: str, - *, - columns: "Sequence[str] | None" = None, - insert: "Sequence[Sequence[Any]] | None" = None, - update: "Sequence[Sequence[Any]] | None" = None, - insert_or_update: "Sequence[Sequence[Any]] | None" = None, - replace: "Sequence[Sequence[Any]] | None" = None, - delete_keys: "Sequence[Sequence[Any]] | None" = None, - delete_all: bool = False, - request_options: "Any | None" = None, - max_commit_delay: "Any | None" = None, - ) -> None: - """Apply blind-write mutations to a table in a single atomic commit.""" - row_groups = (insert, update, insert_or_update, replace) - if any(group is not None for group in row_groups) and columns is None: - msg = "apply_mutations() requires 'columns' when row mutations are provided." - raise ImproperConfigurationError(msg) - - database = self._require_database() - resolved_request_options = ( - request_options if request_options is not None else self.driver_features.get("request_options") + def _pop_execute_options(self, kwargs: dict[str, Any]) -> "_PerCallExecuteOptions | None": + if not any(key in kwargs for key in ("request_options", "directed_read_options", "retry", "timeout")): + return None + return _PerCallExecuteOptions( + request_options=kwargs.pop("request_options", None), + directed_read_options=kwargs.pop("directed_read_options", None), + retry=kwargs.pop("retry", None), + timeout=kwargs.pop("timeout", None), ) - exc_handler = self.handle_database_exceptions() - with ( - exc_handler, - database.batch(request_options=resolved_request_options, max_commit_delay=max_commit_delay) as batch, - ): - if insert is not None: - batch.insert(table, columns, insert) - if update is not None: - batch.update(table, columns, update) - if insert_or_update is not None: - batch.insert_or_update(table, columns, insert_or_update) - if replace is not None: - batch.replace(table, columns, replace) - if delete_all: - batch.delete(table, KeySet(all_=True)) # type: ignore[no-untyped-call] - elif delete_keys is not None: - batch.delete(table, KeySet(keys=list(delete_keys))) # type: ignore[no-untyped-call] - self._check_pending_exception(exc_handler) # ───────────────────────────────────────────────────────────────────────────── # ARROW API METHODS diff --git a/tests/unit/adapters/test_bigquery/test_job_controls.py b/tests/unit/adapters/test_bigquery/test_job_controls.py index e18defb90..ff9fa7762 100644 --- a/tests/unit/adapters/test_bigquery/test_job_controls.py +++ b/tests/unit/adapters/test_bigquery/test_job_controls.py @@ -48,15 +48,6 @@ def result(self, **kwargs: Any) -> list[dict[str, object]]: return self.rows -class _RecordingExtractJob: - def __init__(self) -> None: - self.result_calls: list[dict[str, Any]] = [] - self.job_id = "extract_123" - - def result(self, **kwargs: Any) -> None: - self.result_calls.append(kwargs) - - class _RecordingRowIterator: def __init__( self, @@ -79,11 +70,9 @@ def __init__(self) -> None: self.query_and_wait_calls: list[tuple[str, dict[str, Any]]] = [] self.load_file_calls: list[tuple[Any, Any, dict[str, Any]]] = [] self.load_uri_calls: list[tuple[Any, Any, dict[str, Any]]] = [] - self.extract_calls: list[tuple[Any, Any, dict[str, Any]]] = [] self.query_job = _RecordingJob() self.row_iterator = _RecordingRowIterator() self.load_job = _RecordingJob(statement_type="LOAD", schema=None) - self.extract_job = _RecordingExtractJob() def query(self, sql: str, **kwargs: Any) -> _RecordingJob: self.query_calls.append((sql, kwargs)) @@ -101,10 +90,6 @@ def load_table_from_uri(self, source_uris: Any, destination: Any, **kwargs: Any) self.load_uri_calls.append((source_uris, destination, kwargs)) return self.load_job - def extract_table(self, source: Any, destination_uris: Any, **kwargs: Any) -> _RecordingExtractJob: - self.extract_calls.append((source, destination_uris, kwargs)) - return self.extract_job - def _schema(*names: str) -> list[SimpleNamespace]: return [SimpleNamespace(name=name) for name in names] @@ -226,15 +211,17 @@ def test_load_job_config_fill_from_default_preserves_defaults() -> None: assert filled_job_config.labels == {"source": "default"} -def test_export_table_to_storage_forwards_job_controls() -> None: +def test_select_to_storage_uses_existing_export_surface_and_job_controls(tmp_path: Any) -> None: connection = _RecordingConnection() - driver = BigQueryDriver(cast(Any, connection)) + cast(Any, connection)._connection = SimpleNamespace(API_BASE_URL="http://localhost") + connection.query_job = _RecordingJob(rows=[{"id": 1}], schema=_schema("id")) + driver = BigQueryDriver( + cast(Any, connection), driver_features={"job_result_timeout": 5.0, "storage_capabilities": CAPABILITIES} + ) - job = driver.export_table_to_storage("dataset.table", "gs://bucket/object.csv") + job = driver.select_to_storage("SELECT 1 AS id", tmp_path / "result.parquet", format_hint="parquet") - assert cast(object, job) is connection.extract_job - assert connection.extract_calls[0][0] == "dataset.table" - assert connection.extract_calls[0][1] == "gs://bucket/object.csv" - assert connection.extract_calls[0][2]["retry"] is driver._job_retry - assert connection.extract_calls[0][2]["timeout"] == driver._job_request_timeout() - assert connection.extract_job.result_calls[0]["timeout"] == driver._job_request_timeout() + assert job.telemetry["rows_processed"] == 1 + assert connection.query_calls[0][0] == "SELECT 1 AS id" + assert connection.query_calls[0][1]["timeout"] == driver._job_request_timeout() + assert connection.query_job.result_calls[0]["timeout"] == driver._job_result_timeout diff --git a/tests/unit/adapters/test_spanner/test_session_controls.py b/tests/unit/adapters/test_spanner/test_session_controls.py index 23fea9e2e..abb90d5bb 100644 --- a/tests/unit/adapters/test_spanner/test_session_controls.py +++ b/tests/unit/adapters/test_spanner/test_session_controls.py @@ -6,13 +6,11 @@ import pytest from google.cloud.spanner_v1 import Transaction -from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.streamed import StreamedResultSet import sqlspec.adapters.spanner.config as spanner_config from sqlspec.adapters.spanner.config import SpannerSyncConfig from sqlspec.adapters.spanner.driver import SpannerSyncDriver -from sqlspec.exceptions import ImproperConfigurationError def _mock_result_set() -> MagicMock: @@ -32,31 +30,66 @@ def _mock_transaction() -> MagicMock: return transaction -def test_provide_session_injects_database_provider(monkeypatch: pytest.MonkeyPatch) -> None: +def test_provide_session_uses_config_driver_features(monkeypatch: pytest.MonkeyPatch) -> None: captured: dict[str, Any] = {} class _SessionContext: def __init__(self, **kwargs: Any) -> None: captured.update(kwargs) - config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) - sentinel_database = object() - config.get_database = lambda: sentinel_database # type: ignore[assignment] + request_options = {"priority": 1, "request_tag": "sqlspec-test"} + config = SpannerSyncConfig( + connection_config={"project": "p", "instance_id": "i", "database_id": "d"}, + driver_features={"request_options": request_options}, + ) monkeypatch.setattr(spanner_config, "SpannerSessionContext", _SessionContext) context = config.provide_session() assert isinstance(context, _SessionContext) + assert captured["driver_features"] is config.driver_features + assert captured["driver_features"]["request_options"] is request_options + assert "database_provider" not in captured["driver_features"] + + +def test_provide_session_accepts_spanner_execution_overrides(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, Any] = {} + + class _SessionContext: + def __init__(self, **kwargs: Any) -> None: + captured.update(kwargs) + + config_request_options = {"request_tag": "config"} + session_request_options = {"request_tag": "session"} + directed_read_options = cast(Any, SimpleNamespace(tag="directed")) + retry = cast(Any, object()) + config = SpannerSyncConfig( + connection_config={"project": "p", "instance_id": "i", "database_id": "d"}, + driver_features={"request_options": config_request_options}, + ) + monkeypatch.setattr(spanner_config, "SpannerSessionContext", _SessionContext) + + config.provide_session( + request_options=session_request_options, directed_read_options=directed_read_options, retry=retry, timeout=12.0 + ) + assert captured["driver_features"] is not config.driver_features - assert captured["driver_features"]["database_provider"] is config.get_database - assert "database_provider" not in config.driver_features + assert captured["driver_features"]["request_options"] is session_request_options + assert captured["driver_features"]["directed_read_options"] is directed_read_options + assert captured["driver_features"]["retry"] is retry + assert captured["driver_features"]["timeout"] == 12.0 + assert config.driver_features["request_options"] is config_request_options + assert "database_provider" not in captured["driver_features"] -def test_require_database_raises_without_provider() -> None: +def test_no_extra_public_database_operation_surface() -> None: driver = SpannerSyncDriver(MagicMock()) + config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) - with pytest.raises(ImproperConfigurationError, match="provide_session"): - driver._require_database() # type: ignore[attr-defined] + assert not hasattr(driver, "execute_with_options") + assert not hasattr(driver, "execute_partitioned_dml") + assert not hasattr(driver, "apply_mutations") + assert not hasattr(config, "provide_batch_snapshot") def test_driver_feature_request_options_forwarded_to_select() -> None: @@ -77,6 +110,20 @@ def test_driver_feature_request_options_forwarded_to_select() -> None: assert result.selected_data[0] == (1,) +def test_driver_feature_directed_read_options_forwarded_to_select() -> None: + directed_read_options = SimpleNamespace(tag="directed") + connection = _mock_transaction() + result_set = _mock_result_set() + connection.execute_sql.return_value = result_set + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"directed_read_options": directed_read_options}) + + statement = driver.prepare_statement("SELECT id FROM users", statement_config=driver.statement_config) + result = driver.dispatch_execute(connection, statement) # type: ignore[protected-access] + + assert result.is_select_result + assert connection.execute_sql.call_args.kwargs["directed_read_options"] is directed_read_options + + def test_driver_feature_request_options_forwarded_to_dml() -> None: request_options = {"priority": 1, "request_tag": "sqlspec-test"} connection = _mock_transaction() @@ -110,7 +157,7 @@ def test_driver_feature_request_options_forwarded_to_batch_update() -> None: assert connection.batch_update.call_args.kwargs["request_options"] == request_options -def test_execute_with_options_overrides_feature_defaults() -> None: +def test_execute_overrides_feature_defaults() -> None: feature_request_options = {"priority": 0, "request_tag": "feature"} override_request_options = {"priority": 1, "request_tag": "override"} connection = _mock_transaction() @@ -118,7 +165,7 @@ def test_execute_with_options_overrides_feature_defaults() -> None: connection.execute_sql.return_value = result_set driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": feature_request_options}) - result = driver.execute_with_options("SELECT id FROM users", request_options=override_request_options) + result = driver.execute("SELECT id FROM users", request_options=override_request_options) connection.execute_sql.assert_called_once_with( "SELECT id FROM users", params=None, param_types={}, request_options=override_request_options @@ -127,141 +174,68 @@ def test_execute_with_options_overrides_feature_defaults() -> None: assert driver._pending_execute_options is None # type: ignore[attr-defined] -def test_execute_with_options_directed_read_only_on_select() -> None: +def test_execute_directed_read_only_on_select() -> None: connection = _mock_transaction() result_set = _mock_result_set() connection.execute_sql.return_value = result_set driver = SpannerSyncDriver(cast(Any, connection)) directed_read_options = SimpleNamespace(tag="directed") - driver.execute_with_options("SELECT id FROM users", directed_read_options=directed_read_options) + driver.execute("SELECT id FROM users", directed_read_options=directed_read_options) assert connection.execute_sql.call_args.kwargs["directed_read_options"] is directed_read_options -def test_execute_with_options_directed_read_not_forwarded_to_dml() -> None: +def test_execute_directed_read_not_forwarded_to_dml() -> None: connection = _mock_transaction() connection.execute_update.return_value = 10 driver = SpannerSyncDriver(cast(Any, connection)) directed_read_options = SimpleNamespace(tag="directed") - result = driver.execute_with_options("UPDATE users SET name = 'Bob'", directed_read_options=directed_read_options) + result = driver.execute("UPDATE users SET name = 'Bob'", directed_read_options=directed_read_options) assert result.rows_affected == 10 assert "directed_read_options" not in connection.execute_update.call_args.kwargs -def test_execute_with_options_clears_stash_on_error() -> None: +def test_execute_many_overrides_feature_defaults() -> None: + feature_request_options = {"priority": 0, "request_tag": "feature"} + override_request_options = {"priority": 1, "request_tag": "override"} connection = _mock_transaction() - connection.execute_sql.side_effect = RuntimeError("boom") - driver = SpannerSyncDriver(cast(Any, connection)) - - with pytest.raises(RuntimeError, match="boom"): - driver.execute_with_options("SELECT id FROM users") - - assert driver._pending_execute_options is None # type: ignore[attr-defined] - - -def test_execute_partitioned_dml_raises_without_database_provider() -> None: - driver = SpannerSyncDriver(MagicMock()) - - with pytest.raises(ImproperConfigurationError, match="provide_session"): - driver.execute_partitioned_dml("UPDATE users SET name = 'Bob'") - - -def test_execute_partitioned_dml_forwards_to_database() -> None: - request_options = {"priority": 1, "request_tag": "sqlspec-test"} - database = MagicMock() - database.execute_partitioned_dml.return_value = 7 - driver = SpannerSyncDriver( - MagicMock(), driver_features={"request_options": request_options, "database_provider": lambda: database} - ) - - row_count = driver.execute_partitioned_dml( - "UPDATE users SET name = 'Bob' WHERE TRUE", request_options=request_options - ) - - assert row_count == 7 - database.execute_partitioned_dml.assert_called_once_with( - "UPDATE users SET name = 'Bob' WHERE TRUE", - params=None, - param_types={}, - request_options=request_options, - exclude_txn_from_change_streams=False, - ) - - -def test_apply_mutations_requires_columns_for_rows() -> None: - driver = SpannerSyncDriver(MagicMock(), driver_features={"database_provider": lambda: MagicMock()}) - - with pytest.raises(ImproperConfigurationError, match="columns"): - driver.apply_mutations("users", insert=[(1, "alice")]) - - -def test_apply_mutations_routes_each_group() -> None: - batch = MagicMock() - batch.__enter__.return_value = batch - batch.__exit__.return_value = False - database = MagicMock() - database.batch.return_value = batch - driver = SpannerSyncDriver(MagicMock(), driver_features={"database_provider": lambda: database}) + connection.batch_update.return_value = (None, [1]) + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": feature_request_options}) - driver.apply_mutations( - "users", - columns=("id", "name"), - insert=[(1, "alice")], - update=[(2, "bob")], - insert_or_update=[(3, "carol")], - replace=[(4, "dave")], - delete_keys=[(5,)], - request_options={"priority": 1}, - max_commit_delay=3.0, + result = driver.execute_many( + "UPDATE users SET name = @name WHERE id = @id", + [{"id": 1, "name": "alice"}], + request_options=override_request_options, ) - database.batch.assert_called_once_with(request_options={"priority": 1}, max_commit_delay=3.0) - batch.insert.assert_called_once_with("users", ("id", "name"), [(1, "alice")]) - batch.update.assert_called_once_with("users", ("id", "name"), [(2, "bob")]) - batch.insert_or_update.assert_called_once_with("users", ("id", "name"), [(3, "carol")]) - batch.replace.assert_called_once_with("users", ("id", "name"), [(4, "dave")]) - batch.delete.assert_called_once_with("users", KeySet(keys=[(5,)])) # type: ignore[no-untyped-call] - - -def test_apply_mutations_delete_all_wins_over_delete_keys() -> None: - batch = MagicMock() - batch.__enter__.return_value = batch - batch.__exit__.return_value = False - database = MagicMock() - database.batch.return_value = batch - driver = SpannerSyncDriver(MagicMock(), driver_features={"database_provider": lambda: database}) - - driver.apply_mutations("users", delete_keys=[(5,)], delete_all=True) - - batch.delete.assert_called_once_with("users", KeySet(all_=True)) # type: ignore[no-untyped-call] + assert result.rows_affected == 1 + assert connection.batch_update.call_args.kwargs["request_options"] is override_request_options + assert driver._pending_execute_options is None # type: ignore[attr-defined] -def test_provide_batch_snapshot_closes_on_normal_exit(monkeypatch: pytest.MonkeyPatch) -> None: - snapshot = MagicMock() - database = MagicMock() - database.batch_snapshot.return_value = snapshot - config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) - config.get_database = lambda: database # type: ignore[assignment] +def test_execute_script_overrides_feature_defaults() -> None: + feature_request_options = {"priority": 0, "request_tag": "feature"} + override_request_options = {"priority": 1, "request_tag": "override"} + connection = _mock_transaction() + connection.execute_sql.return_value = _mock_result_set() + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": feature_request_options}) - with config.provide_batch_snapshot() as yielded: - assert yielded is snapshot + driver.execute_script("SELECT id FROM users", request_options=override_request_options) - snapshot.close.assert_called_once() + connection.execute_sql.assert_called_once() + assert connection.execute_sql.call_args.kwargs["request_options"] is override_request_options + assert driver._pending_execute_options is None # type: ignore[attr-defined] -def test_provide_batch_snapshot_closes_on_exception(monkeypatch: pytest.MonkeyPatch) -> None: - snapshot = MagicMock() - database = MagicMock() - database.batch_snapshot.return_value = snapshot - config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) - config.get_database = lambda: database # type: ignore[assignment] +def test_execute_clears_stash_on_error() -> None: + connection = _mock_transaction() + connection.execute_sql.side_effect = RuntimeError("boom") + driver = SpannerSyncDriver(cast(Any, connection)) with pytest.raises(RuntimeError, match="boom"): - with config.provide_batch_snapshot() as yielded: - assert yielded is snapshot - raise RuntimeError("boom") + driver.execute("SELECT id FROM users", request_options={"request_tag": "x"}) - snapshot.close.assert_called_once() + assert driver._pending_execute_options is None # type: ignore[attr-defined]