Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sqlspec/adapters/bigquery/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class BigQueryDriverFeatures(TypedDict):
job_result_timeout: Timeout (seconds) for polling ``QueryJob.result()``. Defaults to the
client polling default (waits indefinitely for the job using the API's per-call default
timeouts). Also used as the per-request HTTP timeout when ``request_timeout`` is unset.
use_query_and_wait: Use ``Client.query_and_wait()`` for single-statement queries.
Pair with ``connection_config["default_job_creation_mode"]="JOB_CREATION_OPTIONAL"``
to allow short queries to run without creating a job. Defaults to False.
request_timeout: Per-request HTTP transport timeout (seconds) for the API calls that start
query jobs. Bounds each request so a server that accepts the request but never responds
(e.g. a wedged emulator) raises instead of blocking indefinitely. Defaults to
Expand All @@ -111,6 +114,7 @@ class BigQueryDriverFeatures(TypedDict):
events_backend: NotRequired[str]
job_retry_deadline: NotRequired[float]
job_result_timeout: NotRequired[float]
use_query_and_wait: NotRequired[bool]
request_timeout: NotRequired[float]
query_page_size: NotRequired[int]
query_max_results: NotRequired[int]
Expand Down
57 changes: 55 additions & 2 deletions sqlspec/adapters/bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def try_bulk_insert(
expression: "exp.Expr | None" = None,
*,
allow_parse: bool = True,
result_timeout: float | None = None,
) -> "int | None":
"""Attempt bulk insert via Parquet load.

Expand All @@ -179,6 +180,7 @@ def try_bulk_insert(
parameters: Parameter dictionaries for the insert.
expression: Optional parsed expression to reuse.
allow_parse: Whether to parse SQL when expression is unavailable.
result_timeout: Timeout forwarded to the load job request and result wait.

Returns:
Inserted row count if bulk insert succeeds, otherwise None.
Expand All @@ -204,8 +206,8 @@ def try_bulk_insert(
buffer.seek(0)

job_config = build_load_job_config("parquet", overwrite=False)
job = connection.load_table_from_file(buffer, table_name, job_config=job_config)
job.result()
job = connection.load_table_from_file(buffer, table_name, job_config=job_config, timeout=result_timeout)
job.result(timeout=result_timeout)
return len(parameters)
except ImportError:
logger.debug("pyarrow not available, falling back to literal inlining")
Expand Down Expand Up @@ -535,6 +537,10 @@ def run_query_job(
retry: Retry | None = None,
timeout: float | None = None,
job_retry: Retry | None = None,
api_method: str | None = None,
timestamp_precision: Any | None = None,
job_id: str | None = None,
job_id_prefix: str | None = None,
) -> QueryJob:
"""Execute a BigQuery query job with merged configuration.

Expand All @@ -552,6 +558,11 @@ def run_query_job(
job_retry: Retry policy attached to the returned query job. ``None``
disables job retries and the client's built-in ``jobs.insert``
retry wrapper (which carries a fixed 600s deadline).
api_method: Optional query API method override.
timestamp_precision: Optional timestamp precision override.
job_id: Explicit BigQuery job ID.
job_id_prefix: Prefix used by BigQuery to generate a job ID when
``job_id`` is not provided.

Returns:
QueryJob object representing the executed job.
Expand All @@ -569,9 +580,51 @@ def run_query_job(
"timeout": timeout,
"job_retry": job_retry,
}
if api_method is not None:
query_kwargs["api_method"] = api_method
if timestamp_precision is not None:
query_kwargs["timestamp_precision"] = timestamp_precision
if job_id is not None:
query_kwargs["job_id"] = job_id
elif job_id_prefix is not None:
query_kwargs["job_id_prefix"] = job_id_prefix
return connection.query(sql, **query_kwargs)


def _run_query_and_wait(
connection: "BigQueryConnection",
sql: str,
parameters: Any,
*,
default_job_config: QueryJobConfig | None,
json_serializer: "Callable[[Any], str]",
retry: Retry | None = None,
wait_timeout: float | None = None,
job_retry: Retry | None = None,
page_size: int | None = None,
max_results: int | None = None,
) -> Any:
"""Execute a BigQuery query via query_and_wait and return the row iterator."""
final_job_config = QueryJobConfig()
if default_job_config:
copy_job_config(default_job_config, final_job_config)
final_job_config.query_parameters = create_parameters(parameters, json_serializer)

query_kwargs: dict[str, Any] = {"job_config": final_job_config}
if retry is not None:
query_kwargs["retry"] = retry
if wait_timeout is not None:
query_kwargs["api_timeout"] = wait_timeout
query_kwargs["wait_timeout"] = wait_timeout
if job_retry is not None:
query_kwargs["job_retry"] = job_retry
if page_size is not None:
query_kwargs["page_size"] = page_size
if max_results is not None:
query_kwargs["max_results"] = max_results
return connection.query_and_wait(sql, **query_kwargs)


def build_load_job_config(file_format: "StorageFormat", overwrite: bool) -> "LoadJobConfig":
job_config = LoadJobConfig()
job_config.source_format = _map_bigquery_source_format(file_format)
Expand Down
69 changes: 58 additions & 11 deletions sqlspec/adapters/bigquery/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sqlspec.adapters.bigquery.core import (
DEFAULT_REQUEST_TIMEOUT,
BigQueryStreamSource,
_run_query_and_wait,
_uses_local_bigquery_endpoint,
build_dml_rowcount,
build_inlined_script,
Expand Down Expand Up @@ -56,7 +57,7 @@
from google.cloud.bigquery import QueryJob, QueryJobConfig

from sqlspec.builder import QueryBuilder
from sqlspec.core import SQL, ArrowResult, SQLResult, Statement, StatementFilter
from sqlspec.core import SQL, ArrowResult, Statement, StatementFilter
from sqlspec.storage import (
StorageBridgeJob,
StorageDestination,
Expand Down Expand Up @@ -104,13 +105,14 @@ class BigQueryDriver(SyncDriverAdapterBase):
"_column_name_cache",
"_data_dictionary",
"_default_query_job_config",
"_job_result_kwargs",
"_job_result_kwargs_defaults",
"_job_result_timeout",
"_job_retry",
"_job_retry_deadline",
"_json_serializer",
"_literal_inliner",
"_request_timeout",
"_use_query_and_wait",
)
dialect = "bigquery"

Expand Down Expand Up @@ -138,11 +140,12 @@ def __init__(
self._default_query_job_config: QueryJobConfig | None = (driver_features or {}).get("default_query_job_config")
self._data_dictionary: BigQueryDataDictionary | None = None
self._column_name_cache: dict[int, tuple[Any, list[str]]] = {}
self._job_result_kwargs = self._build_job_result_kwargs(features)
self._job_result_kwargs_defaults = self._build_job_result_kwargs(features)
self._job_retry_deadline = float(features.get("job_retry_deadline", 60.0))
self._job_retry: Retry | None = build_retry(self._job_retry_deadline) if self._job_retry_deadline > 0 else None
self._job_result_timeout: float | object = features.get("job_result_timeout", POLLING_DEFAULT_VALUE)
self._request_timeout = self._resolve_request_timeout(features)
self._use_query_and_wait = bool(features.get("use_query_and_wait", False))

def _resolve_request_timeout(self, features: "dict[str, Any]") -> float:
timeout = features.get("request_timeout")
Expand All @@ -163,6 +166,9 @@ def _build_job_result_kwargs(self, features: dict[str, Any]) -> dict[str, Any]:
job_result_kwargs["max_results"] = query_max_results
return job_result_kwargs

def _job_request_timeout(self) -> float:
return self._request_timeout

def _run_query_job(self, connection: "BigQueryConnection", sql: str, parameters: Any) -> "QueryJob":
return run_query_job(
connection,
Expand All @@ -172,10 +178,13 @@ def _run_query_job(self, connection: "BigQueryConnection", sql: str, parameters:
job_config=None,
json_serializer=self._json_serializer,
retry=self._job_retry,
timeout=self._request_timeout,
timeout=self._job_request_timeout(),
job_retry=self._job_retry,
)

def _job_result_kwargs(self) -> dict[str, Any]:
return dict(self._job_result_kwargs_defaults)

# ─────────────────────────────────────────────────────────────────────────────
# CORE DISPATCH METHODS
# ─────────────────────────────────────────────────────────────────────────────
Expand All @@ -191,6 +200,35 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult:
ExecutionResult with query results and metadata
"""
sql, parameters = self._get_compiled_sql(statement, self.statement_config)
if self._use_query_and_wait:
row_iterator = _run_query_and_wait(
cursor,
sql,
parameters,
default_job_config=self._default_query_job_config,
json_serializer=self._json_serializer,
retry=self._job_retry,
wait_timeout=self._job_request_timeout(),
job_retry=self._job_retry,
)
cursor.job = None
iterator_schema = getattr(row_iterator, "schema", None)
if statement.returns_rows() or iterator_schema:
column_names = resolve_column_names(iterator_schema, self._column_name_cache)
rows_list, _ = collect_rows(row_iterator, iterator_schema, column_names=column_names)

return self.create_execution_result(
cursor,
selected_data=rows_list,
column_names=column_names,
data_row_count=len(rows_list),
is_select_result=True,
row_format="record",
)

affected_rows = build_dml_rowcount(row_iterator, 0)
return self.create_execution_result(cursor, rowcount_override=affected_rows)

cursor.job = self._run_query_job(cursor, sql, parameters)
statement_type = str(cursor.job.statement_type or "").upper()
is_select_like = (
Expand All @@ -199,7 +237,7 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult:

if is_select_like:
job_result = cursor.job.result(
job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs
job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs()
)
job_schema = cursor.job.schema or getattr(job_result, "schema", None)
column_names = resolve_column_names(job_schema, self._column_name_cache)
Expand Down Expand Up @@ -246,7 +284,12 @@ def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> ExecutionResul
allow_parse = statement.statement_config.enable_parsing
if is_simple_insert(sql, parsed_expression, allow_parse=allow_parse):
rowcount = try_bulk_insert(
self.connection, sql, prepared_parameters, parsed_expression, allow_parse=allow_parse
self.connection,
sql,
prepared_parameters,
parsed_expression,
allow_parse=allow_parse,
result_timeout=self._job_request_timeout(),
)
if rowcount is not None:
return self.create_execution_result(cursor, rowcount_override=rowcount, is_many_result=True)
Expand Down Expand Up @@ -446,7 +489,7 @@ def select_to_arrow(
with exc_handler:
query_job = self._run_query_job(self.connection, sql, driver_params)
query_job.result(
job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs
job_retry=self._job_retry, timeout=self._job_result_timeout, **self._job_result_kwargs()
) # Wait for completion

# Native Arrow via Storage API
Expand Down Expand Up @@ -528,8 +571,10 @@ def load_from_arrow(
pq.write_table(arrow_table, buffer)
buffer.seek(0)
job_config = build_load_job_config("parquet", overwrite)
job = self.connection.load_table_from_file(buffer, table, job_config=job_config)
job.result()
job = self.connection.load_table_from_file(
buffer, table, job_config=job_config, timeout=self._job_request_timeout()
)
job.result(timeout=self._job_request_timeout())
telemetry_payload = build_load_job_telemetry(job, table, format_label="parquet")
if telemetry:
telemetry_payload.setdefault("extra", {})
Expand All @@ -552,8 +597,10 @@ def load_from_storage(
msg = "BigQuery storage bridge currently supports Parquet ingest only"
raise StorageCapabilityError(msg, capability="parquet_import_enabled")
job_config = build_load_job_config(file_format, overwrite)
job = self.connection.load_table_from_uri(source, table, job_config=job_config)
job.result()
job = self.connection.load_table_from_uri(
source, table, job_config=job_config, retry=self._job_retry, timeout=self._job_request_timeout()
)
job.result(timeout=self._job_request_timeout())
telemetry_payload = build_load_job_telemetry(job, table, format_label=file_format)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload)
Expand Down
13 changes: 13 additions & 0 deletions sqlspec/adapters/spanner/adk/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from google.cloud.spanner_v1 import param_types

from sqlspec.adapters.spanner.config import SpannerSyncConfig
from sqlspec.exceptions import OperationalError
from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore
from sqlspec.protocols import SpannerParamTypesProtocol
Expand Down Expand Up @@ -726,6 +727,12 @@ def __init__(self, statements: "list[tuple[str, dict[str, Any], dict[str, Any]]]
self._statements = statements

def __call__(self, transaction: "Transaction") -> None:
if len(self._statements) > 1:
status, _row_counts = transaction.batch_update(self._statements) # type: ignore[no-untyped-call]
if status.code != 0:
msg = f"Spanner batch update failed (code {status.code}): {status.message}"
raise OperationalError(msg)
return
for sql, params, types in self._statements:
transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call]

Expand All @@ -737,6 +744,12 @@ def __init__(self, statements: "list[tuple[str, dict[str, Any], dict[str, Any]]]
self._statements = statements

def __call__(self, transaction: "Transaction") -> None:
if len(self._statements) > 1:
status, _row_counts = transaction.batch_update(self._statements) # type: ignore[no-untyped-call]
if status.code != 0:
msg = f"Spanner batch update failed (code {status.code}): {status.message}"
raise OperationalError(msg)
return
for sql, params, types in self._statements:
transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call]

Expand Down
Loading
Loading