diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 6a0916806..5b864726b 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -94,6 +94,9 @@ class BigQueryDriverFeatures(TypedDict): job_result_timeout: Timeout (seconds) for polling ``QueryJob.result()``. Defaults to the client polling default (waits indefinitely for the job using the API's per-call default timeouts). Also used as the per-request HTTP timeout when ``request_timeout`` is unset. + use_query_and_wait: Use ``Client.query_and_wait()`` for single-statement queries. + Pair with ``connection_config["default_job_creation_mode"]="JOB_CREATION_OPTIONAL"`` + to allow short queries to run without creating a job. Defaults to False. request_timeout: Per-request HTTP transport timeout (seconds) for the API calls that start query jobs. Bounds each request so a server that accepts the request but never responds (e.g. a wedged emulator) raises instead of blocking indefinitely. Defaults to @@ -111,6 +114,7 @@ class BigQueryDriverFeatures(TypedDict): events_backend: NotRequired[str] job_retry_deadline: NotRequired[float] job_result_timeout: NotRequired[float] + use_query_and_wait: NotRequired[bool] request_timeout: NotRequired[float] query_page_size: NotRequired[int] query_max_results: NotRequired[int] diff --git a/sqlspec/adapters/bigquery/core.py b/sqlspec/adapters/bigquery/core.py index ef8091441..81be63175 100644 --- a/sqlspec/adapters/bigquery/core.py +++ b/sqlspec/adapters/bigquery/core.py @@ -170,6 +170,7 @@ def try_bulk_insert( expression: "exp.Expr | None" = None, *, allow_parse: bool = True, + result_timeout: float | None = None, ) -> "int | None": """Attempt bulk insert via Parquet load. @@ -179,6 +180,7 @@ def try_bulk_insert( parameters: Parameter dictionaries for the insert. expression: Optional parsed expression to reuse. allow_parse: Whether to parse SQL when expression is unavailable. + result_timeout: Timeout forwarded to the load job request and result wait. Returns: Inserted row count if bulk insert succeeds, otherwise None. @@ -204,8 +206,8 @@ def try_bulk_insert( buffer.seek(0) job_config = build_load_job_config("parquet", overwrite=False) - job = connection.load_table_from_file(buffer, table_name, job_config=job_config) - job.result() + job = connection.load_table_from_file(buffer, table_name, job_config=job_config, timeout=result_timeout) + job.result(timeout=result_timeout) return len(parameters) except ImportError: logger.debug("pyarrow not available, falling back to literal inlining") @@ -535,6 +537,10 @@ def run_query_job( retry: Retry | None = None, timeout: float | None = None, job_retry: Retry | None = None, + api_method: str | None = None, + timestamp_precision: Any | None = None, + job_id: str | None = None, + job_id_prefix: str | None = None, ) -> QueryJob: """Execute a BigQuery query job with merged configuration. @@ -552,6 +558,11 @@ def run_query_job( job_retry: Retry policy attached to the returned query job. ``None`` disables job retries and the client's built-in ``jobs.insert`` retry wrapper (which carries a fixed 600s deadline). + api_method: Optional query API method override. + timestamp_precision: Optional timestamp precision override. + job_id: Explicit BigQuery job ID. + job_id_prefix: Prefix used by BigQuery to generate a job ID when + ``job_id`` is not provided. Returns: QueryJob object representing the executed job. @@ -569,9 +580,51 @@ def run_query_job( "timeout": timeout, "job_retry": job_retry, } + if api_method is not None: + query_kwargs["api_method"] = api_method + if timestamp_precision is not None: + query_kwargs["timestamp_precision"] = timestamp_precision + if job_id is not None: + query_kwargs["job_id"] = job_id + elif job_id_prefix is not None: + query_kwargs["job_id_prefix"] = job_id_prefix return connection.query(sql, **query_kwargs) +def _run_query_and_wait( + connection: "BigQueryConnection", + sql: str, + parameters: Any, + *, + default_job_config: QueryJobConfig | None, + json_serializer: "Callable[[Any], str]", + retry: Retry | None = None, + wait_timeout: float | None = None, + job_retry: Retry | None = None, + page_size: int | None = None, + max_results: int | None = None, +) -> Any: + """Execute a BigQuery query via query_and_wait and return the row iterator.""" + final_job_config = QueryJobConfig() + if default_job_config: + copy_job_config(default_job_config, final_job_config) + final_job_config.query_parameters = create_parameters(parameters, json_serializer) + + query_kwargs: dict[str, Any] = {"job_config": final_job_config} + if retry is not None: + query_kwargs["retry"] = retry + if wait_timeout is not None: + query_kwargs["api_timeout"] = wait_timeout + query_kwargs["wait_timeout"] = wait_timeout + if job_retry is not None: + query_kwargs["job_retry"] = job_retry + if page_size is not None: + query_kwargs["page_size"] = page_size + if max_results is not None: + query_kwargs["max_results"] = max_results + return connection.query_and_wait(sql, **query_kwargs) + + def build_load_job_config(file_format: "StorageFormat", overwrite: bool) -> "LoadJobConfig": job_config = LoadJobConfig() job_config.source_format = _map_bigquery_source_format(file_format) diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index a3f381f9c..534032cf8 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -16,6 +16,7 @@ from sqlspec.adapters.bigquery.core import ( DEFAULT_REQUEST_TIMEOUT, BigQueryStreamSource, + _run_query_and_wait, _uses_local_bigquery_endpoint, build_dml_rowcount, build_inlined_script, @@ -56,7 +57,7 @@ from google.cloud.bigquery import QueryJob, QueryJobConfig from sqlspec.builder import QueryBuilder - from sqlspec.core import SQL, ArrowResult, SQLResult, Statement, StatementFilter + from sqlspec.core import SQL, ArrowResult, Statement, StatementFilter from sqlspec.storage import ( StorageBridgeJob, StorageDestination, @@ -104,13 +105,14 @@ class BigQueryDriver(SyncDriverAdapterBase): "_column_name_cache", "_data_dictionary", "_default_query_job_config", - "_job_result_kwargs", + "_job_result_kwargs_defaults", "_job_result_timeout", "_job_retry", "_job_retry_deadline", "_json_serializer", "_literal_inliner", "_request_timeout", + "_use_query_and_wait", ) dialect = "bigquery" @@ -138,11 +140,12 @@ def __init__( self._default_query_job_config: QueryJobConfig | None = (driver_features or {}).get("default_query_job_config") self._data_dictionary: BigQueryDataDictionary | None = None self._column_name_cache: dict[int, tuple[Any, list[str]]] = {} - self._job_result_kwargs = self._build_job_result_kwargs(features) + self._job_result_kwargs_defaults = self._build_job_result_kwargs(features) self._job_retry_deadline = float(features.get("job_retry_deadline", 60.0)) self._job_retry: Retry | None = build_retry(self._job_retry_deadline) if self._job_retry_deadline > 0 else None self._job_result_timeout: float | object = features.get("job_result_timeout", POLLING_DEFAULT_VALUE) self._request_timeout = self._resolve_request_timeout(features) + self._use_query_and_wait = bool(features.get("use_query_and_wait", False)) def _resolve_request_timeout(self, features: "dict[str, Any]") -> float: timeout = features.get("request_timeout") @@ -163,6 +166,9 @@ def _build_job_result_kwargs(self, features: dict[str, Any]) -> dict[str, Any]: job_result_kwargs["max_results"] = query_max_results return job_result_kwargs + def _job_request_timeout(self) -> float: + return self._request_timeout + def _run_query_job(self, connection: "BigQueryConnection", sql: str, parameters: Any) -> "QueryJob": return run_query_job( connection, @@ -172,10 +178,13 @@ def _run_query_job(self, connection: "BigQueryConnection", sql: str, parameters: job_config=None, json_serializer=self._json_serializer, retry=self._job_retry, - timeout=self._request_timeout, + timeout=self._job_request_timeout(), job_retry=self._job_retry, ) + def _job_result_kwargs(self) -> dict[str, Any]: + return dict(self._job_result_kwargs_defaults) + # ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── @@ -191,6 +200,35 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult: ExecutionResult with query results and metadata """ sql, parameters = self._get_compiled_sql(statement, self.statement_config) + if self._use_query_and_wait: + row_iterator = _run_query_and_wait( + cursor, + sql, + parameters, + default_job_config=self._default_query_job_config, + json_serializer=self._json_serializer, + retry=self._job_retry, + wait_timeout=self._job_request_timeout(), + job_retry=self._job_retry, + ) + cursor.job = None + iterator_schema = getattr(row_iterator, "schema", None) + if statement.returns_rows() or iterator_schema: + column_names = resolve_column_names(iterator_schema, self._column_name_cache) + rows_list, _ = collect_rows(row_iterator, iterator_schema, column_names=column_names) + + return self.create_execution_result( + cursor, + selected_data=rows_list, + column_names=column_names, + data_row_count=len(rows_list), + is_select_result=True, + row_format="record", + ) + + affected_rows = build_dml_rowcount(row_iterator, 0) + return self.create_execution_result(cursor, rowcount_override=affected_rows) + cursor.job = self._run_query_job(cursor, sql, parameters) statement_type = str(cursor.job.statement_type or "").upper() is_select_like = ( @@ -199,7 +237,7 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult: if is_select_like: job_result = cursor.job.result( - job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs + job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs() ) job_schema = cursor.job.schema or getattr(job_result, "schema", None) column_names = resolve_column_names(job_schema, self._column_name_cache) @@ -246,7 +284,12 @@ def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> ExecutionResul allow_parse = statement.statement_config.enable_parsing if is_simple_insert(sql, parsed_expression, allow_parse=allow_parse): rowcount = try_bulk_insert( - self.connection, sql, prepared_parameters, parsed_expression, allow_parse=allow_parse + self.connection, + sql, + prepared_parameters, + parsed_expression, + allow_parse=allow_parse, + result_timeout=self._job_request_timeout(), ) if rowcount is not None: return self.create_execution_result(cursor, rowcount_override=rowcount, is_many_result=True) @@ -446,7 +489,7 @@ def select_to_arrow( with exc_handler: query_job = self._run_query_job(self.connection, sql, driver_params) query_job.result( - job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs + job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs() ) # Wait for completion # Native Arrow via Storage API @@ -528,8 +571,10 @@ def load_from_arrow( pq.write_table(arrow_table, buffer) buffer.seek(0) job_config = build_load_job_config("parquet", overwrite) - job = self.connection.load_table_from_file(buffer, table, job_config=job_config) - job.result() + job = self.connection.load_table_from_file( + buffer, table, job_config=job_config, timeout=self._job_request_timeout() + ) + job.result(timeout=self._job_request_timeout()) telemetry_payload = build_load_job_telemetry(job, table, format_label="parquet") if telemetry: telemetry_payload.setdefault("extra", {}) @@ -552,8 +597,10 @@ def load_from_storage( msg = "BigQuery storage bridge currently supports Parquet ingest only" raise StorageCapabilityError(msg, capability="parquet_import_enabled") job_config = build_load_job_config(file_format, overwrite) - job = self.connection.load_table_from_uri(source, table, job_config=job_config) - job.result() + job = self.connection.load_table_from_uri( + source, table, job_config=job_config, retry=self._job_retry, timeout=self._job_request_timeout() + ) + job.result(timeout=self._job_request_timeout()) telemetry_payload = build_load_job_telemetry(job, table, format_label=file_format) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload) diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index b3b8e030b..c94c1c695 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -7,6 +7,7 @@ from google.cloud.spanner_v1 import param_types from sqlspec.adapters.spanner.config import SpannerSyncConfig +from sqlspec.exceptions import OperationalError from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.protocols import SpannerParamTypesProtocol @@ -726,6 +727,12 @@ def __init__(self, statements: "list[tuple[str, dict[str, Any], dict[str, Any]]] self._statements = statements def __call__(self, transaction: "Transaction") -> None: + if len(self._statements) > 1: + status, _row_counts = transaction.batch_update(self._statements) # type: ignore[no-untyped-call] + if status.code != 0: + msg = f"Spanner batch update failed (code {status.code}): {status.message}" + raise OperationalError(msg) + return for sql, params, types in self._statements: transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call] @@ -737,6 +744,12 @@ def __init__(self, statements: "list[tuple[str, dict[str, Any], dict[str, Any]]] self._statements = statements def __call__(self, transaction: "Transaction") -> None: + if len(self._statements) > 1: + status, _row_counts = transaction.batch_update(self._statements) # type: ignore[no-untyped-call] + if status.code != 0: + msg = f"Spanner batch update failed (code {status.code}): {status.message}" + raise OperationalError(msg) + return for sql, params, types in self._statements: transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call] diff --git a/sqlspec/adapters/spanner/config.py b/sqlspec/adapters/spanner/config.py index 6b376845a..af9258033 100644 --- a/sqlspec/adapters/spanner/config.py +++ b/sqlspec/adapters/spanner/config.py @@ -26,7 +26,7 @@ from google.api_core.retry import Retry from google.auth.credentials import Credentials from google.cloud.spanner_admin_database_v1.types import DatabaseDialect, EncryptionConfig - from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest + from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest, RequestOptions from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.transaction import DefaultTransactionOptions @@ -137,6 +137,10 @@ class SpannerDriverFeatures(TypedDict): json_deserializer: Custom JSON deserializer for result conversion. retry: Per-request retry policy passed to execute_sql(), execute_update(), and batch_update(). timeout: Per-request timeout in seconds passed to execute_sql(), execute_update(), and batch_update(). + request_options: Default RequestOptions forwarded to execute_sql(), execute_update(), + and batch_update(). Per-call overrides are available through normal + driver execution methods. + directed_read_options: Default DirectedReadOptions forwarded to execute_sql(). session_labels: Deprecated compatibility alias for pool session labels. Prefer ``connection_config["session_labels"]``. enable_events: Enable database event channel support. @@ -150,6 +154,8 @@ class SpannerDriverFeatures(TypedDict): json_deserializer: "NotRequired[Callable[[str], Any]]" retry: "NotRequired[Retry | None]" timeout: "NotRequired[float | None]" + request_options: "NotRequired[RequestOptions | dict[str, Any] | None]" + directed_read_options: "NotRequired[DirectedReadOptions | None]" session_labels: "NotRequired[dict[str, str]]" enable_events: "NotRequired[bool]" events_backend: "NotRequired[str]" @@ -399,6 +405,10 @@ def provide_session( *args: Any, statement_config: "StatementConfig | None" = None, transaction: "bool" = _DEFAULT_SESSION_TRANSACTION, + request_options: "RequestOptions | dict[str, Any] | None" = None, + directed_read_options: "DirectedReadOptions | None" = None, + retry: "Retry | None" = None, + timeout: "float | None" = None, **kwargs: Any, ) -> "SpannerSessionContext": """Provide a Spanner driver session context manager. @@ -412,6 +422,10 @@ def provide_session( statement_config: Optional statement configuration override. transaction: Whether to use a Transaction (True, default) or Snapshot (False). + request_options: Session-scoped RequestOptions for Spanner statements. + directed_read_options: Session-scoped DirectedReadOptions for reads. + retry: Session-scoped retry policy for Spanner statement calls. + timeout: Session-scoped timeout for Spanner statement calls. **kwargs: Additional keyword arguments. Returns: @@ -424,25 +438,83 @@ def provide_session( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, + driver_features=self._session_driver_features( + request_options=request_options, + directed_read_options=directed_read_options, + retry=retry, + timeout=timeout, + ), prepare_driver=self._prepare_driver, ) def provide_write_session( - self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any + self, + *args: Any, + statement_config: "StatementConfig | None" = None, + request_options: "RequestOptions | dict[str, Any] | None" = None, + directed_read_options: "DirectedReadOptions | None" = None, + retry: "Retry | None" = None, + timeout: "float | None" = None, + **kwargs: Any, ) -> "SpannerSessionContext": """Provide a write-capable Spanner session (alias for :meth:`provide_session`).""" - return self.provide_session(*args, statement_config=statement_config, transaction=True, **kwargs) + return self.provide_session( + *args, + statement_config=statement_config, + transaction=True, + request_options=request_options, + directed_read_options=directed_read_options, + retry=retry, + timeout=timeout, + **kwargs, + ) def provide_read_session( - self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any + self, + *args: Any, + statement_config: "StatementConfig | None" = None, + request_options: "RequestOptions | dict[str, Any] | None" = None, + directed_read_options: "DirectedReadOptions | None" = None, + retry: "Retry | None" = None, + timeout: "float | None" = None, + **kwargs: Any, ) -> "SpannerSessionContext": """Provide a read-only Snapshot Spanner session. Use for query workloads that benefit from Spanner's snapshot reads. For DDL/DML, use :meth:`provide_session` (write-capable by default). """ - return self.provide_session(*args, statement_config=statement_config, transaction=False, **kwargs) + return self.provide_session( + *args, + statement_config=statement_config, + transaction=False, + request_options=request_options, + directed_read_options=directed_read_options, + retry=retry, + timeout=timeout, + **kwargs, + ) + + def _session_driver_features( + self, + *, + request_options: "RequestOptions | dict[str, Any] | None", + directed_read_options: "DirectedReadOptions | None", + retry: "Retry | None", + timeout: "float | None", + ) -> "dict[str, Any]": + if request_options is None and directed_read_options is None and retry is None and timeout is None: + return self.driver_features + driver_features = dict(self.driver_features) + if request_options is not None: + driver_features["request_options"] = request_options + if directed_read_options is not None: + driver_features["directed_read_options"] = directed_read_options + if retry is not None: + driver_features["retry"] = retry + if timeout is not None: + driver_features["timeout"] = timeout + return driver_features def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for SpannerSyncConfig types. diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index 8a35850e8..28a8c58b9 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -36,15 +36,18 @@ ) if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Sequence + from google.api_core.retry import Retry + from google.cloud.spanner_v1 import DirectedReadOptions, RequestOptions from sqlglot.dialects.dialect import DialectType from sqlspec.adapters.spanner._typing import SpannerConnection - from sqlspec.core import ArrowResult + from sqlspec.builder import QueryBuilder + from sqlspec.core import ArrowResult, SQLResult, Statement, StatementFilter from sqlspec.core.statement import SQL from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry - from sqlspec.typing import ArrowReturnFormat + from sqlspec.typing import ArrowReturnFormat, StatementParameters __all__ = ( "SpannerDataDictionary", @@ -111,11 +114,30 @@ def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "Ba return False +class _PerCallExecuteOptions: + """Per-call Spanner execution options captured for a single dispatch.""" + + __slots__ = ("directed_read_options", "request_options", "retry", "timeout") + + def __init__( + self, + *, + request_options: "RequestOptions | dict[str, Any] | None" = None, + directed_read_options: "DirectedReadOptions | None" = None, + retry: "Retry | None" = None, + timeout: "float | None" = None, + ) -> None: + self.request_options = request_options + self.directed_read_options = directed_read_options + self.retry = retry + self.timeout = timeout + + class SpannerSyncDriver(SyncDriverAdapterBase): """Synchronous Spanner driver operating on Snapshot or Transaction contexts.""" dialect: "DialectType" = "spanner" - __slots__ = ("_column_name_cache", "_data_dictionary", "_type_converter") + __slots__ = ("_column_name_cache", "_data_dictionary", "_pending_execute_options", "_type_converter") def __init__( self, @@ -136,6 +158,7 @@ def __init__( ) self._column_name_cache: dict[int, tuple[Any, list[str]]] = {} self._data_dictionary: SpannerDataDictionary | None = None + self._pending_execute_options: _PerCallExecuteOptions | None = None # ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS - The Execution Engine @@ -146,10 +169,10 @@ def dispatch_execute(self, cursor: "SpannerConnection", statement: "SQL") -> Exe params = cast("dict[str, Any] | None", params) coerced_params = self._coerce_params(params) param_types_map = self._infer_param_types(coerced_params) - execute_kwargs = self._execute_kwargs() if statement.returns_rows(): reader = cast("_SpannerReadProtocol", cursor) + execute_kwargs = self._execute_kwargs(for_read=True) result_set = reader.execute_sql(sql, params=coerced_params, param_types=param_types_map, **execute_kwargs) rows = list(result_set) try: @@ -174,6 +197,7 @@ def dispatch_execute(self, cursor: "SpannerConnection", statement: "SQL") -> Exe if supports_write(cursor): writer = cast("_SpannerWriteProtocol", cursor) + execute_kwargs = self._execute_kwargs() row_count = writer.execute_update(sql, params=coerced_params, param_types=param_types_map, **execute_kwargs) return self.create_execution_result(cursor, rowcount_override=row_count) @@ -225,7 +249,8 @@ def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL") script_params = cast("dict[str, Any] | None", params) coerced_params = self._coerce_params(script_params) param_types_map = self._infer_param_types(coerced_params) - execute_kwargs = self._execute_kwargs() + read_execute_kwargs = self._execute_kwargs(for_read=True) + write_execute_kwargs = self._execute_kwargs() for stmt in statements: try: parsed = _sqlglot.parse_one(stmt) @@ -236,9 +261,11 @@ def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL") raise SQLConversionError(_READ_ONLY_SNAPSHOT_ERROR_MESSAGE) if not is_select and is_transaction: writer = cast("_SpannerWriteProtocol", cursor) - writer.execute_update(stmt, params=coerced_params, param_types=param_types_map, **execute_kwargs) + writer.execute_update(stmt, params=coerced_params, param_types=param_types_map, **write_execute_kwargs) else: - _ = list(reader.execute_sql(stmt, params=coerced_params, param_types=param_types_map, **execute_kwargs)) + _ = list( + reader.execute_sql(stmt, params=coerced_params, param_types=param_types_map, **read_execute_kwargs) + ) count += 1 return self.create_execution_result( @@ -270,8 +297,93 @@ def with_cursor(self, connection: "SpannerConnection") -> "SpannerSyncCursor": def handle_database_exceptions(self) -> "SpannerExceptionHandler": return SpannerExceptionHandler() - def _execute_kwargs(self) -> dict[str, Any]: - return {key: self.driver_features[key] for key in ("retry", "timeout") if key in self.driver_features} + def execute( + self, + statement: "SQL | Statement | QueryBuilder", + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + **kwargs: Any, + ) -> "SQLResult": + """Execute a statement with optional Spanner per-call request options.""" + execute_options = self._pop_execute_options(kwargs) + if execute_options is None: + return super().execute(statement, *parameters, statement_config=statement_config, **kwargs) + previous_options = self._pending_execute_options + self._pending_execute_options = execute_options + try: + return super().execute(statement, *parameters, statement_config=statement_config, **kwargs) + finally: + self._pending_execute_options = previous_options + + def execute_many( + self, + statement: "SQL | Statement | QueryBuilder", + /, + parameters: "Sequence[StatementParameters]", + *filters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + **kwargs: Any, + ) -> "SQLResult": + """Execute a batch statement with optional Spanner per-call request options.""" + execute_options = self._pop_execute_options(kwargs) + if execute_options is None: + return super().execute_many(statement, parameters, *filters, statement_config=statement_config, **kwargs) + previous_options = self._pending_execute_options + self._pending_execute_options = execute_options + try: + return super().execute_many(statement, parameters, *filters, statement_config=statement_config, **kwargs) + finally: + self._pending_execute_options = previous_options + + def execute_script( + self, + statement: "str | SQL", + /, + *parameters: "StatementParameters | StatementFilter", + statement_config: "StatementConfig | None" = None, + **kwargs: Any, + ) -> "SQLResult": + """Execute a multi-statement script with optional Spanner per-call request options.""" + execute_options = self._pop_execute_options(kwargs) + if execute_options is None: + return super().execute_script(statement, *parameters, statement_config=statement_config, **kwargs) + previous_options = self._pending_execute_options + self._pending_execute_options = execute_options + try: + return super().execute_script(statement, *parameters, statement_config=statement_config, **kwargs) + finally: + self._pending_execute_options = previous_options + + def _execute_kwargs(self, *, for_read: bool = False) -> dict[str, Any]: + kwargs = {key: self.driver_features[key] for key in ("retry", "timeout") if key in self.driver_features} + request_options = self.driver_features.get("request_options") + if request_options is not None: + kwargs["request_options"] = request_options + directed_read_options = self.driver_features.get("directed_read_options") + if for_read and directed_read_options is not None: + kwargs["directed_read_options"] = directed_read_options + pending = self._pending_execute_options + if pending is not None: + if pending.request_options is not None: + kwargs["request_options"] = pending.request_options + if pending.retry is not None: + kwargs["retry"] = pending.retry + if pending.timeout is not None: + kwargs["timeout"] = pending.timeout + if for_read and pending.directed_read_options is not None: + kwargs["directed_read_options"] = pending.directed_read_options + return kwargs + + def _pop_execute_options(self, kwargs: dict[str, Any]) -> "_PerCallExecuteOptions | None": + if not any(key in kwargs for key in ("request_options", "directed_read_options", "retry", "timeout")): + return None + return _PerCallExecuteOptions( + request_options=kwargs.pop("request_options", None), + directed_read_options=kwargs.pop("directed_read_options", None), + retry=kwargs.pop("retry", None), + timeout=kwargs.pop("timeout", None), + ) # ───────────────────────────────────────────────────────────────────────────── # ARROW API METHODS diff --git a/tests/unit/adapters/test_bigquery/test_config.py b/tests/unit/adapters/test_bigquery/test_config.py index c5d8d5c40..9c7145d2a 100644 --- a/tests/unit/adapters/test_bigquery/test_config.py +++ b/tests/unit/adapters/test_bigquery/test_config.py @@ -95,5 +95,6 @@ def test_bigquery_config_routes_current_client_level_settings(monkeypatch: Monke def test_bigquery_config_typed_surfaces_do_not_advertise_inert_settings() -> None: """Typed public settings should only include options that SQLSpec routes.""" assert "credentials_path" not in BigQueryConnectionParams.__annotations__ + assert "use_query_and_wait" in BigQueryDriverFeatures.__annotations__ assert "on_job_start" not in BigQueryDriverFeatures.__annotations__ assert "on_job_complete" not in BigQueryDriverFeatures.__annotations__ diff --git a/tests/unit/adapters/test_bigquery/test_job_controls.py b/tests/unit/adapters/test_bigquery/test_job_controls.py new file mode 100644 index 000000000..ff9fa7762 --- /dev/null +++ b/tests/unit/adapters/test_bigquery/test_job_controls.py @@ -0,0 +1,227 @@ +"""Unit tests for BigQuery job-control behavior.""" + +from types import SimpleNamespace +from typing import Any, cast + +import pyarrow as pa +from google.cloud.bigquery import LoadJobConfig +from google.cloud.bigquery.enums import QueryApiMethod, TimestampPrecision + +from sqlspec.adapters.bigquery.core import build_load_job_config, run_query_job, try_bulk_insert +from sqlspec.adapters.bigquery.driver import BigQueryDriver +from sqlspec.utils.serializers import to_json + +CAPABILITIES = { + "arrow_export_enabled": True, + "arrow_import_enabled": True, + "parquet_export_enabled": True, + "parquet_import_enabled": True, + "requires_staging_for_load": False, + "staging_protocols": [], + "partition_strategies": ["fixed"], +} + + +class _RecordingJob: + def __init__( + self, + rows: list[dict[str, object]] | None = None, + *, + statement_type: str = "SELECT", + schema: list[SimpleNamespace] | None = None, + num_dml_affected_rows: int | None = None, + job_id: str = "job_123", + ) -> None: + self.rows = rows or [] + self.statement_type = statement_type + self.schema = schema + self.num_dml_affected_rows = num_dml_affected_rows + self.job_id = job_id + self.labels: dict[str, str] = {} + self.started = None + self.ended = None + self._properties: dict[str, Any] = {} + self.result_calls: list[dict[str, Any]] = [] + + def result(self, **kwargs: Any) -> list[dict[str, object]]: + self.result_calls.append(kwargs) + return self.rows + + +class _RecordingRowIterator: + def __init__( + self, + rows: list[dict[str, object]] | None = None, + *, + schema: list[SimpleNamespace] | None = None, + num_dml_affected_rows: int | None = None, + ) -> None: + self.rows = rows or [] + self.schema = schema + self.num_dml_affected_rows = num_dml_affected_rows + + def __iter__(self) -> Any: + return iter(self.rows) + + +class _RecordingConnection: + def __init__(self) -> None: + self.query_calls: list[tuple[str, dict[str, Any]]] = [] + self.query_and_wait_calls: list[tuple[str, dict[str, Any]]] = [] + self.load_file_calls: list[tuple[Any, Any, dict[str, Any]]] = [] + self.load_uri_calls: list[tuple[Any, Any, dict[str, Any]]] = [] + self.query_job = _RecordingJob() + self.row_iterator = _RecordingRowIterator() + self.load_job = _RecordingJob(statement_type="LOAD", schema=None) + + def query(self, sql: str, **kwargs: Any) -> _RecordingJob: + self.query_calls.append((sql, kwargs)) + return self.query_job + + def query_and_wait(self, sql: str, **kwargs: Any) -> _RecordingRowIterator: + self.query_and_wait_calls.append((sql, kwargs)) + return self.row_iterator + + def load_table_from_file(self, file_obj: Any, destination: Any, **kwargs: Any) -> _RecordingJob: + self.load_file_calls.append((file_obj, destination, kwargs)) + return self.load_job + + def load_table_from_uri(self, source_uris: Any, destination: Any, **kwargs: Any) -> _RecordingJob: + self.load_uri_calls.append((source_uris, destination, kwargs)) + return self.load_job + + +def _schema(*names: str) -> list[SimpleNamespace]: + return [SimpleNamespace(name=name) for name in names] + + +def test_run_query_job_passes_job_id_prefix_only_without_job_id() -> None: + connection = _RecordingConnection() + + run_query_job( + cast(Any, connection), + "SELECT @name", + {"name": "alpha"}, + default_job_config=None, + job_config=None, + json_serializer=to_json, + retry=None, + timeout=3.0, + job_retry=None, + api_method=QueryApiMethod.INSERT, + timestamp_precision=TimestampPrecision.MICROSECOND, + job_id_prefix="prefix-", + ) + + sql, kwargs = connection.query_calls[0] + assert sql == "SELECT @name" + assert kwargs["api_method"] == QueryApiMethod.INSERT + assert kwargs["timestamp_precision"] == TimestampPrecision.MICROSECOND + assert kwargs["job_id_prefix"] == "prefix-" + assert "job_id" not in kwargs + + +def test_query_and_wait_used_when_enabled() -> None: + connection = _RecordingConnection() + connection.row_iterator = _RecordingRowIterator(rows=[{"v": 1}], schema=_schema("v")) + driver = BigQueryDriver(cast(Any, connection), driver_features={"use_query_and_wait": True}) + + result = driver.execute("SELECT 1 AS v") + + assert connection.query_calls == [] + assert connection.query_and_wait_calls[0][0] == "SELECT 1 AS v" + assert connection.query_and_wait_calls[0][1]["api_timeout"] == driver._job_request_timeout() + assert connection.query_and_wait_calls[0][1]["wait_timeout"] == driver._job_request_timeout() + assert result.get_data()[0]["v"] == 1 + + +def test_query_and_wait_dml_rowcount() -> None: + connection = _RecordingConnection() + connection.row_iterator = _RecordingRowIterator(num_dml_affected_rows=3) + driver = BigQueryDriver(cast(Any, connection), driver_features={"use_query_and_wait": True}) + + result = driver.execute("UPDATE t SET v = 1") + + assert connection.query_calls == [] + assert result.rows_affected == 3 + + +def test_query_and_wait_default_off() -> None: + connection = _RecordingConnection() + connection.query_job = _RecordingJob(rows=[{"v": 1}], schema=_schema("v")) + connection.row_iterator = _RecordingRowIterator(rows=[{"v": 1}], schema=_schema("v")) + driver = BigQueryDriver(cast(Any, connection)) + + result = driver.execute("SELECT 1 AS v") + + assert connection.query_calls + assert connection.query_and_wait_calls == [] + assert result.get_data()[0]["v"] == 1 + + +def test_try_bulk_insert_bounds_result_timeout() -> None: + connection = _RecordingConnection() + connection.load_job = _RecordingJob(statement_type="LOAD") + + rowcount = try_bulk_insert( + cast(Any, connection), "INSERT INTO dataset.table (id) VALUES (@id)", [{"id": 1}], result_timeout=5.0 + ) + + assert rowcount == 1 + assert connection.load_file_calls[0][2]["timeout"] == 5.0 + assert connection.load_job.result_calls[0]["timeout"] == 5.0 + + +def test_load_from_arrow_bounds_result_timeout() -> None: + connection = _RecordingConnection() + connection.load_job = _RecordingJob(statement_type="LOAD") + driver = BigQueryDriver( + cast(Any, connection), driver_features={"job_result_timeout": 5.0, "storage_capabilities": CAPABILITIES} + ) + + result = driver.load_from_arrow("dataset.table", pa.table({"id": [1]})) + + assert result.telemetry["rows_processed"] == 0 + assert connection.load_file_calls[0][2]["timeout"] == driver._job_request_timeout() + assert connection.load_job.result_calls[0]["timeout"] == driver._job_request_timeout() + + +def test_load_from_storage_forwards_retry_and_bounds_result_timeout() -> None: + connection = _RecordingConnection() + connection.load_job = _RecordingJob(statement_type="LOAD") + driver = BigQueryDriver(cast(Any, connection), driver_features={"job_result_timeout": 5.0}) + + result = driver.load_from_storage("dataset.table", "gs://bucket/object.parquet", file_format="parquet") + + assert result.telemetry["rows_processed"] == 0 + assert connection.load_uri_calls[0][2]["retry"] is driver._job_retry + assert connection.load_uri_calls[0][2]["timeout"] == driver._job_request_timeout() + assert connection.load_job.result_calls[0]["timeout"] == driver._job_request_timeout() + + +def test_load_job_config_fill_from_default_preserves_defaults() -> None: + default_job_config = LoadJobConfig(labels={"source": "default"}) + + filled_job_config = build_load_job_config("parquet", overwrite=False)._fill_from_default( # type: ignore[no-untyped-call] + default_job_config + ) + + assert filled_job_config.source_format == "PARQUET" + assert filled_job_config.write_disposition == "WRITE_APPEND" + assert filled_job_config.labels == {"source": "default"} + + +def test_select_to_storage_uses_existing_export_surface_and_job_controls(tmp_path: Any) -> None: + connection = _RecordingConnection() + cast(Any, connection)._connection = SimpleNamespace(API_BASE_URL="http://localhost") + connection.query_job = _RecordingJob(rows=[{"id": 1}], schema=_schema("id")) + driver = BigQueryDriver( + cast(Any, connection), driver_features={"job_result_timeout": 5.0, "storage_capabilities": CAPABILITIES} + ) + + job = driver.select_to_storage("SELECT 1 AS id", tmp_path / "result.parquet", format_hint="parquet") + + assert job.telemetry["rows_processed"] == 1 + assert connection.query_calls[0][0] == "SELECT 1 AS id" + assert connection.query_calls[0][1]["timeout"] == driver._job_request_timeout() + assert connection.query_job.result_calls[0]["timeout"] == driver._job_result_timeout diff --git a/tests/unit/adapters/test_spanner/test_config.py b/tests/unit/adapters/test_spanner/test_config.py index 52e5aeaa6..7264af465 100644 --- a/tests/unit/adapters/test_spanner/test_config.py +++ b/tests/unit/adapters/test_spanner/test_config.py @@ -5,7 +5,12 @@ from google.cloud.spanner_v1.pool import AbstractSessionPool, BurstyPool, FixedSizePool from sqlspec.adapters.spanner import config as config_module -from sqlspec.adapters.spanner.config import SpannerConnectionParams, SpannerPoolParams, SpannerSyncConfig +from sqlspec.adapters.spanner.config import ( + SpannerConnectionParams, + SpannerDriverFeatures, + SpannerPoolParams, + SpannerSyncConfig, +) from sqlspec.adapters.spanner.core import default_statement_config from sqlspec.driver import SyncDriverAdapterBase from sqlspec.exceptions import ImproperConfigurationError @@ -70,6 +75,7 @@ def test_driver_features_defaults() -> None: config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) assert config.driver_features["enable_uuid_conversion"] is True assert config.driver_features["json_serializer"] is not None + assert "request_options" in SpannerDriverFeatures.__annotations__ def test_driver_feature_session_labels_are_routed_to_pool() -> None: diff --git a/tests/unit/adapters/test_spanner/test_session_controls.py b/tests/unit/adapters/test_spanner/test_session_controls.py new file mode 100644 index 000000000..abb90d5bb --- /dev/null +++ b/tests/unit/adapters/test_spanner/test_session_controls.py @@ -0,0 +1,241 @@ +"""Unit tests for Spanner session-control behavior.""" + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from google.cloud.spanner_v1 import Transaction +from google.cloud.spanner_v1.streamed import StreamedResultSet + +import sqlspec.adapters.spanner.config as spanner_config +from sqlspec.adapters.spanner.config import SpannerSyncConfig +from sqlspec.adapters.spanner.driver import SpannerSyncDriver + + +def _mock_result_set() -> MagicMock: + result = MagicMock(spec=StreamedResultSet) + field = MagicMock() + field.name = "id" + result.metadata.row_type.fields = [field] + result.__iter__.return_value = iter([(1,)]) + return result + + +def _mock_transaction() -> MagicMock: + transaction = MagicMock(spec=Transaction) + transaction.execute_sql = MagicMock() + transaction.execute_update = MagicMock() + transaction.batch_update = MagicMock() + return transaction + + +def test_provide_session_uses_config_driver_features(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, Any] = {} + + class _SessionContext: + def __init__(self, **kwargs: Any) -> None: + captured.update(kwargs) + + request_options = {"priority": 1, "request_tag": "sqlspec-test"} + config = SpannerSyncConfig( + connection_config={"project": "p", "instance_id": "i", "database_id": "d"}, + driver_features={"request_options": request_options}, + ) + monkeypatch.setattr(spanner_config, "SpannerSessionContext", _SessionContext) + + context = config.provide_session() + + assert isinstance(context, _SessionContext) + assert captured["driver_features"] is config.driver_features + assert captured["driver_features"]["request_options"] is request_options + assert "database_provider" not in captured["driver_features"] + + +def test_provide_session_accepts_spanner_execution_overrides(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, Any] = {} + + class _SessionContext: + def __init__(self, **kwargs: Any) -> None: + captured.update(kwargs) + + config_request_options = {"request_tag": "config"} + session_request_options = {"request_tag": "session"} + directed_read_options = cast(Any, SimpleNamespace(tag="directed")) + retry = cast(Any, object()) + config = SpannerSyncConfig( + connection_config={"project": "p", "instance_id": "i", "database_id": "d"}, + driver_features={"request_options": config_request_options}, + ) + monkeypatch.setattr(spanner_config, "SpannerSessionContext", _SessionContext) + + config.provide_session( + request_options=session_request_options, directed_read_options=directed_read_options, retry=retry, timeout=12.0 + ) + + assert captured["driver_features"] is not config.driver_features + assert captured["driver_features"]["request_options"] is session_request_options + assert captured["driver_features"]["directed_read_options"] is directed_read_options + assert captured["driver_features"]["retry"] is retry + assert captured["driver_features"]["timeout"] == 12.0 + assert config.driver_features["request_options"] is config_request_options + assert "database_provider" not in captured["driver_features"] + + +def test_no_extra_public_database_operation_surface() -> None: + driver = SpannerSyncDriver(MagicMock()) + config = SpannerSyncConfig(connection_config={"project": "p", "instance_id": "i", "database_id": "d"}) + + assert not hasattr(driver, "execute_with_options") + assert not hasattr(driver, "execute_partitioned_dml") + assert not hasattr(driver, "apply_mutations") + assert not hasattr(config, "provide_batch_snapshot") + + +def test_driver_feature_request_options_forwarded_to_select() -> None: + request_options = {"priority": 1, "request_tag": "sqlspec-test"} + connection = _mock_transaction() + result_set = _mock_result_set() + connection.execute_sql.return_value = result_set + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": request_options}) + + statement = driver.prepare_statement("SELECT id FROM users", statement_config=driver.statement_config) + result = driver.dispatch_execute(connection, statement) # type: ignore[protected-access] + + assert result.is_select_result + connection.execute_sql.assert_called_once_with( + "SELECT id FROM users", params=None, param_types={}, request_options=request_options + ) + assert result.selected_data is not None + assert result.selected_data[0] == (1,) + + +def test_driver_feature_directed_read_options_forwarded_to_select() -> None: + directed_read_options = SimpleNamespace(tag="directed") + connection = _mock_transaction() + result_set = _mock_result_set() + connection.execute_sql.return_value = result_set + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"directed_read_options": directed_read_options}) + + statement = driver.prepare_statement("SELECT id FROM users", statement_config=driver.statement_config) + result = driver.dispatch_execute(connection, statement) # type: ignore[protected-access] + + assert result.is_select_result + assert connection.execute_sql.call_args.kwargs["directed_read_options"] is directed_read_options + + +def test_driver_feature_request_options_forwarded_to_dml() -> None: + request_options = {"priority": 1, "request_tag": "sqlspec-test"} + connection = _mock_transaction() + connection.execute_update.return_value = 10 + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": request_options}) + + statement = driver.prepare_statement("UPDATE users SET name = 'Bob'", statement_config=driver.statement_config) + result = driver.dispatch_execute(connection, statement) # type: ignore[protected-access] + + assert result.rowcount_override == 10 + connection.execute_update.assert_called_once_with( + "UPDATE users SET name = 'Bob'", params=None, param_types={}, request_options=request_options + ) + + +def test_driver_feature_request_options_forwarded_to_batch_update() -> None: + request_options = {"priority": 1, "request_tag": "sqlspec-test"} + connection = _mock_transaction() + connection.batch_update.return_value = (None, [1, 1]) + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": request_options}) + + statement = driver.prepare_statement( + "UPDATE users SET name = @name WHERE id = @id", + parameters=([{"id": 1, "name": "alice"}, {"id": 2, "name": "bob"}],), + statement_config=driver.statement_config, + ) + result = driver.dispatch_execute_many(connection, statement) # type: ignore[protected-access] + + assert result.rowcount_override == 2 + connection.batch_update.assert_called_once() + assert connection.batch_update.call_args.kwargs["request_options"] == request_options + + +def test_execute_overrides_feature_defaults() -> None: + feature_request_options = {"priority": 0, "request_tag": "feature"} + override_request_options = {"priority": 1, "request_tag": "override"} + connection = _mock_transaction() + result_set = _mock_result_set() + connection.execute_sql.return_value = result_set + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": feature_request_options}) + + result = driver.execute("SELECT id FROM users", request_options=override_request_options) + + connection.execute_sql.assert_called_once_with( + "SELECT id FROM users", params=None, param_types={}, request_options=override_request_options + ) + assert result.get_data()[0]["id"] == 1 + assert driver._pending_execute_options is None # type: ignore[attr-defined] + + +def test_execute_directed_read_only_on_select() -> None: + connection = _mock_transaction() + result_set = _mock_result_set() + connection.execute_sql.return_value = result_set + driver = SpannerSyncDriver(cast(Any, connection)) + + directed_read_options = SimpleNamespace(tag="directed") + driver.execute("SELECT id FROM users", directed_read_options=directed_read_options) + + assert connection.execute_sql.call_args.kwargs["directed_read_options"] is directed_read_options + + +def test_execute_directed_read_not_forwarded_to_dml() -> None: + connection = _mock_transaction() + connection.execute_update.return_value = 10 + driver = SpannerSyncDriver(cast(Any, connection)) + + directed_read_options = SimpleNamespace(tag="directed") + result = driver.execute("UPDATE users SET name = 'Bob'", directed_read_options=directed_read_options) + + assert result.rows_affected == 10 + assert "directed_read_options" not in connection.execute_update.call_args.kwargs + + +def test_execute_many_overrides_feature_defaults() -> None: + feature_request_options = {"priority": 0, "request_tag": "feature"} + override_request_options = {"priority": 1, "request_tag": "override"} + connection = _mock_transaction() + connection.batch_update.return_value = (None, [1]) + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": feature_request_options}) + + result = driver.execute_many( + "UPDATE users SET name = @name WHERE id = @id", + [{"id": 1, "name": "alice"}], + request_options=override_request_options, + ) + + assert result.rows_affected == 1 + assert connection.batch_update.call_args.kwargs["request_options"] is override_request_options + assert driver._pending_execute_options is None # type: ignore[attr-defined] + + +def test_execute_script_overrides_feature_defaults() -> None: + feature_request_options = {"priority": 0, "request_tag": "feature"} + override_request_options = {"priority": 1, "request_tag": "override"} + connection = _mock_transaction() + connection.execute_sql.return_value = _mock_result_set() + driver = SpannerSyncDriver(cast(Any, connection), driver_features={"request_options": feature_request_options}) + + driver.execute_script("SELECT id FROM users", request_options=override_request_options) + + connection.execute_sql.assert_called_once() + assert connection.execute_sql.call_args.kwargs["request_options"] is override_request_options + assert driver._pending_execute_options is None # type: ignore[attr-defined] + + +def test_execute_clears_stash_on_error() -> None: + connection = _mock_transaction() + connection.execute_sql.side_effect = RuntimeError("boom") + driver = SpannerSyncDriver(cast(Any, connection)) + + with pytest.raises(RuntimeError, match="boom"): + driver.execute("SELECT id FROM users", request_options={"request_tag": "x"}) + + assert driver._pending_execute_options is None # type: ignore[attr-defined]