From 4a069fd20060e9ef2ba70deba130c84a28132284 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 22 Apr 2026 18:54:01 +0000 Subject: [PATCH 1/8] chore(adapters): remove mock adapter implementation --- sqlspec/adapters/mock/__init__.py | 79 --- sqlspec/adapters/mock/_typing.py | 227 -------- sqlspec/adapters/mock/config.py | 484 ---------------- sqlspec/adapters/mock/core.py | 317 ----------- sqlspec/adapters/mock/data_dictionary.py | 366 ------------ sqlspec/adapters/mock/driver.py | 688 ----------------------- 6 files changed, 2161 deletions(-) delete mode 100644 sqlspec/adapters/mock/__init__.py delete mode 100644 sqlspec/adapters/mock/_typing.py delete mode 100644 sqlspec/adapters/mock/config.py delete mode 100644 sqlspec/adapters/mock/core.py delete mode 100644 sqlspec/adapters/mock/data_dictionary.py delete mode 100644 sqlspec/adapters/mock/driver.py diff --git a/sqlspec/adapters/mock/__init__.py b/sqlspec/adapters/mock/__init__.py deleted file mode 100644 index 5934509ff..000000000 --- a/sqlspec/adapters/mock/__init__.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Mock adapter for SQLSpec. - -This adapter provides mock database drivers that use SQLite :memory: as the -execution backend while accepting SQL written in other dialects (Postgres, -MySQL, Oracle, etc.). SQL is transpiled to SQLite syntax before execution -using sqlglot. - -Key Features: - - Write SQL in your target dialect (Postgres, MySQL, Oracle, SQLite) - - SQL is transpiled to SQLite before execution - - Fast execution using SQLite :memory: database - - Same API as real database adapters - - Both sync and async drivers available - - Initial SQL support for test fixtures - -Example: - >>> from sqlspec.adapters.mock import MockSyncConfig - >>> - >>> # Create config with Postgres dialect - >>> config = MockSyncConfig(target_dialect="postgres") - >>> - >>> # Use Postgres syntax - it will be transpiled to SQLite - >>> with config.provide_session() as session: - ... session.execute(\"\"\" - ... CREATE TABLE users ( - ... id SERIAL PRIMARY KEY, - ... name VARCHAR(100) - ... ) - ... \"\"\") - ... session.execute("INSERT INTO users (name) VALUES ($1)", "Alice") - ... user = session.select_one("SELECT * FROM users WHERE name = $1", "Alice") - ... print(user["name"]) - Alice - -With Test Fixtures: - >>> config = MockSyncConfig( - ... target_dialect="postgres", - ... initial_sql=[ - ... "CREATE TABLE users (id INT, name TEXT, role TEXT)", - ... "INSERT INTO users VALUES (1, 'Alice', 'admin')", - ... "INSERT INTO users VALUES (2, 'Bob', 'user')", - ... ], - ... ) - >>> - >>> with config.provide_session() as session: - ... admins = session.select( - ... "SELECT * FROM users WHERE role = $1", "admin" - ... ) - ... print(len(admins)) - 1 -""" - -from sqlspec.adapters.mock._typing import ( - MockAsyncCursor, - MockAsyncSessionContext, - MockConnection, - MockCursor, - MockSyncSessionContext, -) -from sqlspec.adapters.mock.config import MockAsyncConfig, MockConnectionParams, MockDriverFeatures, MockSyncConfig -from sqlspec.adapters.mock.data_dictionary import MockAsyncDataDictionary, MockDataDictionary -from sqlspec.adapters.mock.driver import MockAsyncDriver, MockExceptionHandler, MockSyncDriver - -__all__ = ( - "MockAsyncConfig", - "MockAsyncCursor", - "MockAsyncDataDictionary", - "MockAsyncDriver", - "MockAsyncSessionContext", - "MockConnection", - "MockConnectionParams", - "MockCursor", - "MockDataDictionary", - "MockDriverFeatures", - "MockExceptionHandler", - "MockSyncConfig", - "MockSyncDriver", - "MockSyncSessionContext", -) diff --git a/sqlspec/adapters/mock/_typing.py b/sqlspec/adapters/mock/_typing.py deleted file mode 100644 index b347befef..000000000 --- a/sqlspec/adapters/mock/_typing.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Mock adapter type definitions. - -This module contains type aliases and classes that are excluded from mypyc -compilation to avoid ABI boundary issues. -""" - -import contextlib -import sqlite3 -from typing import TYPE_CHECKING, Any - -_MockConnection = sqlite3.Connection - -if TYPE_CHECKING: - from collections.abc import Awaitable, Callable - from types import TracebackType - from typing import TypeAlias - - from sqlspec.adapters.mock.driver import MockAsyncDriver, MockSyncDriver - from sqlspec.core import StatementConfig - - MockConnection: TypeAlias = _MockConnection - MockRawCursor: TypeAlias = sqlite3.Cursor - -if not TYPE_CHECKING: - MockConnection = _MockConnection - MockRawCursor = sqlite3.Cursor - - -class MockCursor: - """Context manager for Mock SQLite cursor management. - - Provides automatic cursor creation and cleanup for SQLite database operations. - """ - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "MockConnection") -> None: - """Initialize cursor manager. - - Args: - connection: SQLite database connection - - """ - self.connection = connection - self.cursor: MockRawCursor | None = None - - def __enter__(self) -> "MockRawCursor": - """Create and return a new cursor. - - Returns: - Active SQLite cursor object - - """ - self.cursor = self.connection.cursor() - return self.cursor - - def __exit__(self, *_: object) -> None: - """Clean up cursor resources.""" - if self.cursor is not None: - with contextlib.suppress(Exception): - self.cursor.close() - - -class MockAsyncCursor: - """Async context manager for Mock SQLite cursor management.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "MockConnection") -> None: - """Initialize async cursor manager. - - Args: - connection: SQLite database connection - - """ - self.connection = connection - self.cursor: MockRawCursor | None = None - - async def __aenter__(self) -> "MockRawCursor": - """Create and return a new cursor. - - Returns: - Active SQLite cursor object - - """ - self.cursor = self.connection.cursor() - return self.cursor - - async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" - ) -> None: - """Clean up cursor resources.""" - if self.cursor is not None: - with contextlib.suppress(Exception): - self.cursor.close() - - -class MockSyncSessionContext: - """Sync context manager for Mock sessions. - - This class is intentionally excluded from mypyc compilation to avoid ABI - boundary issues. It receives callables from uncompiled config classes and - instantiates compiled Driver objects, acting as a bridge between compiled - and uncompiled code. - - Uses callable-based connection management to decouple from config implementation. - """ - - __slots__ = ( - "_acquire_connection", - "_connection", - "_driver", - "_driver_features", - "_prepare_driver", - "_release_connection", - "_statement_config", - "_target_dialect", - ) - - def __init__( - self, - acquire_connection: "Callable[[], MockConnection]", - release_connection: "Callable[[MockConnection], None]", - statement_config: "StatementConfig", - driver_features: "dict[str, Any]", - prepare_driver: "Callable[[MockSyncDriver], MockSyncDriver]", - target_dialect: str = "sqlite", - ) -> None: - self._acquire_connection = acquire_connection - self._release_connection = release_connection - self._statement_config = statement_config - self._driver_features = driver_features - self._prepare_driver = prepare_driver - self._target_dialect = target_dialect - self._connection: MockConnection | None = None - self._driver: MockSyncDriver | None = None - - def __enter__(self) -> "MockSyncDriver": - from sqlspec.adapters.mock.driver import MockSyncDriver - - self._connection = self._acquire_connection() - self._driver = MockSyncDriver( - connection=self._connection, - statement_config=self._statement_config, - driver_features=self._driver_features, - target_dialect=self._target_dialect, - ) - return self._prepare_driver(self._driver) - - def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" - ) -> "bool | None": - if self._connection is not None: - self._release_connection(self._connection) - self._connection = None - return None - - -class MockAsyncSessionContext: - """Async context manager for Mock sessions. - - This class is intentionally excluded from mypyc compilation to avoid ABI - boundary issues. It receives callables from uncompiled config classes and - instantiates compiled Driver objects, acting as a bridge between compiled - and uncompiled code. - - Uses callable-based connection management to decouple from config implementation. - """ - - __slots__ = ( - "_acquire_connection", - "_connection", - "_driver", - "_driver_features", - "_prepare_driver", - "_release_connection", - "_statement_config", - "_target_dialect", - ) - - def __init__( - self, - acquire_connection: "Callable[[], Awaitable[MockConnection]]", - release_connection: "Callable[[MockConnection], Awaitable[None]]", - statement_config: "StatementConfig", - driver_features: "dict[str, Any]", - prepare_driver: "Callable[[MockAsyncDriver], MockAsyncDriver]", - target_dialect: str = "sqlite", - ) -> None: - self._acquire_connection = acquire_connection - self._release_connection = release_connection - self._statement_config = statement_config - self._driver_features = driver_features - self._prepare_driver = prepare_driver - self._target_dialect = target_dialect - self._connection: MockConnection | None = None - self._driver: MockAsyncDriver | None = None - - async def __aenter__(self) -> "MockAsyncDriver": - from sqlspec.adapters.mock.driver import MockAsyncDriver - - self._connection = await self._acquire_connection() - self._driver = MockAsyncDriver( - connection=self._connection, - statement_config=self._statement_config, - driver_features=self._driver_features, - target_dialect=self._target_dialect, - ) - return self._prepare_driver(self._driver) - - async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" - ) -> "bool | None": - if self._connection is not None: - await self._release_connection(self._connection) - self._connection = None - return None - - -__all__ = ( - "MockAsyncCursor", - "MockAsyncSessionContext", - "MockConnection", - "MockCursor", - "MockRawCursor", - "MockSyncSessionContext", -) diff --git a/sqlspec/adapters/mock/config.py b/sqlspec/adapters/mock/config.py deleted file mode 100644 index 4333f44c3..000000000 --- a/sqlspec/adapters/mock/config.py +++ /dev/null @@ -1,484 +0,0 @@ -"""Mock database configuration for testing with dialect transpilation. - -This module provides configuration classes for the mock adapter that use -SQLite :memory: as the execution backend while accepting SQL written in -other dialects (Postgres, MySQL, Oracle, etc.). -""" - -import sqlite3 -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast - -from typing_extensions import NotRequired - -from sqlspec.adapters.mock._typing import MockAsyncSessionContext, MockConnection, MockCursor, MockSyncSessionContext -from sqlspec.adapters.mock.core import apply_driver_features, default_statement_config -from sqlspec.adapters.mock.driver import MockAsyncDriver, MockExceptionHandler, MockSyncDriver -from sqlspec.config import ExtensionConfigs, NoPoolAsyncConfig, NoPoolSyncConfig -from sqlspec.driver import convert_to_dialect -from sqlspec.driver._async import AsyncPoolConnectionContext, AsyncPoolSessionFactory -from sqlspec.driver._sync import SyncPoolConnectionContext, SyncPoolSessionFactory -from sqlspec.utils.sync_tools import async_ - -if TYPE_CHECKING: - from collections.abc import Callable - from types import TracebackType - - from sqlspec.core import StatementConfig - from sqlspec.observability import ObservabilityConfig - -__all__ = ("MockAsyncConfig", "MockConnectionParams", "MockDriverFeatures", "MockSyncConfig") - - -class MockConnectionParams(TypedDict): - """Mock connection parameters. - - These parameters control the SQLite :memory: backend behavior. - """ - - target_dialect: NotRequired[str] - initial_sql: NotRequired["str | list[str]"] - timeout: NotRequired[float] - detect_types: NotRequired[int] - isolation_level: "NotRequired[str | None]" - check_same_thread: NotRequired[bool] - cached_statements: NotRequired[int] - - -class MockDriverFeatures(TypedDict): - """Mock driver feature configuration. - - Controls optional type handling and serialization features for Mock connections. - - json_serializer: Custom JSON serializer function. - Defaults to sqlspec.utils.serializers.to_json. - json_deserializer: Custom JSON deserializer function. - Defaults to sqlspec.utils.serializers.from_json. - """ - - json_serializer: "NotRequired[Callable[[Any], str]]" - json_deserializer: "NotRequired[Callable[[str], Any]]" - - -class MockSyncConnectionContext(SyncPoolConnectionContext): - """Context manager for Mock sync connections.""" - - __slots__ = ("_connection",) - - def __init__(self, config: "MockSyncConfig") -> None: - super().__init__(config) - self._connection: MockConnection | None = None - - def __enter__(self) -> MockConnection: - self._connection = self._config.create_connection() - return cast("MockConnection", self._connection) - - def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" - ) -> "bool | None": - if self._connection is not None: - self._connection.close() - self._connection = None - return None - - -class MockAsyncConnectionContext(AsyncPoolConnectionContext): - """Async context manager for Mock async connections.""" - - __slots__ = () - - async def __aenter__(self) -> MockConnection: - self._connection = await self._config.create_connection() - return cast("MockConnection", self._connection) - - async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" - ) -> "bool | None": - if self._connection is not None: - self._connection.close() - self._connection = None - return None - - -class _MockSyncSessionFactory(SyncPoolSessionFactory): - """Factory for creating mock sync sessions.""" - - __slots__ = ("_connection",) - - def __init__(self, config: "MockSyncConfig") -> None: - super().__init__(config) - self._connection: MockConnection | None = None - - def acquire_connection(self) -> MockConnection: - self._connection = self._config.create_connection() - return cast("MockConnection", self._connection) - - def release_connection(self, _conn: MockConnection, **kwargs: Any) -> None: - if self._connection is not None: - self._connection.close() - self._connection = None - - -class _MockAsyncSessionFactory(AsyncPoolSessionFactory): - """Factory for creating mock async sessions.""" - - __slots__ = () - - async def acquire_connection(self) -> MockConnection: - self._connection = await self._config.create_connection() - return cast("MockConnection", self._connection) - - async def release_connection(self, _conn: MockConnection, **kwargs: Any) -> None: - if self._connection is not None: - self._connection.close() - self._connection = None - - -class MockSyncConfig(NoPoolSyncConfig["MockConnection", "MockSyncDriver"]): - """Sync mock database configuration. - - Uses SQLite :memory: as the execution backend with dialect transpilation. - Write SQL in your target dialect (Postgres, MySQL, Oracle) and it will - be transpiled to SQLite before execution. - - Example: - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute(\"\"\" - CREATE TABLE users ( - id SERIAL PRIMARY KEY, - name VARCHAR(100) - ) - \"\"\") - session.execute( - "INSERT INTO users (name) VALUES ($1)", - "Alice" - ) - user = session.select_one("SELECT * FROM users WHERE name = $1", "Alice") - assert user["name"] == "Alice" - """ - - driver_type: "ClassVar[type[MockSyncDriver]]" = MockSyncDriver - connection_type: "ClassVar[type[MockConnection]]" = MockConnection - supports_transactional_ddl: "ClassVar[bool]" = True - supports_native_arrow_export: "ClassVar[bool]" = True - supports_native_arrow_import: "ClassVar[bool]" = True - supports_native_parquet_export: "ClassVar[bool]" = True - supports_native_parquet_import: "ClassVar[bool]" = True - _connection_context_class: "ClassVar[type[MockSyncConnectionContext]]" = MockSyncConnectionContext - _session_factory_class: "ClassVar[type[_MockSyncSessionFactory]]" = _MockSyncSessionFactory - _session_context_class: "ClassVar[type[MockSyncSessionContext]]" = MockSyncSessionContext - _default_statement_config = default_statement_config - - def __init__( - self, - *, - target_dialect: str = "sqlite", - initial_sql: "str | list[str] | None" = None, - connection_config: "MockConnectionParams | dict[str, Any] | None" = None, - connection_instance: "Any" = None, - migration_config: "dict[str, Any] | None" = None, - statement_config: "StatementConfig | None" = None, - driver_features: "MockDriverFeatures | dict[str, Any] | None" = None, - bind_key: "str | None" = None, - extension_config: "ExtensionConfigs | None" = None, - observability_config: "ObservabilityConfig | None" = None, - ) -> None: - """Initialize Mock sync configuration. - - Args: - target_dialect: SQL dialect for input SQL (postgres, mysql, oracle, sqlite). - SQL will be transpiled to SQLite before execution, unless 'sqlite'. - initial_sql: SQL statements to execute when creating connection. - Can be a single string or list of strings. Useful for setting up - test fixtures. - connection_config: Additional connection parameters. - connection_instance: Pre-existing connection (not used for mock). - migration_config: Migration configuration. - statement_config: Statement configuration settings. - driver_features: Driver feature configuration. - bind_key: Optional unique identifier for this configuration. - extension_config: Extension-specific configuration. - observability_config: Observability configuration. - """ - config_dict: dict[str, Any] = dict(connection_config) if connection_config else {} - config_dict["target_dialect"] = target_dialect - config_dict["initial_sql"] = initial_sql - - statement_config = statement_config or default_statement_config - statement_config, driver_features = apply_driver_features(statement_config, driver_features) - - super().__init__( - connection_config=config_dict, - connection_instance=connection_instance, - migration_config=migration_config, - statement_config=statement_config, - driver_features=driver_features, - bind_key=bind_key, - extension_config=extension_config, - observability_config=observability_config, - ) - - @property - def target_dialect(self) -> str: - """Get the target dialect for SQL transpilation.""" - return str(self.connection_config.get("target_dialect", "sqlite")) - - @property - def initial_sql(self) -> "str | list[str] | None": - """Get the initial SQL to execute on connection creation.""" - return self.connection_config.get("initial_sql") - - def create_connection(self) -> MockConnection: - """Create a new SQLite :memory: connection. - - Returns: - SQLite connection with row factory set. - """ - conn = sqlite3.connect(":memory:", check_same_thread=False) - - if self.initial_sql: - self._execute_initial_sql(conn) - - return conn - - def _execute_initial_sql(self, conn: MockConnection) -> None: - """Execute initial SQL statements on a new connection. - - Args: - conn: SQLite connection to execute SQL on. - """ - initial_sql = self.initial_sql - if initial_sql is None: - return - - statements = initial_sql if isinstance(initial_sql, list) else [initial_sql] - target_dialect = self.target_dialect - - for sql in statements: - if target_dialect != "sqlite": - transpiled = convert_to_dialect(sql, target_dialect, "sqlite", pretty=False) - else: - transpiled = sql - conn.executescript(transpiled) - - def provide_connection(self, *args: Any, **kwargs: Any) -> "MockSyncConnectionContext": - """Provide a Mock sync connection context manager. - - Returns: - Connection context manager. - """ - return MockSyncConnectionContext(self) - - def provide_session( - self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any - ) -> "MockSyncSessionContext": - """Provide a Mock sync driver session. - - Args: - statement_config: Optional statement configuration override. - - Returns: - Mock driver session context manager. - """ - factory = _MockSyncSessionFactory(self) - - return MockSyncSessionContext( - acquire_connection=factory.acquire_connection, - release_connection=factory.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - target_dialect=self.target_dialect, - ) - - def get_signature_namespace(self) -> "dict[str, Any]": - """Get the signature namespace for Mock types. - - Returns: - Dictionary mapping type names to types. - """ - namespace = super().get_signature_namespace() - namespace.update({ - "MockConnection": MockConnection, - "MockConnectionParams": MockConnectionParams, - "MockCursor": MockCursor, - "MockDriverFeatures": MockDriverFeatures, - "MockExceptionHandler": MockExceptionHandler, - "MockSyncConfig": MockSyncConfig, - "MockSyncConnectionContext": MockSyncConnectionContext, - "MockSyncDriver": MockSyncDriver, - "MockSyncSessionContext": MockSyncSessionContext, - }) - return namespace - - -class MockAsyncConfig(NoPoolAsyncConfig["MockConnection", "MockAsyncDriver"]): - """Async mock database configuration. - - Uses SQLite :memory: as the execution backend with dialect transpilation. - The async interface wraps sync SQLite operations using asyncio.to_thread(). - - Example: - config = MockAsyncConfig(target_dialect="mysql") - - async with config.provide_session() as session: - await session.execute("CREATE TABLE items (id INT, name TEXT)") - await session.execute("INSERT INTO items VALUES (%s, %s)", 1, "Widget") - result = await session.select("SELECT * FROM items") - assert len(result) == 1 - """ - - driver_type: "ClassVar[type[MockAsyncDriver]]" = MockAsyncDriver - connection_type: "ClassVar[type[MockConnection]]" = MockConnection - supports_transactional_ddl: "ClassVar[bool]" = True - supports_native_arrow_export: "ClassVar[bool]" = True - supports_native_arrow_import: "ClassVar[bool]" = True - supports_native_parquet_export: "ClassVar[bool]" = True - supports_native_parquet_import: "ClassVar[bool]" = True - _connection_context_class: "ClassVar[type[MockAsyncConnectionContext]]" = MockAsyncConnectionContext - _session_factory_class: "ClassVar[type[_MockAsyncSessionFactory]]" = _MockAsyncSessionFactory - _session_context_class: "ClassVar[type[MockAsyncSessionContext]]" = MockAsyncSessionContext - _default_statement_config = default_statement_config - - def __init__( - self, - *, - target_dialect: str = "sqlite", - initial_sql: "str | list[str] | None" = None, - connection_config: "MockConnectionParams | dict[str, Any] | None" = None, - connection_instance: "Any" = None, - migration_config: "dict[str, Any] | None" = None, - statement_config: "StatementConfig | None" = None, - driver_features: "MockDriverFeatures | dict[str, Any] | None" = None, - bind_key: "str | None" = None, - extension_config: "ExtensionConfigs | None" = None, - observability_config: "ObservabilityConfig | None" = None, - ) -> None: - """Initialize Mock async configuration. - - Args: - target_dialect: SQL dialect for input SQL (postgres, mysql, oracle, sqlite). - SQL will be transpiled to SQLite before execution, unless 'sqlite'. - initial_sql: SQL statements to execute when creating connection. - Can be a single string or list of strings. Useful for setting up - test fixtures. - connection_config: Additional connection parameters. - connection_instance: Pre-existing connection (not used for mock). - migration_config: Migration configuration. - statement_config: Statement configuration settings. - driver_features: Driver feature configuration. - bind_key: Optional unique identifier for this configuration. - extension_config: Extension-specific configuration. - observability_config: Observability configuration. - """ - config_dict: dict[str, Any] = dict(connection_config) if connection_config else {} - config_dict["target_dialect"] = target_dialect - config_dict["initial_sql"] = initial_sql - - statement_config = statement_config or default_statement_config - statement_config, driver_features = apply_driver_features(statement_config, driver_features) - - super().__init__( - connection_config=config_dict, - connection_instance=connection_instance, - migration_config=migration_config, - statement_config=statement_config, - driver_features=driver_features, - bind_key=bind_key, - extension_config=extension_config, - observability_config=observability_config, - ) - - @property - def target_dialect(self) -> str: - """Get the target dialect for SQL transpilation.""" - return str(self.connection_config.get("target_dialect", "sqlite")) - - @property - def initial_sql(self) -> "str | list[str] | None": - """Get the initial SQL to execute on connection creation.""" - return self.connection_config.get("initial_sql") - - async def create_connection(self) -> MockConnection: - """Create a new SQLite :memory: connection asynchronously. - - Returns: - SQLite connection with row factory set. - """ - connect_async = async_(sqlite3.connect) - conn = await connect_async(":memory:", check_same_thread=False) - - if self.initial_sql: - await self._execute_initial_sql_async(conn) - - return conn - - async def _execute_initial_sql_async(self, conn: MockConnection) -> None: - """Execute initial SQL statements on a new connection. - - Args: - conn: SQLite connection to execute SQL on. - """ - initial_sql = self.initial_sql - if initial_sql is None: - return - - statements = initial_sql if isinstance(initial_sql, list) else [initial_sql] - target_dialect = self.target_dialect - - for sql in statements: - if target_dialect != "sqlite": - transpiled = convert_to_dialect(sql, target_dialect, "sqlite", pretty=False) - else: - transpiled = sql - executescript_async = async_(conn.executescript) - await executescript_async(transpiled) - - def provide_connection(self, *args: Any, **kwargs: Any) -> "MockAsyncConnectionContext": - """Provide a Mock async connection context manager. - - Returns: - Async connection context manager. - """ - return MockAsyncConnectionContext(self) - - def provide_session( - self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any - ) -> "MockAsyncSessionContext": - """Provide a Mock async driver session. - - Args: - statement_config: Optional statement configuration override. - - Returns: - Mock async driver session context manager. - """ - factory = _MockAsyncSessionFactory(self) - - return MockAsyncSessionContext( - acquire_connection=factory.acquire_connection, - release_connection=factory.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - target_dialect=self.target_dialect, - ) - - def get_signature_namespace(self) -> "dict[str, Any]": - """Get the signature namespace for Mock types. - - Returns: - Dictionary mapping type names to types. - """ - namespace = super().get_signature_namespace() - namespace.update({ - "MockAsyncConfig": MockAsyncConfig, - "MockAsyncConnectionContext": MockAsyncConnectionContext, - "MockAsyncDriver": MockAsyncDriver, - "MockAsyncSessionContext": MockAsyncSessionContext, - "MockConnection": MockConnection, - "MockConnectionParams": MockConnectionParams, - "MockDriverFeatures": MockDriverFeatures, - }) - return namespace diff --git a/sqlspec/adapters/mock/core.py b/sqlspec/adapters/mock/core.py deleted file mode 100644 index 6bf9cfd57..000000000 --- a/sqlspec/adapters/mock/core.py +++ /dev/null @@ -1,317 +0,0 @@ -"""Mock adapter compiled helpers. - -This module provides utility functions for the mock adapter, reusing -SQLite-compatible helpers since mock uses SQLite as its execution backend. -""" - -from datetime import date, datetime -from decimal import Decimal -from typing import TYPE_CHECKING, Any, cast - -from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile -from sqlspec.exceptions import ( - CheckViolationError, - DatabaseConnectionError, - DataError, - ForeignKeyViolationError, - IntegrityError, - NotNullViolationError, - OperationalError, - SQLParsingError, - SQLSpecError, - UniqueViolationError, -) -from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter, build_uuid_coercions -from sqlspec.utils.type_guards import has_rowcount, has_sqlite_error - -if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence - -__all__ = ( - "apply_driver_features", - "build_insert_statement", - "build_profile", - "build_statement_config", - "collect_rows", - "create_mapped_exception", - "default_statement_config", - "driver_profile", - "format_identifier", - "normalize_execute_many_parameters", - "normalize_execute_parameters", - "resolve_rowcount", -) - -SQLITE_CONSTRAINT_UNIQUE_CODE = 2067 -SQLITE_CONSTRAINT_FOREIGNKEY_CODE = 787 -SQLITE_CONSTRAINT_NOTNULL_CODE = 1811 -SQLITE_CONSTRAINT_CHECK_CODE = 531 -SQLITE_CONSTRAINT_CODE = 19 -SQLITE_CANTOPEN_CODE = 14 -SQLITE_IOERR_CODE = 10 -SQLITE_MISMATCH_CODE = 20 - - -_TIME_TO_ISO = build_time_iso_converter() -_DECIMAL_TO_STRING = build_decimal_converter(mode="string") - - -def _bool_to_int(value: bool) -> int: - return int(value) - - -def _quote_sqlite_identifier(identifier: str) -> str: - normalized = identifier.replace('"', '""') - return f'"{normalized}"' - - -def format_identifier(identifier: str) -> str: - """Format an identifier for SQLite. - - Args: - identifier: Table or column name to format. - - Returns: - Properly quoted identifier. - - Raises: - SQLSpecError: If identifier is empty. - """ - cleaned = identifier.strip() - if not cleaned: - msg = "Table name must not be empty" - raise SQLSpecError(msg) - - if "." not in cleaned: - return _quote_sqlite_identifier(cleaned) - - return ".".join(_quote_sqlite_identifier(part) for part in cleaned.split(".") if part) - - -def build_insert_statement(table: str, columns: "list[str]") -> str: - """Build an INSERT statement for the given table and columns. - - Args: - table: Table name. - columns: List of column names. - - Returns: - INSERT SQL statement. - """ - column_clause = ", ".join(_quote_sqlite_identifier(column) for column in columns) - placeholders = ", ".join("?" for _ in columns) - return f"INSERT INTO {format_identifier(table)} ({column_clause}) VALUES ({placeholders})" - - -def collect_rows(fetched_data: "list[Any]", description: "Sequence[Any] | None") -> "tuple[list[Any], list[str], int]": - """Collect mock result rows as raw tuples. - - Args: - fetched_data: Raw rows from cursor.fetchall() - description: Cursor description (tuple of tuples) - - Returns: - Tuple of (data, column_names, row_count) - """ - if not description: - return [], [], 0 - - column_names = [col[0] for col in description] - return fetched_data, column_names, len(fetched_data) - - -def resolve_rowcount(cursor: Any) -> int: - """Resolve rowcount from a SQLite cursor. - - Args: - cursor: SQLite cursor with optional rowcount metadata. - - Returns: - Positive rowcount value or 0 when unknown. - """ - if not has_rowcount(cursor): - return 0 - rowcount = cursor.rowcount - if isinstance(rowcount, int) and rowcount > 0: - return rowcount - return 0 - - -def normalize_execute_parameters(parameters: Any) -> Any: - """Normalize parameters for SQLite execute calls. - - Args: - parameters: Prepared parameters payload. - - Returns: - Normalized parameters payload. - """ - return parameters or () - - -def normalize_execute_many_parameters(parameters: Any) -> Any: - """Normalize parameters for SQLite executemany calls. - - Args: - parameters: Prepared parameters payload. - - Returns: - Normalized parameters payload. - - Raises: - ValueError: When parameters are missing for executemany. - """ - if not parameters: - msg = "execute_many requires parameters" - raise ValueError(msg) - return parameters - - -def _create_sqlite_error( - error: Any, code: "int | None", error_class: type[SQLSpecError], description: str -) -> SQLSpecError: - """Create a SQLite error instance without raising it.""" - code_str = f"[code {code}]" if code else "" - msg = f"SQLite {description} {code_str}: {error}" if code_str else f"SQLite {description}: {error}" - exc = error_class(msg) - exc.__cause__ = cast("BaseException", error) - return exc - - -def create_mapped_exception(error: BaseException) -> SQLSpecError: - """Map SQLite errors to SQLSpec exceptions. - - This is a factory function that returns an exception instance rather than - raising. This pattern is more robust for use in __exit__ handlers and - avoids issues with exception control flow in different Python versions. - - Args: - error: The SQLite exception to map - - Returns: - A SQLSpec exception that wraps the original error - """ - if has_sqlite_error(error): - error_code = error.sqlite_errorcode - error_name = error.sqlite_errorname - else: - error_code = None - error_name = None - error_msg = str(error).lower() - - if "locked" in error_msg: - return _create_sqlite_error(error, error_code or 0, OperationalError, "operational error") - - if not error_code: - if "unique constraint" in error_msg: - return _create_sqlite_error(error, 0, UniqueViolationError, "unique constraint violation") - if "foreign key constraint" in error_msg: - return _create_sqlite_error(error, 0, ForeignKeyViolationError, "foreign key constraint violation") - if "not null constraint" in error_msg: - return _create_sqlite_error(error, 0, NotNullViolationError, "not-null constraint violation") - if "check constraint" in error_msg: - return _create_sqlite_error(error, 0, CheckViolationError, "check constraint violation") - if "syntax" in error_msg: - return _create_sqlite_error(error, None, SQLParsingError, "SQL syntax error") - return _create_sqlite_error(error, None, SQLSpecError, "database error") - - if error_code == SQLITE_CONSTRAINT_UNIQUE_CODE or error_name == "SQLITE_CONSTRAINT_UNIQUE": - return _create_sqlite_error(error, error_code, UniqueViolationError, "unique constraint violation") - if error_code == SQLITE_CONSTRAINT_FOREIGNKEY_CODE or error_name == "SQLITE_CONSTRAINT_FOREIGNKEY": - return _create_sqlite_error(error, error_code, ForeignKeyViolationError, "foreign key constraint violation") - if error_code == SQLITE_CONSTRAINT_NOTNULL_CODE or error_name == "SQLITE_CONSTRAINT_NOTNULL": - return _create_sqlite_error(error, error_code, NotNullViolationError, "not-null constraint violation") - if error_code == SQLITE_CONSTRAINT_CHECK_CODE or error_name == "SQLITE_CONSTRAINT_CHECK": - return _create_sqlite_error(error, error_code, CheckViolationError, "check constraint violation") - if error_code == SQLITE_CONSTRAINT_CODE or error_name == "SQLITE_CONSTRAINT": - return _create_sqlite_error(error, error_code, IntegrityError, "integrity constraint violation") - if error_code == SQLITE_CANTOPEN_CODE or error_name == "SQLITE_CANTOPEN": - return _create_sqlite_error(error, error_code, DatabaseConnectionError, "connection error") - if error_code == SQLITE_IOERR_CODE or error_name == "SQLITE_IOERR": - return _create_sqlite_error(error, error_code, OperationalError, "operational error") - if error_code == SQLITE_MISMATCH_CODE or error_name == "SQLITE_MISMATCH": - return _create_sqlite_error(error, error_code, DataError, "data error") - if error_code == 1 or "syntax" in error_msg: - return _create_sqlite_error(error, error_code, SQLParsingError, "SQL syntax error") - return _create_sqlite_error(error, error_code, SQLSpecError, "database error") - - -def build_profile() -> "DriverParameterProfile": - """Create the Mock driver parameter profile. - - Returns: - Driver parameter profile for Mock adapter. - """ - return DriverParameterProfile( - name="Mock", - default_style=ParameterStyle.QMARK, - supported_styles={ParameterStyle.QMARK, ParameterStyle.NAMED_COLON}, - default_execution_style=ParameterStyle.QMARK, - supported_execution_styles={ParameterStyle.QMARK, ParameterStyle.NAMED_COLON}, - has_native_list_expansion=False, - preserve_parameter_format=True, - needs_static_script_compilation=False, - allow_mixed_parameter_styles=False, - preserve_original_params_for_many=False, - json_serializer_strategy="helper", - custom_type_coercions={ - bool: _bool_to_int, - datetime: _TIME_TO_ISO, - date: _TIME_TO_ISO, - Decimal: _DECIMAL_TO_STRING, - **build_uuid_coercions(), - }, - default_dialect="sqlite", - ) - - -driver_profile = build_profile() - - -def build_statement_config( - *, json_serializer: "Callable[[Any], str] | None" = None, json_deserializer: "Callable[[str], Any] | None" = None -) -> "StatementConfig": - """Construct the Mock statement configuration with optional JSON codecs. - - Args: - json_serializer: Custom JSON serializer function. - json_deserializer: Custom JSON deserializer function. - - Returns: - StatementConfig for Mock adapter. - """ - serializer = json_serializer or to_json - deserializer = json_deserializer or from_json - profile = driver_profile - return build_statement_config_from_profile( - profile, statement_overrides={"dialect": "sqlite"}, json_serializer=serializer, json_deserializer=deserializer - ) - - -default_statement_config = build_statement_config() - - -def apply_driver_features( - statement_config: "StatementConfig", driver_features: "Mapping[str, Any] | None" -) -> "tuple[StatementConfig, dict[str, Any]]": - """Apply Mock driver feature defaults to statement config. - - Args: - statement_config: Base statement configuration. - driver_features: Driver feature overrides. - - Returns: - Tuple of (updated statement config, driver features dict). - """ - features: dict[str, Any] = dict(driver_features) if driver_features else {} - json_serializer = features.setdefault("json_serializer", to_json) - json_deserializer = features.setdefault("json_deserializer", from_json) - - if json_serializer is not None: - parameter_config = statement_config.parameter_config.with_json_serializers( - json_serializer, deserializer=json_deserializer - ) - statement_config = statement_config.replace(parameter_config=parameter_config) - - return statement_config, features diff --git a/sqlspec/adapters/mock/data_dictionary.py b/sqlspec/adapters/mock/data_dictionary.py deleted file mode 100644 index 123d9ab6d..000000000 --- a/sqlspec/adapters/mock/data_dictionary.py +++ /dev/null @@ -1,366 +0,0 @@ -"""Mock-specific data dictionary for metadata queries. - -This module provides data dictionary functionality for the mock adapter, -delegating to SQLite's catalog since mock uses SQLite as its execution backend. -""" - -from typing import TYPE_CHECKING, ClassVar - -from mypy_extensions import mypyc_attr - -from sqlspec.adapters.mock.core import format_identifier -from sqlspec.driver import AsyncDataDictionaryBase, SyncDataDictionaryBase -from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo - -__all__ = ("MockAsyncDataDictionary", "MockDataDictionary") - -if TYPE_CHECKING: - from sqlspec.adapters.mock.driver import MockAsyncDriver, MockSyncDriver - - -@mypyc_attr(allow_interpreted_subclasses=True, native_class=False) -class MockDataDictionary(SyncDataDictionaryBase): - """Mock-specific sync data dictionary. - - Delegates metadata queries to SQLite's catalog (sqlite_master, PRAGMA table_info). - """ - - dialect: ClassVar[str] = "sqlite" - - def __init__(self) -> None: - super().__init__() - - def get_version(self, driver: "MockSyncDriver") -> "VersionInfo | None": - """Get SQLite database version information. - - Args: - driver: Sync database driver instance. - - Returns: - SQLite version information or None if detection fails. - """ - driver_id = id(driver) - # Inline cache check to avoid cross-module method call that causes mypyc segfault - if driver_id in self._version_fetch_attempted: - return self._version_cache.get(driver_id) - # Not cached, fetch from database - - version_value = driver.select_value_or_none(self.get_query("version")) - if not version_value: - self._log_version_unavailable(type(self).dialect, "missing") - self.cache_version(driver_id, None) - return None - - version_info = self.parse_version_with_pattern(self.get_dialect_config().version_pattern, str(version_value)) - if version_info is None: - self._log_version_unavailable(type(self).dialect, "parse_failed") - self.cache_version(driver_id, None) - return None - - self._log_version_detected(type(self).dialect, version_info) - self.cache_version(driver_id, version_info) - return version_info - - def get_feature_flag(self, driver: "MockSyncDriver", feature: str) -> bool: - """Check if SQLite database supports a specific feature. - - Args: - driver: Sync database driver instance. - feature: Feature name to check. - - Returns: - True if feature is supported, False otherwise. - """ - version_info = self.get_version(driver) - return self.resolve_feature_flag(feature, version_info) - - def get_optimal_type(self, driver: "MockSyncDriver", type_category: str) -> str: - """Get optimal SQLite type for a category. - - Args: - driver: Sync database driver instance. - type_category: Type category. - - Returns: - SQLite-specific type name. - """ - config = self.get_dialect_config() - version_info = self.get_version(driver) - - if type_category == "json": - json_version = config.get_feature_version("supports_json") - if version_info and json_version and version_info >= json_version: - return "JSON" - return "TEXT" - - return config.get_optimal_type(type_category) - - def get_tables(self, driver: "MockSyncDriver", schema: "str | None" = None) -> "list[TableMetadata]": - """Get tables sorted by topological dependency order using SQLite catalog.""" - schema_name = self.resolve_schema(schema) - self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="tables") - schema_prefix = f"{format_identifier(schema_name)}." if schema_name else "" - query_text = self.get_query_text("tables_by_schema").format(schema_prefix=schema_prefix) - return driver.select(query_text, schema_type=TableMetadata) - - def get_columns( - self, driver: "MockSyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ColumnMetadata]": - """Get column information for a table or schema.""" - schema_name = self.resolve_schema(schema) - schema_prefix = f"{format_identifier(schema_name)}." if schema_name else "" - if table is None: - self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="columns") - query_text = self.get_query_text("columns_by_schema").format(schema_prefix=schema_prefix) - return driver.select(query_text, schema_type=ColumnMetadata) - - assert table is not None - self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="columns") - table_name = table - table_identifier = f"{schema_name}.{table_name}" if schema_name else table_name - query_text = self.get_query_text("columns_by_table").format(table_name=format_identifier(table_identifier)) - return driver.select(query_text, schema_type=ColumnMetadata) - - def get_indexes( - self, driver: "MockSyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[IndexMetadata]": - """Get index metadata for a table or schema.""" - schema_name = self.resolve_schema(schema) - indexes: list[IndexMetadata] = [] - if table is None: - self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="indexes") - for table_info in self.get_tables(driver, schema=schema_name): - table_name = table_info.get("table_name") - if not table_name: - continue - indexes.extend(self.get_indexes(driver, table=table_name, schema=schema_name)) - return indexes - - assert table is not None - self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="indexes") - table_name = table - table_identifier = f"{schema_name}.{table_name}" if schema_name else table_name - index_list_sql = self.get_query_text("indexes_by_table").format(table_name=format_identifier(table_identifier)) - index_rows = driver.select(index_list_sql) - for row in index_rows: - index_name = row.get("name") - if not index_name: - continue - index_identifier = f"{schema_name}.{index_name}" if schema_name else index_name - columns_sql = self.get_query_text("index_columns_by_index").format( - index_name=format_identifier(index_identifier) - ) - columns_rows = driver.select(columns_sql) - columns: list[str] = [] - for col in columns_rows: - column_name = col.get("name") - if column_name is None: - continue - columns.append(str(column_name)) - is_primary = row.get("origin") == "pk" - index_metadata: IndexMetadata = { - "index_name": index_name, - "table_name": table_name, - "columns": columns, - "is_primary": is_primary, - } - if schema_name is not None: - index_metadata["schema_name"] = schema_name - unique_value = row.get("unique") - if unique_value is not None: - index_metadata["is_unique"] = unique_value - indexes.append(index_metadata) - return indexes - - def get_foreign_keys( - self, driver: "MockSyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ForeignKeyMetadata]": - """Get foreign key metadata.""" - schema_name = self.resolve_schema(schema) - schema_prefix = f"{format_identifier(schema_name)}." if schema_name else "" - if table is None: - self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="foreign_keys") - query_text = self.get_query_text("foreign_keys_by_schema").format(schema_prefix=schema_prefix) - return driver.select(query_text, schema_type=ForeignKeyMetadata) - - self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="foreign_keys") - table_label = table.replace("'", "''") - table_identifier = f"{schema_name}.{table}" if schema_name else table - query_text = self.get_query_text("foreign_keys_by_table").format( - table_name=format_identifier(table_identifier), table_label=table_label - ) - return driver.select(query_text, schema_type=ForeignKeyMetadata) - - -@mypyc_attr(allow_interpreted_subclasses=True, native_class=False) -class MockAsyncDataDictionary(AsyncDataDictionaryBase): - """Mock-specific async data dictionary. - - Delegates metadata queries to SQLite's catalog (sqlite_master, PRAGMA table_info). - """ - - dialect: ClassVar[str] = "sqlite" - - def __init__(self) -> None: - super().__init__() - - async def get_version(self, driver: "MockAsyncDriver") -> "VersionInfo | None": - """Get SQLite database version information. - - Args: - driver: Async database driver instance. - - Returns: - SQLite version information or None if detection fails. - """ - driver_id = id(driver) - # Inline cache check to avoid cross-module method call that causes mypyc segfault - if driver_id in self._version_fetch_attempted: - return self._version_cache.get(driver_id) - # Not cached, fetch from database - - version_value = await driver.select_value_or_none(self.get_query("version")) - if not version_value: - self._log_version_unavailable(type(self).dialect, "missing") - self.cache_version(driver_id, None) - return None - - version_info = self.parse_version_with_pattern(self.get_dialect_config().version_pattern, str(version_value)) - if version_info is None: - self._log_version_unavailable(type(self).dialect, "parse_failed") - self.cache_version(driver_id, None) - return None - - self._log_version_detected(type(self).dialect, version_info) - self.cache_version(driver_id, version_info) - return version_info - - async def get_feature_flag(self, driver: "MockAsyncDriver", feature: str) -> bool: - """Check if SQLite database supports a specific feature. - - Args: - driver: Async database driver instance. - feature: Feature name to check. - - Returns: - True if feature is supported, False otherwise. - """ - version_info = await self.get_version(driver) - return self.resolve_feature_flag(feature, version_info) - - async def get_optimal_type(self, driver: "MockAsyncDriver", type_category: str) -> str: - """Get optimal SQLite type for a category. - - Args: - driver: Async database driver instance. - type_category: Type category. - - Returns: - SQLite-specific type name. - """ - config = self.get_dialect_config() - version_info = await self.get_version(driver) - - if type_category == "json": - json_version = config.get_feature_version("supports_json") - if version_info and json_version and version_info >= json_version: - return "JSON" - return "TEXT" - - return config.get_optimal_type(type_category) - - async def get_tables(self, driver: "MockAsyncDriver", schema: "str | None" = None) -> "list[TableMetadata]": - """Get tables sorted by topological dependency order using SQLite catalog.""" - schema_name = self.resolve_schema(schema) - self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="tables") - schema_prefix = f"{format_identifier(schema_name)}." if schema_name else "" - query_text = self.get_query_text("tables_by_schema").format(schema_prefix=schema_prefix) - return await driver.select(query_text, schema_type=TableMetadata) - - async def get_columns( - self, driver: "MockAsyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ColumnMetadata]": - """Get column information for a table or schema.""" - schema_name = self.resolve_schema(schema) - schema_prefix = f"{format_identifier(schema_name)}." if schema_name else "" - if table is None: - self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="columns") - query_text = self.get_query_text("columns_by_schema").format(schema_prefix=schema_prefix) - return await driver.select(query_text, schema_type=ColumnMetadata) - - assert table is not None - self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="columns") - table_name = table - table_identifier = f"{schema_name}.{table_name}" if schema_name else table_name - query_text = self.get_query_text("columns_by_table").format(table_name=format_identifier(table_identifier)) - return await driver.select(query_text, schema_type=ColumnMetadata) - - async def get_indexes( - self, driver: "MockAsyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[IndexMetadata]": - """Get index metadata for a table or schema.""" - schema_name = self.resolve_schema(schema) - indexes: list[IndexMetadata] = [] - if table is None: - self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="indexes") - for table_info in await self.get_tables(driver, schema=schema_name): - table_name = table_info.get("table_name") - if not table_name: - continue - indexes.extend(await self.get_indexes(driver, table=table_name, schema=schema_name)) - return indexes - - assert table is not None - self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="indexes") - table_name = table - table_identifier = f"{schema_name}.{table_name}" if schema_name else table_name - index_list_sql = self.get_query_text("indexes_by_table").format(table_name=format_identifier(table_identifier)) - index_rows = await driver.select(index_list_sql) - for row in index_rows: - index_name = row.get("name") - if not index_name: - continue - index_identifier = f"{schema_name}.{index_name}" if schema_name else index_name - columns_sql = self.get_query_text("index_columns_by_index").format( - index_name=format_identifier(index_identifier) - ) - columns_rows = await driver.select(columns_sql) - columns: list[str] = [] - for col in columns_rows: - column_name = col.get("name") - if column_name is None: - continue - columns.append(str(column_name)) - is_primary = row.get("origin") == "pk" - index_metadata: IndexMetadata = { - "index_name": index_name, - "table_name": table_name, - "columns": columns, - "is_primary": is_primary, - } - if schema_name is not None: - index_metadata["schema_name"] = schema_name - unique_value = row.get("unique") - if unique_value is not None: - index_metadata["is_unique"] = unique_value - indexes.append(index_metadata) - return indexes - - async def get_foreign_keys( - self, driver: "MockAsyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ForeignKeyMetadata]": - """Get foreign key metadata.""" - schema_name = self.resolve_schema(schema) - schema_prefix = f"{format_identifier(schema_name)}." if schema_name else "" - if table is None: - self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="foreign_keys") - query_text = self.get_query_text("foreign_keys_by_schema").format(schema_prefix=schema_prefix) - return await driver.select(query_text, schema_type=ForeignKeyMetadata) - - self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="foreign_keys") - table_label = table.replace("'", "''") - table_identifier = f"{schema_name}.{table}" if schema_name else table - query_text = self.get_query_text("foreign_keys_by_table").format( - table_name=format_identifier(table_identifier), table_label=table_label - ) - return await driver.select(query_text, schema_type=ForeignKeyMetadata) diff --git a/sqlspec/adapters/mock/driver.py b/sqlspec/adapters/mock/driver.py deleted file mode 100644 index b2cb289c0..000000000 --- a/sqlspec/adapters/mock/driver.py +++ /dev/null @@ -1,688 +0,0 @@ -"""Mock driver implementation with dialect transpilation. - -This module provides sync and async mock drivers that use SQLite `:memory:` -as the execution backend while accepting SQL written in other dialects -(Postgres, MySQL, Oracle, etc.). SQL is transpiled to SQLite syntax before -execution using sqlglot. -""" - -import sqlite3 -from typing import TYPE_CHECKING, Any - -from sqlspec.adapters.mock._typing import MockAsyncCursor, MockAsyncSessionContext, MockCursor, MockSyncSessionContext -from sqlspec.adapters.mock.core import ( - build_insert_statement, - collect_rows, - create_mapped_exception, - default_statement_config, - driver_profile, - format_identifier, - normalize_execute_many_parameters, - normalize_execute_parameters, - resolve_rowcount, -) -from sqlspec.adapters.mock.data_dictionary import MockAsyncDataDictionary, MockDataDictionary -from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile -from sqlspec.driver import ( - AsyncDriverAdapterBase, - BaseAsyncExceptionHandler, - BaseSyncExceptionHandler, - SyncDriverAdapterBase, - convert_to_dialect, -) -from sqlspec.exceptions import SQLSpecError -from sqlspec.utils.sync_tools import async_ - -if TYPE_CHECKING: - from sqlspec.adapters.mock._typing import MockConnection - from sqlspec.core import SQL, StatementConfig - from sqlspec.driver import ExecutionResult - from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry - -__all__ = ( - "MockAsyncCursor", - "MockAsyncDriver", - "MockAsyncSessionContext", - "MockCursor", - "MockExceptionHandler", - "MockSyncDriver", - "MockSyncSessionContext", -) - - -class MockExceptionHandler(BaseSyncExceptionHandler): - """Context manager for handling SQLite database exceptions. - - Maps SQLite extended result codes to specific SQLSpec exceptions - for better error handling in application code. - - Uses deferred exception pattern for mypyc compatibility: exceptions - are stored in pending_exception rather than raised from __exit__ - to avoid ABI boundary violations with compiled code. - """ - - __slots__ = () - - def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: - if exc_type is None: - return False - if issubclass(exc_type, sqlite3.Error): - self.pending_exception = create_mapped_exception(exc_val) - return True - return False - - -class MockAsyncExceptionHandler(BaseAsyncExceptionHandler): - """Async context manager for handling SQLite database exceptions. - - Uses deferred exception pattern for mypyc compatibility. - """ - - __slots__ = () - - def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: - if exc_type is None: - return False - if issubclass(exc_type, sqlite3.Error): - self.pending_exception = create_mapped_exception(exc_val) - return True - return False - - -class MockSyncDriver(SyncDriverAdapterBase): - """Mock sync driver with dialect transpilation. - - Provides SQL statement execution, transaction management, and result handling - using SQLite :memory: as the backend. Accepts SQL written in various dialects - (Postgres, MySQL, Oracle, etc.) and transpiles to SQLite before execution. - """ - - __slots__ = ("_data_dictionary", "_target_dialect") - dialect = "sqlite" - - def __init__( - self, - connection: "MockConnection", - statement_config: "StatementConfig | None" = None, - driver_features: "dict[str, Any] | None" = None, - target_dialect: str = "sqlite", - ) -> None: - """Initialize Mock sync driver. - - Args: - connection: SQLite database connection - statement_config: Statement configuration settings - driver_features: Driver-specific feature flags - target_dialect: Source dialect for SQL transpilation (postgres, mysql, etc.) - """ - if statement_config is None: - statement_config = default_statement_config.replace( - enable_caching=get_cache_config().compiled_cache_enabled - ) - - super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) - self._data_dictionary: MockDataDictionary | None = None - self._target_dialect = target_dialect - - # ───────────────────────────────────────────────────────────────────────────── - # CORE DISPATCH METHODS - # ───────────────────────────────────────────────────────────────────────────── - - def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": - """Execute single SQL statement. - - Args: - cursor: SQLite cursor object - statement: SQL statement to execute - - Returns: - ExecutionResult with statement execution details - """ - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) - - if statement.returns_rows(): - fetched_data = cursor.fetchall() - data, column_names, row_count = collect_rows(fetched_data, cursor.description) - - return self.create_execution_result( - cursor, - selected_data=data, - column_names=column_names, - data_row_count=row_count, - is_select_result=True, - row_format="tuple", - ) - - affected_rows = resolve_rowcount(cursor) - return self.create_execution_result(cursor, rowcount_override=affected_rows) - - def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": - """Execute SQL with multiple parameter sets. - - Args: - cursor: SQLite cursor object - statement: SQL statement with multiple parameter sets - - Returns: - ExecutionResult with batch execution details - """ - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - - cursor.executemany(sql, normalize_execute_many_parameters(prepared_parameters)) - - affected_rows = resolve_rowcount(cursor) - - return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - - def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": - """Execute SQL script with statement splitting and parameter handling. - - Args: - cursor: SQLite cursor object - statement: SQL statement containing multiple statements - - Returns: - ExecutionResult with script execution details - """ - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) - - successful_count = 0 - - for stmt in statements: - cursor.execute(stmt, normalize_execute_parameters(prepared_parameters)) - successful_count += 1 - - return self.create_execution_result( - cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True - ) - - # ───────────────────────────────────────────────────────────────────────────── - # TRANSACTION MANAGEMENT - # ───────────────────────────────────────────────────────────────────────────── - - def begin(self) -> None: - """Begin a database transaction. - - Raises: - SQLSpecError: If transaction cannot be started - """ - try: - if not self.connection.in_transaction: - self.connection.execute("BEGIN") - except sqlite3.Error as e: - msg = f"Failed to begin transaction: {e}" - raise SQLSpecError(msg) from e - - def commit(self) -> None: - """Commit the current transaction. - - Raises: - SQLSpecError: If transaction cannot be committed - """ - try: - self.connection.commit() - except sqlite3.Error as e: - msg = f"Failed to commit transaction: {e}" - raise SQLSpecError(msg) from e - - def rollback(self) -> None: - """Rollback the current transaction. - - Raises: - SQLSpecError: If transaction cannot be rolled back - """ - try: - self.connection.rollback() - except sqlite3.Error as e: - msg = f"Failed to rollback transaction: {e}" - raise SQLSpecError(msg) from e - - def with_cursor(self, connection: "MockConnection") -> "MockCursor": - """Create context manager for SQLite cursor. - - Args: - connection: SQLite database connection - - Returns: - Cursor context manager for safe cursor operations - """ - return MockCursor(connection) - - def handle_database_exceptions(self) -> "MockExceptionHandler": - """Handle database-specific exceptions and wrap them appropriately. - - Returns: - Exception handler with deferred exception pattern for mypyc compatibility. - """ - return MockExceptionHandler() - - # ───────────────────────────────────────────────────────────────────────────── - # STORAGE API METHODS - # ───────────────────────────────────────────────────────────────────────────── - - def select_to_storage( - self, - statement: "SQL | str", - destination: "StorageDestination", - /, - *parameters: Any, - statement_config: "StatementConfig | None" = None, - partitioner: "dict[str, object] | None" = None, - format_hint: "StorageFormat | None" = None, - telemetry: "StorageTelemetry | None" = None, - **kwargs: Any, - ) -> "StorageBridgeJob": - """Execute a query and write Arrow-compatible output to storage (sync).""" - self._require_capability("arrow_export_enabled") - arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) - sync_pipeline = self._storage_pipeline() - telemetry_payload = self._write_result_to_storage_sync( - arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline - ) - self._attach_partition_telemetry(telemetry_payload, partitioner) - return self._create_storage_job(telemetry_payload, telemetry) - - def load_from_arrow( - self, - table: str, - source: "ArrowResult | Any", - *, - partitioner: "dict[str, object] | None" = None, - overwrite: bool = False, - telemetry: "StorageTelemetry | None" = None, - ) -> "StorageBridgeJob": - """Load Arrow data into SQLite using batched inserts.""" - self._require_capability("arrow_import_enabled") - arrow_table = self._coerce_arrow_table(source) - if overwrite: - delete_statement = f"DELETE FROM {format_identifier(table)}" - exc_handler = self.handle_database_exceptions() - with exc_handler, self.with_cursor(self.connection) as cursor: - cursor.execute(delete_statement) - if exc_handler.pending_exception is not None: - raise exc_handler.pending_exception from None - - columns, records = self._arrow_table_to_rows(arrow_table) - if records: - insert_sql = build_insert_statement(table, columns) - exc_handler = self.handle_database_exceptions() - with exc_handler, self.with_cursor(self.connection) as cursor: - cursor.executemany(insert_sql, records) - if exc_handler.pending_exception is not None: - raise exc_handler.pending_exception from None - - telemetry_payload = self._build_ingest_telemetry(arrow_table) - telemetry_payload["destination"] = table - self._attach_partition_telemetry(telemetry_payload, partitioner) - return self._create_storage_job(telemetry_payload, telemetry) - - def load_from_storage( - self, - table: str, - source: "StorageDestination", - *, - file_format: "StorageFormat", - partitioner: "dict[str, object] | None" = None, - overwrite: bool = False, - ) -> "StorageBridgeJob": - """Load staged artifacts from storage into SQLite.""" - arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) - return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound) - - # ───────────────────────────────────────────────────────────────────────────── - # UTILITY METHODS - # ───────────────────────────────────────────────────────────────────────────── - - @property - def data_dictionary(self) -> "MockDataDictionary": - """Get the data dictionary for this driver. - - Returns: - Data dictionary instance for metadata queries - """ - if self._data_dictionary is None: - self._data_dictionary = MockDataDictionary() - return self._data_dictionary - - # ───────────────────────────────────────────────────────────────────────────── - # PRIVATE/INTERNAL METHODS - # ───────────────────────────────────────────────────────────────────────────── - - def _transpile_to_sqlite(self, statement: "SQL") -> str: - """Convert statement from target dialect to SQLite. - - Args: - statement: SQL statement to transpile. - - Returns: - Transpiled SQL string compatible with SQLite. - """ - if self._target_dialect == "sqlite": - sql, _ = self._get_compiled_sql(statement, self.statement_config) - return sql - return convert_to_dialect(statement, self._target_dialect, "sqlite", pretty=False) - - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": - """Collect mock sync rows for the direct execution path.""" - return collect_rows(fetched, cursor.description) - - def resolve_rowcount(self, cursor: Any) -> int: - """Resolve rowcount from mock cursor for the direct execution path.""" - return resolve_rowcount(cursor) - - def _connection_in_transaction(self) -> bool: - """Check if connection is in transaction. - - Returns: - True if connection is in an active transaction. - """ - return bool(self.connection.in_transaction) - - -class MockAsyncDriver(AsyncDriverAdapterBase): - """Mock async driver with dialect transpilation. - - Provides async SQL statement execution using SQLite :memory: as the backend. - Uses asyncio.to_thread() to wrap sync SQLite operations. Accepts SQL written - in various dialects (Postgres, MySQL, Oracle, etc.) and transpiles to SQLite. - """ - - __slots__ = ("_async_data_dictionary", "_target_dialect") - dialect = "sqlite" - - def __init__( - self, - connection: "MockConnection", - statement_config: "StatementConfig | None" = None, - driver_features: "dict[str, Any] | None" = None, - target_dialect: str = "sqlite", - ) -> None: - """Initialize Mock async driver. - - Args: - connection: SQLite database connection - statement_config: Statement configuration settings - driver_features: Driver-specific feature flags - target_dialect: Source dialect for SQL transpilation (postgres, mysql, etc.) - """ - if statement_config is None: - statement_config = default_statement_config.replace( - enable_caching=get_cache_config().compiled_cache_enabled - ) - - super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) - self._async_data_dictionary: MockAsyncDataDictionary | None = None - self._target_dialect = target_dialect - - # ───────────────────────────────────────────────────────────────────────────── - # CORE DISPATCH METHODS - # ───────────────────────────────────────────────────────────────────────────── - - async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": - """Execute single SQL statement asynchronously. - - Args: - cursor: SQLite cursor object - statement: SQL statement to execute - - Returns: - ExecutionResult with statement execution details - """ - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - - execute_async = async_(cursor.execute) - await execute_async(sql, normalize_execute_parameters(prepared_parameters)) - - if statement.returns_rows(): - fetchall_async = async_(cursor.fetchall) - fetched_data = await fetchall_async() - data, column_names, row_count = collect_rows(fetched_data, cursor.description) - - return self.create_execution_result( - cursor, - selected_data=data, - column_names=column_names, - data_row_count=row_count, - is_select_result=True, - row_format="tuple", - ) - - affected_rows = resolve_rowcount(cursor) - return self.create_execution_result(cursor, rowcount_override=affected_rows) - - async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": - """Execute SQL with multiple parameter sets asynchronously. - - Args: - cursor: SQLite cursor object - statement: SQL statement with multiple parameter sets - - Returns: - ExecutionResult with batch execution details - """ - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - - executemany_async = async_(cursor.executemany) - await executemany_async(sql, normalize_execute_many_parameters(prepared_parameters)) - - affected_rows = resolve_rowcount(cursor) - - return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - - async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": - """Execute SQL script asynchronously. - - Args: - cursor: SQLite cursor object - statement: SQL statement containing multiple statements - - Returns: - ExecutionResult with script execution details - """ - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) - - successful_count = 0 - - for stmt in statements: - execute_async = async_(cursor.execute) - await execute_async(stmt, normalize_execute_parameters(prepared_parameters)) - successful_count += 1 - - return self.create_execution_result( - cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True - ) - - # ───────────────────────────────────────────────────────────────────────────── - # TRANSACTION MANAGEMENT - # ───────────────────────────────────────────────────────────────────────────── - - async def begin(self) -> None: - """Begin a database transaction. - - Raises: - SQLSpecError: If transaction cannot be started - """ - try: - if not self.connection.in_transaction: - execute_async = async_(self.connection.execute) - await execute_async("BEGIN") - except sqlite3.Error as e: - msg = f"Failed to begin transaction: {e}" - raise SQLSpecError(msg) from e - - async def commit(self) -> None: - """Commit the current transaction. - - Raises: - SQLSpecError: If transaction cannot be committed - """ - try: - commit_async = async_(self.connection.commit) - await commit_async() - except sqlite3.Error as e: - msg = f"Failed to commit transaction: {e}" - raise SQLSpecError(msg) from e - - async def rollback(self) -> None: - """Rollback the current transaction. - - Raises: - SQLSpecError: If transaction cannot be rolled back - """ - try: - rollback_async = async_(self.connection.rollback) - await rollback_async() - except sqlite3.Error as e: - msg = f"Failed to rollback transaction: {e}" - raise SQLSpecError(msg) from e - - def with_cursor(self, connection: "MockConnection") -> "MockAsyncCursor": - """Create async context manager for SQLite cursor. - - Args: - connection: SQLite database connection - - Returns: - Async cursor context manager - """ - return MockAsyncCursor(connection) - - def handle_database_exceptions(self) -> "MockAsyncExceptionHandler": - """Handle database-specific exceptions. - - Returns: - Async exception handler with deferred exception pattern. - """ - return MockAsyncExceptionHandler() - - # ───────────────────────────────────────────────────────────────────────────── - # STORAGE API METHODS - # ───────────────────────────────────────────────────────────────────────────── - - async def select_to_storage( - self, - statement: "SQL | str", - destination: "StorageDestination", - /, - *parameters: Any, - statement_config: "StatementConfig | None" = None, - partitioner: "dict[str, object] | None" = None, - format_hint: "StorageFormat | None" = None, - telemetry: "StorageTelemetry | None" = None, - **kwargs: Any, - ) -> "StorageBridgeJob": - """Execute a query and stream Arrow results into storage.""" - self._require_capability("arrow_export_enabled") - arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) - async_pipeline = self._storage_pipeline() - telemetry_payload = await self._write_result_to_storage_async( - arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline - ) - self._attach_partition_telemetry(telemetry_payload, partitioner) - return self._create_storage_job(telemetry_payload, telemetry) - - async def load_from_arrow( - self, - table: str, - source: "ArrowResult | Any", - *, - partitioner: "dict[str, object] | None" = None, - overwrite: bool = False, - telemetry: "StorageTelemetry | None" = None, - ) -> "StorageBridgeJob": - """Load Arrow data into SQLite using batched inserts.""" - self._require_capability("arrow_import_enabled") - arrow_table = self._coerce_arrow_table(source) - if overwrite: - delete_statement = f"DELETE FROM {format_identifier(table)}" - exc_handler = self.handle_database_exceptions() - async with exc_handler, self.with_cursor(self.connection) as cursor: - execute_async = async_(cursor.execute) - await execute_async(delete_statement) - if exc_handler.pending_exception is not None: - raise exc_handler.pending_exception from None - - columns, records = self._arrow_table_to_rows(arrow_table) - if records: - insert_sql = build_insert_statement(table, columns) - exc_handler = self.handle_database_exceptions() - async with exc_handler, self.with_cursor(self.connection) as cursor: - executemany_async = async_(cursor.executemany) - await executemany_async(insert_sql, records) - if exc_handler.pending_exception is not None: - raise exc_handler.pending_exception from None - - telemetry_payload = self._build_ingest_telemetry(arrow_table) - telemetry_payload["destination"] = table - self._attach_partition_telemetry(telemetry_payload, partitioner) - return self._create_storage_job(telemetry_payload, telemetry) - - async def load_from_storage( - self, - table: str, - source: "StorageDestination", - *, - file_format: "StorageFormat", - partitioner: "dict[str, object] | None" = None, - overwrite: bool = False, - ) -> "StorageBridgeJob": - """Load staged artifacts from storage into SQLite.""" - arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) - return await self.load_from_arrow( - table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound - ) - - # ───────────────────────────────────────────────────────────────────────────── - # UTILITY METHODS - # ───────────────────────────────────────────────────────────────────────────── - - @property - def data_dictionary(self) -> "MockAsyncDataDictionary": - """Get the async data dictionary for this driver. - - Returns: - Async data dictionary instance for metadata queries - """ - if self._async_data_dictionary is None: - self._async_data_dictionary = MockAsyncDataDictionary() - return self._async_data_dictionary - - # ───────────────────────────────────────────────────────────────────────────── - # PRIVATE/INTERNAL METHODS - # ───────────────────────────────────────────────────────────────────────────── - - def _transpile_to_sqlite(self, statement: "SQL") -> str: - """Convert statement from target dialect to SQLite. - - Args: - statement: SQL statement to transpile. - - Returns: - Transpiled SQL string compatible with SQLite. - """ - if self._target_dialect == "sqlite": - sql, _ = self._get_compiled_sql(statement, self.statement_config) - return sql - return convert_to_dialect(statement, self._target_dialect, "sqlite", pretty=False) - - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": - """Collect mock async rows for the direct execution path.""" - return collect_rows(fetched, cursor.description) - - def resolve_rowcount(self, cursor: Any) -> int: - """Resolve rowcount from mock cursor for the direct execution path.""" - return resolve_rowcount(cursor) - - def _connection_in_transaction(self) -> bool: - """Check if connection is in transaction. - - Returns: - True if connection is in an active transaction. - """ - return bool(self.connection.in_transaction) - - -register_driver_profile("mock", driver_profile) From 570b45337c4ec9a69c84d2a62931a5e21b30bb63 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 22 Apr 2026 18:54:31 +0000 Subject: [PATCH 2/8] chore(config): remove mock adapter from mypyc exclusions --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 85c14cff0..58909436f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -183,7 +183,6 @@ exclude = [ "sqlspec/dialects/**/*.py", # Keep SQLGlot dialect subclasses interpreted "sqlspec/**/__init__.py", # Init files (usually just imports) "sqlspec/protocols.py", # Protocol definitions - "sqlspec/adapters/mock/**", # Mock adapter (testing only) "sqlspec/migrations/commands.py", # Migration command CLI (dynamic imports) "sqlspec/data_dictionary/_loader.py", # Loader relies on __file__ which fails in compiled modules "sqlspec/extensions/fastapi/providers.py", # Uses SingletonMeta metaclass From f4176ed753f5a5e3aa5d569c9fc56ef635d80489 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 22 Apr 2026 18:56:05 +0000 Subject: [PATCH 3/8] test(adapters): remove mock adapter unit tests --- tests/unit/adapters/test_mock/__init__.py | 1 - tests/unit/adapters/test_mock/test_config.py | 152 ------- tests/unit/adapters/test_mock/test_core.py | 349 ---------------- .../test_mock/test_cursor_and_exceptions.py | 257 ------------ .../test_mock/test_data_dictionary.py | 316 -------------- .../test_mock/test_dialect_transpilation.py | 225 ---------- tests/unit/adapters/test_mock/test_driver.py | 233 ----------- .../adapters/test_mock/test_edge_cases.py | 389 ------------------ 8 files changed, 1922 deletions(-) delete mode 100644 tests/unit/adapters/test_mock/__init__.py delete mode 100644 tests/unit/adapters/test_mock/test_config.py delete mode 100644 tests/unit/adapters/test_mock/test_core.py delete mode 100644 tests/unit/adapters/test_mock/test_cursor_and_exceptions.py delete mode 100644 tests/unit/adapters/test_mock/test_data_dictionary.py delete mode 100644 tests/unit/adapters/test_mock/test_dialect_transpilation.py delete mode 100644 tests/unit/adapters/test_mock/test_driver.py delete mode 100644 tests/unit/adapters/test_mock/test_edge_cases.py diff --git a/tests/unit/adapters/test_mock/__init__.py b/tests/unit/adapters/test_mock/__init__.py deleted file mode 100644 index cd82fd0b9..000000000 --- a/tests/unit/adapters/test_mock/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Mock adapter unit tests.""" diff --git a/tests/unit/adapters/test_mock/test_config.py b/tests/unit/adapters/test_mock/test_config.py deleted file mode 100644 index 69b00402e..000000000 --- a/tests/unit/adapters/test_mock/test_config.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Unit tests for mock configuration classes.""" - -import pytest - -from sqlspec.adapters.mock import MockAsyncConfig, MockSyncConfig - - -def test_mock_sync_config_defaults() -> None: - """Test MockSyncConfig default values.""" - config = MockSyncConfig() - - assert config.target_dialect == "sqlite" - assert config.initial_sql is None - assert config.is_async is False - assert config.supports_transactional_ddl is True - - -def test_mock_sync_config_with_target_dialect() -> None: - """Test MockSyncConfig with custom target dialect.""" - config = MockSyncConfig(target_dialect="postgres") - - assert config.target_dialect == "postgres" - - -def test_mock_sync_config_with_initial_sql_string() -> None: - """Test MockSyncConfig with initial SQL as string.""" - config = MockSyncConfig(initial_sql="CREATE TABLE test (id INTEGER)") - - assert config.initial_sql == "CREATE TABLE test (id INTEGER)" - - -def test_mock_sync_config_with_initial_sql_list() -> None: - """Test MockSyncConfig with initial SQL as list.""" - sql_list = ["CREATE TABLE test1 (id INTEGER)", "CREATE TABLE test2 (id INTEGER)"] - config = MockSyncConfig(initial_sql=sql_list) - - assert config.initial_sql == sql_list - - -def test_mock_sync_config_create_connection() -> None: - """Test that create_connection returns a valid connection.""" - config = MockSyncConfig() - conn = config.create_connection() - - assert conn is not None - conn.close() - - -def test_mock_sync_config_provide_connection_context() -> None: - """Test provide_connection context manager.""" - config = MockSyncConfig() - - with config.provide_connection() as conn: - cursor = conn.cursor() - cursor.execute("SELECT 1") - result = cursor.fetchone() - assert result[0] == 1 - - -def test_mock_sync_config_provide_session_context() -> None: - """Test provide_session context manager.""" - config = MockSyncConfig() - - with config.provide_session() as session: - result = session.select_value("SELECT 42") - assert result == 42 - - -def test_mock_async_config_defaults() -> None: - """Test MockAsyncConfig default values.""" - config = MockAsyncConfig() - - assert config.target_dialect == "sqlite" - assert config.initial_sql is None - assert config.is_async is True - assert config.supports_transactional_ddl is True - - -def test_mock_async_config_with_target_dialect() -> None: - """Test MockAsyncConfig with custom target dialect.""" - config = MockAsyncConfig(target_dialect="mysql") - - assert config.target_dialect == "mysql" - - -@pytest.mark.anyio -async def test_mock_async_config_create_connection() -> None: - """Test that async create_connection returns a valid connection.""" - config = MockAsyncConfig() - conn = await config.create_connection() - - assert conn is not None - conn.close() - - -@pytest.mark.anyio -async def test_mock_async_config_provide_connection_context() -> None: - """Test async provide_connection context manager.""" - config = MockAsyncConfig() - - async with config.provide_connection() as conn: - cursor = conn.cursor() - cursor.execute("SELECT 1") - result = cursor.fetchone() - assert result[0] == 1 - - -@pytest.mark.anyio -async def test_mock_async_config_provide_session_context() -> None: - """Test async provide_session context manager.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - result = await session.select_value("SELECT 42") - assert result == 42 - - -def test_mock_config_with_driver_features() -> None: - """Test MockSyncConfig with custom driver features.""" - from sqlspec.utils.serializers import from_json, to_json - - config = MockSyncConfig(driver_features={"json_serializer": to_json, "json_deserializer": from_json}) - - assert config.driver_features.get("json_serializer") is to_json - assert config.driver_features.get("json_deserializer") is from_json - - -def test_mock_config_with_bind_key() -> None: - """Test MockSyncConfig with bind_key.""" - config = MockSyncConfig(bind_key="test_db") - - assert config.bind_key == "test_db" - - -def test_mock_config_supports_arrow() -> None: - """Test that mock config reports Arrow support.""" - config = MockSyncConfig() - - assert config.supports_native_arrow_export is True - assert config.supports_native_arrow_import is True - assert config.supports_native_parquet_export is True - assert config.supports_native_parquet_import is True - - -def test_mock_async_config_supports_arrow() -> None: - """Test that async mock config reports Arrow support.""" - config = MockAsyncConfig() - - assert config.supports_native_arrow_export is True - assert config.supports_native_arrow_import is True - assert config.supports_native_parquet_export is True - assert config.supports_native_parquet_import is True diff --git a/tests/unit/adapters/test_mock/test_core.py b/tests/unit/adapters/test_mock/test_core.py deleted file mode 100644 index fc247727d..000000000 --- a/tests/unit/adapters/test_mock/test_core.py +++ /dev/null @@ -1,349 +0,0 @@ -"""Unit tests for mock adapter core utilities.""" - -import sqlite3 -from datetime import date, datetime -from decimal import Decimal - -import pytest - -from sqlspec.adapters.mock.core import ( - apply_driver_features, - build_insert_statement, - collect_rows, - create_mapped_exception, - default_statement_config, - driver_profile, - format_identifier, - normalize_execute_many_parameters, - normalize_execute_parameters, - resolve_rowcount, -) -from sqlspec.core import ParameterStyle, StatementConfig -from sqlspec.exceptions import ( - CheckViolationError, - ForeignKeyViolationError, - NotNullViolationError, - SQLParsingError, - SQLSpecError, - UniqueViolationError, -) -from sqlspec.utils.serializers import from_json, to_json - - -def test_driver_profile_defaults() -> None: - """Test driver profile has correct default values.""" - assert driver_profile.name == "Mock" - assert driver_profile.default_style == ParameterStyle.QMARK - assert ParameterStyle.QMARK in driver_profile.supported_styles - assert ParameterStyle.NAMED_COLON in driver_profile.supported_styles - assert driver_profile.has_native_list_expansion is False - assert driver_profile.json_serializer_strategy == "helper" - assert driver_profile.default_dialect == "sqlite" - - -def test_default_statement_config() -> None: - """Test default statement config is properly initialized.""" - assert default_statement_config.dialect == "sqlite" - assert default_statement_config.parameter_config.default_parameter_style == ParameterStyle.QMARK - - -def test_format_identifier_simple() -> None: - """Test formatting simple identifiers.""" - assert format_identifier("users") == '"users"' - assert format_identifier("table_name") == '"table_name"' - - -def test_format_identifier_with_schema() -> None: - """Test formatting identifiers with schema.""" - result = format_identifier("public.users") - assert result == '"public"."users"' - - -def test_format_identifier_with_quotes() -> None: - """Test formatting identifiers containing quotes.""" - result = format_identifier('table"name') - assert result == '"table""name"' - - -def test_format_identifier_empty_raises() -> None: - """Test formatting empty identifier raises error.""" - with pytest.raises(SQLSpecError, match="Table name must not be empty"): - format_identifier("") - - with pytest.raises(SQLSpecError, match="Table name must not be empty"): - format_identifier(" ") - - -def test_format_identifier_with_dots_edge_case() -> None: - """Test formatting with multiple dots.""" - result = format_identifier("db.schema.table") - assert result == '"db"."schema"."table"' - - -def test_build_insert_statement_basic() -> None: - """Test building basic INSERT statement.""" - result = build_insert_statement("users", ["id", "name"]) - assert result == 'INSERT INTO "users" ("id", "name") VALUES (?, ?)' - - -def test_build_insert_statement_single_column() -> None: - """Test INSERT statement with single column.""" - result = build_insert_statement("counters", ["count"]) - assert result == 'INSERT INTO "counters" ("count") VALUES (?)' - - -def test_build_insert_statement_many_columns() -> None: - """Test INSERT statement with many columns.""" - columns = ["col1", "col2", "col3", "col4", "col5"] - result = build_insert_statement("data", columns) - assert 'INSERT INTO "data"' in result - assert '"col1", "col2", "col3", "col4", "col5"' in result - assert "VALUES (?, ?, ?, ?, ?)" in result - - -def test_build_insert_statement_with_schema() -> None: - """Test INSERT statement with schema-qualified table.""" - result = build_insert_statement("public.users", ["id", "name"]) - assert result == 'INSERT INTO "public"."users" ("id", "name") VALUES (?, ?)' - - -def test_collect_rows_empty() -> None: - """Test collecting empty result set.""" - data, columns, count = collect_rows([], None) - assert data == [] - assert columns == [] - assert count == 0 - - -def test_collect_rows_with_data() -> None: - """Test collecting rows with raw tuple data (no dict conversion).""" - description = [("id",), ("name",)] - rows = [(1, "Alice"), (2, "Bob")] - data, columns, count = collect_rows(rows, description) - - assert count == 2 - assert columns == ["id", "name"] - assert len(data) == 2 - assert data[0] == (1, "Alice") - assert data[1] == (2, "Bob") - - -def test_collect_rows_with_none_values() -> None: - """Test collecting rows containing None values (raw tuples).""" - description = [("id",), ("value",)] - rows = [(1, None), (2, "text")] - data, _columns, count = collect_rows(rows, description) - - assert count == 2 - assert data[0] == (1, None) - assert data[1] == (2, "text") - - -def test_resolve_rowcount_with_valid_cursor() -> None: - """Test resolving rowcount from cursor with rowcount attribute.""" - conn = sqlite3.connect(":memory:") - cursor = conn.cursor() - cursor.execute("CREATE TABLE test (id INTEGER)") - cursor.execute("INSERT INTO test VALUES (1)") - - rowcount = resolve_rowcount(cursor) - assert rowcount == 1 - - conn.close() - - -def test_resolve_rowcount_negative_value() -> None: - """Test resolving rowcount when value is negative.""" - conn = sqlite3.connect(":memory:") - cursor = conn.cursor() - cursor.execute("SELECT 1") - - rowcount = resolve_rowcount(cursor) - assert rowcount == 0 - - conn.close() - - -def test_resolve_rowcount_no_attribute() -> None: - """Test resolving rowcount from object without rowcount.""" - - class FakeCursor: - pass - - rowcount = resolve_rowcount(FakeCursor()) - assert rowcount == 0 - - -def test_normalize_execute_parameters_with_tuple() -> None: - """Test normalizing tuple parameters.""" - result = normalize_execute_parameters((1, "test")) - assert result == (1, "test") - - -def test_normalize_execute_parameters_with_list() -> None: - """Test normalizing list parameters.""" - result = normalize_execute_parameters([1, 2, 3]) - assert result == [1, 2, 3] - - -def test_normalize_execute_parameters_empty() -> None: - """Test normalizing empty parameters.""" - result = normalize_execute_parameters(None) - assert result == () - - -def test_normalize_execute_many_parameters_valid() -> None: - """Test normalizing execute_many parameters.""" - params = [(1, "a"), (2, "b")] - result = normalize_execute_many_parameters(params) - assert result == params - - -def test_normalize_execute_many_parameters_empty_raises() -> None: - """Test normalizing empty execute_many parameters raises error.""" - with pytest.raises(ValueError, match="execute_many requires parameters"): - normalize_execute_many_parameters(None) - - with pytest.raises(ValueError, match="execute_many requires parameters"): - normalize_execute_many_parameters([]) - - -def test_create_mapped_exception_unique_constraint_code() -> None: - """Test creating exception for unique constraint violation with error code.""" - conn = sqlite3.connect(":memory:") - try: - conn.execute("CREATE TABLE test (id INTEGER UNIQUE)") - conn.execute("INSERT INTO test VALUES (1)") - conn.execute("INSERT INTO test VALUES (1)") - except sqlite3.Error as e: - result = create_mapped_exception(e) - assert isinstance(result, UniqueViolationError) - finally: - conn.close() - - -def test_create_mapped_exception_foreign_key_constraint() -> None: - """Test creating exception for foreign key constraint violation.""" - conn = sqlite3.connect(":memory:") - try: - conn.execute("PRAGMA foreign_keys = ON") - conn.execute("CREATE TABLE parent (id INTEGER PRIMARY KEY)") - conn.execute("CREATE TABLE child (id INTEGER, parent_id INTEGER, FOREIGN KEY(parent_id) REFERENCES parent(id))") - conn.execute("INSERT INTO child VALUES (1, 999)") - except sqlite3.Error as e: - result = create_mapped_exception(e) - assert isinstance(result, ForeignKeyViolationError) - finally: - conn.close() - - -def test_create_mapped_exception_not_null_constraint() -> None: - """Test creating exception for not null constraint violation.""" - conn = sqlite3.connect(":memory:") - try: - conn.execute("CREATE TABLE test (id INTEGER NOT NULL)") - conn.execute("INSERT INTO test VALUES (NULL)") - except sqlite3.Error as e: - result = create_mapped_exception(e) - assert isinstance(result, NotNullViolationError) - finally: - conn.close() - - -def test_create_mapped_exception_check_constraint() -> None: - """Test creating exception for check constraint violation.""" - conn = sqlite3.connect(":memory:") - try: - conn.execute("CREATE TABLE test (id INTEGER CHECK(id > 0))") - conn.execute("INSERT INTO test VALUES (-1)") - except sqlite3.Error as e: - result = create_mapped_exception(e) - assert isinstance(result, CheckViolationError) - finally: - conn.close() - - -def test_create_mapped_exception_syntax_error() -> None: - """Test creating exception for SQL syntax error.""" - conn = sqlite3.connect(":memory:") - try: - conn.execute("INVALID SQL SYNTAX") - except sqlite3.Error as e: - result = create_mapped_exception(e) - assert isinstance(result, SQLParsingError) - finally: - conn.close() - - -def test_create_mapped_exception_generic_error() -> None: - """Test creating exception for generic database error.""" - - class CustomSQLiteError(sqlite3.Error): - pass - - error = CustomSQLiteError("Generic error") - result = create_mapped_exception(error) - assert isinstance(result, SQLSpecError) - assert "Generic error" in str(result) - - -def test_apply_driver_features_defaults() -> None: - """Test applying driver features with defaults.""" - config = StatementConfig() - result_config, features = apply_driver_features(config, None) - - assert features["json_serializer"] is to_json - assert features["json_deserializer"] is from_json - assert result_config is not None - - -def test_apply_driver_features_custom_serializers() -> None: - """Test applying driver features with custom JSON serializers.""" - - def custom_serializer(obj: object) -> str: - return "custom" - - def custom_deserializer(s: str) -> object: - return {"custom": True} - - config = StatementConfig() - _result_config, features = apply_driver_features( - config, {"json_serializer": custom_serializer, "json_deserializer": custom_deserializer} - ) - - assert features["json_serializer"] is custom_serializer - assert features["json_deserializer"] is custom_deserializer - - -def test_apply_driver_features_preserves_other_features() -> None: - """Test that apply_driver_features preserves non-JSON features.""" - config = StatementConfig() - _result_config, features = apply_driver_features(config, {"custom_feature": "value", "another_feature": 123}) - - assert features["custom_feature"] == "value" - assert features["another_feature"] == 123 - assert "json_serializer" in features - assert "json_deserializer" in features - - -def test_driver_profile_type_coercions() -> None: - """Test that driver profile has correct type coercions.""" - coercions = driver_profile.custom_type_coercions - - assert bool in coercions - assert datetime in coercions - assert date in coercions - assert Decimal in coercions - - bool_converter = coercions[bool] - assert bool_converter(True) == 1 - assert bool_converter(False) == 0 - - -def test_driver_profile_decimal_conversion() -> None: - """Test Decimal to string conversion in driver profile.""" - coercions = driver_profile.custom_type_coercions - decimal_converter = coercions[Decimal] - - result = decimal_converter(Decimal("123.45")) - assert result == "123.45" diff --git a/tests/unit/adapters/test_mock/test_cursor_and_exceptions.py b/tests/unit/adapters/test_mock/test_cursor_and_exceptions.py deleted file mode 100644 index 0e1d385cb..000000000 --- a/tests/unit/adapters/test_mock/test_cursor_and_exceptions.py +++ /dev/null @@ -1,257 +0,0 @@ -"""Unit tests for cursor management and exception handling.""" - -import sqlite3 - -import pytest - -from sqlspec.adapters.mock._typing import MockAsyncCursor, MockCursor -from sqlspec.adapters.mock.driver import MockAsyncExceptionHandler, MockExceptionHandler -from sqlspec.exceptions import UniqueViolationError - - -def test_mock_cursor_context_manager() -> None: - """Test MockCursor context manager creates and cleans up cursor.""" - conn = sqlite3.connect(":memory:") - - with MockCursor(conn) as cursor: - assert cursor is not None - assert isinstance(cursor, sqlite3.Cursor) - cursor.execute("SELECT 1") - result = cursor.fetchone() - assert result[0] == 1 - - conn.close() - - -def test_mock_cursor_cleanup_on_exception() -> None: - """Test MockCursor cleans up cursor even when exception occurs.""" - conn = sqlite3.connect(":memory:") - - try: - with MockCursor(conn) as cursor: - cursor.execute("SELECT 1") - raise ValueError("Test exception") - except ValueError: - pass - - conn.close() - - -def test_mock_cursor_multiple_operations() -> None: - """Test MockCursor with multiple operations.""" - conn = sqlite3.connect(":memory:") - - with MockCursor(conn) as cursor: - cursor.execute("CREATE TABLE test (id INTEGER)") - cursor.execute("INSERT INTO test VALUES (1)") - cursor.execute("INSERT INTO test VALUES (2)") - cursor.execute("SELECT * FROM test") - results = cursor.fetchall() - assert len(results) == 2 - - conn.close() - - -def test_mock_exception_handler_no_exception() -> None: - """Test MockExceptionHandler when no exception occurs.""" - with MockExceptionHandler() as handler: - pass - - assert handler.pending_exception is None - - -def test_mock_exception_handler_non_sqlite_exception() -> None: - """Test MockExceptionHandler passes through non-SQLite exceptions.""" - handler = MockExceptionHandler() - try: - with handler: - raise ValueError("Not a SQLite error") - except ValueError as e: - assert str(e) == "Not a SQLite error" - assert handler.pending_exception is None - - -def test_mock_exception_handler_sqlite_error() -> None: - """Test MockExceptionHandler maps SQLite errors.""" - conn = sqlite3.connect(":memory:") - conn.execute("CREATE TABLE test (id INTEGER UNIQUE)") - conn.execute("INSERT INTO test VALUES (1)") - - try: - conn.execute("INSERT INTO test VALUES (1)") - except sqlite3.Error as e: - with MockExceptionHandler() as handler: - handler.__exit__(type(e), e, None) - - assert handler.pending_exception is not None - assert isinstance(handler.pending_exception, UniqueViolationError) - - conn.close() - - -def test_mock_exception_handler_captures_and_suppresses() -> None: - """Test MockExceptionHandler captures exception and returns True.""" - conn = sqlite3.connect(":memory:") - conn.execute("CREATE TABLE test (id INTEGER UNIQUE)") - conn.execute("INSERT INTO test VALUES (1)") - - with MockExceptionHandler() as handler: - try: - conn.execute("INSERT INTO test VALUES (1)") - except sqlite3.Error as e: - suppressed = handler.__exit__(type(e), e, None) - assert suppressed is True - assert handler.pending_exception is not None - - conn.close() - - -def test_mock_exception_handler_syntax_error() -> None: - """Test MockExceptionHandler maps syntax errors.""" - conn = sqlite3.connect(":memory:") - - try: - conn.execute("INVALID SQL") - except sqlite3.Error as e: - with MockExceptionHandler() as handler: - handler.__exit__(type(e), e, None) - - assert handler.pending_exception is not None - - conn.close() - - -@pytest.mark.anyio -async def test_mock_async_cursor_context_manager() -> None: - """Test MockAsyncCursor context manager.""" - conn = sqlite3.connect(":memory:") - - async with MockAsyncCursor(conn) as cursor: - assert cursor is not None - assert isinstance(cursor, sqlite3.Cursor) - cursor.execute("SELECT 1") - result = cursor.fetchone() - assert result[0] == 1 - - conn.close() - - -@pytest.mark.anyio -async def test_mock_async_cursor_cleanup_on_exception() -> None: - """Test MockAsyncCursor cleans up even with exception.""" - conn = sqlite3.connect(":memory:") - - try: - async with MockAsyncCursor(conn) as cursor: - cursor.execute("SELECT 1") - raise ValueError("Async test exception") - except ValueError: - pass - - conn.close() - - -@pytest.mark.anyio -async def test_mock_async_cursor_multiple_operations() -> None: - """Test MockAsyncCursor with multiple operations.""" - conn = sqlite3.connect(":memory:") - - async with MockAsyncCursor(conn) as cursor: - cursor.execute("CREATE TABLE async_test (id INTEGER)") - cursor.execute("INSERT INTO async_test VALUES (1)") - cursor.execute("SELECT * FROM async_test") - results = cursor.fetchall() - assert len(results) == 1 - - conn.close() - - -@pytest.mark.anyio -async def test_mock_async_exception_handler_no_exception() -> None: - """Test MockAsyncExceptionHandler when no exception occurs.""" - async with MockAsyncExceptionHandler() as handler: - pass - - assert handler.pending_exception is None - - -@pytest.mark.anyio -async def test_mock_async_exception_handler_non_sqlite_exception() -> None: - """Test MockAsyncExceptionHandler passes through non-SQLite exceptions.""" - handler = MockAsyncExceptionHandler() - try: - async with handler: - raise ValueError("Not a SQLite error") - except ValueError as e: - assert str(e) == "Not a SQLite error" - assert handler.pending_exception is None - - -@pytest.mark.anyio -async def test_mock_async_exception_handler_sqlite_error() -> None: - """Test MockAsyncExceptionHandler maps SQLite errors.""" - conn = sqlite3.connect(":memory:") - conn.execute("CREATE TABLE test (id INTEGER UNIQUE)") - conn.execute("INSERT INTO test VALUES (1)") - - try: - conn.execute("INSERT INTO test VALUES (1)") - except sqlite3.Error as e: - async with MockAsyncExceptionHandler() as handler: - await handler.__aexit__(type(e), e, None) - - assert handler.pending_exception is not None - assert isinstance(handler.pending_exception, UniqueViolationError) - - conn.close() - - -@pytest.mark.anyio -async def test_mock_async_exception_handler_captures_and_suppresses() -> None: - """Test MockAsyncExceptionHandler captures and suppresses.""" - conn = sqlite3.connect(":memory:") - conn.execute("CREATE TABLE test (id INTEGER UNIQUE)") - conn.execute("INSERT INTO test VALUES (1)") - - async with MockAsyncExceptionHandler() as handler: - try: - conn.execute("INSERT INTO test VALUES (1)") - except sqlite3.Error as e: - suppressed = await handler.__aexit__(type(e), e, None) - assert suppressed is True - assert handler.pending_exception is not None - - conn.close() - - -def test_cursor_close_failure_suppressed() -> None: - """Test that cursor close failures are suppressed.""" - - class FailingCursor: - def close(self) -> None: - raise RuntimeError("Close failed") - - conn = sqlite3.connect(":memory:") - cursor_manager = MockCursor(conn) - cursor_manager.cursor = FailingCursor() # type: ignore[assignment] - - cursor_manager.__exit__(None, None, None) - - conn.close() - - -@pytest.mark.anyio -async def test_async_cursor_close_failure_suppressed() -> None: - """Test that async cursor close failures are suppressed.""" - - class FailingCursor: - def close(self) -> None: - raise RuntimeError("Async close failed") - - conn = sqlite3.connect(":memory:") - cursor_manager = MockAsyncCursor(conn) - cursor_manager.cursor = FailingCursor() # type: ignore[assignment] - - await cursor_manager.__aexit__(None, None, None) - - conn.close() diff --git a/tests/unit/adapters/test_mock/test_data_dictionary.py b/tests/unit/adapters/test_mock/test_data_dictionary.py deleted file mode 100644 index d719b822c..000000000 --- a/tests/unit/adapters/test_mock/test_data_dictionary.py +++ /dev/null @@ -1,316 +0,0 @@ -"""Unit tests for mock data dictionary.""" - -import pytest - -from sqlspec.adapters.mock import MockAsyncConfig, MockSyncConfig - - -def test_mock_data_dictionary_get_version() -> None: - """Test retrieving SQLite version through data dictionary.""" - config = MockSyncConfig() - - with config.provide_session() as session: - version = session.data_dictionary.get_version(session) - - assert version is not None - assert version.major >= 3 - assert version.minor >= 0 - assert version.patch >= 0 - - -def test_mock_data_dictionary_version_caching() -> None: - """Test that version is cached after first retrieval.""" - config = MockSyncConfig() - - with config.provide_session() as session: - dd = session.data_dictionary - driver_id = id(session) - - was_cached, cached_version = dd.get_cached_version(driver_id) - assert was_cached is False - - version1 = dd.get_version(session) - was_cached, cached_version = dd.get_cached_version(driver_id) - assert was_cached is True - assert cached_version == version1 - - version2 = dd.get_version(session) - assert version2 == version1 - - -def test_mock_data_dictionary_get_tables() -> None: - """Test retrieving tables from data dictionary.""" - config = MockSyncConfig( - initial_sql=[ - "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)", - "CREATE TABLE orders (id INTEGER PRIMARY KEY, user_id INTEGER)", - ] - ) - - with config.provide_session() as session: - tables = session.data_dictionary.get_tables(session) - - table_names = [t["table_name"] for t in tables] - assert "users" in table_names - assert "orders" in table_names - - -def test_mock_data_dictionary_get_tables_empty() -> None: - """Test getting tables when no tables exist.""" - config = MockSyncConfig() - - with config.provide_session() as session: - tables = session.data_dictionary.get_tables(session) - assert tables == [] - - -def test_mock_data_dictionary_get_columns() -> None: - """Test retrieving columns from data dictionary.""" - config = MockSyncConfig( - initial_sql="CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT NOT NULL, price REAL)" - ) - - with config.provide_session() as session: - columns = session.data_dictionary.get_columns(session, table="products") - - assert len(columns) >= 3 - column_names = [c["column_name"] for c in columns] - assert "id" in column_names - assert "name" in column_names - assert "price" in column_names - - -def test_mock_data_dictionary_get_columns_for_schema() -> None: - """Test retrieving all columns for a schema.""" - config = MockSyncConfig( - initial_sql=["CREATE TABLE table1 (id INTEGER, name TEXT)", "CREATE TABLE table2 (id INTEGER, value REAL)"] - ) - - with config.provide_session() as session: - columns = session.data_dictionary.get_columns(session) - - assert len(columns) >= 4 - - -def test_mock_data_dictionary_get_indexes() -> None: - """Test retrieving indexes from data dictionary.""" - config = MockSyncConfig( - initial_sql=[ - "CREATE TABLE indexed_table (id INTEGER PRIMARY KEY, email TEXT UNIQUE)", - "CREATE INDEX idx_email ON indexed_table(email)", - ] - ) - - with config.provide_session() as session: - indexes = session.data_dictionary.get_indexes(session, table="indexed_table") - - assert len(indexes) > 0 - - -def test_mock_data_dictionary_get_indexes_empty() -> None: - """Test getting indexes when table has no indexes.""" - config = MockSyncConfig(initial_sql="CREATE TABLE simple (id INTEGER, name TEXT)") - - with config.provide_session() as session: - indexes = session.data_dictionary.get_indexes(session, table="simple") - assert len(indexes) == 0 or all(idx.get("index_name") for idx in indexes) - - -def test_mock_data_dictionary_get_foreign_keys() -> None: - """Test retrieving foreign keys from data dictionary.""" - config = MockSyncConfig( - initial_sql=[ - "CREATE TABLE parent (id INTEGER PRIMARY KEY)", - "CREATE TABLE child (id INTEGER, parent_id INTEGER, FOREIGN KEY(parent_id) REFERENCES parent(id))", - ] - ) - - with config.provide_session() as session: - fks = session.data_dictionary.get_foreign_keys(session, table="child") - - assert len(fks) > 0 - - -def test_mock_data_dictionary_get_foreign_keys_empty() -> None: - """Test getting foreign keys when table has none.""" - config = MockSyncConfig(initial_sql="CREATE TABLE standalone (id INTEGER PRIMARY KEY)") - - with config.provide_session() as session: - fks = session.data_dictionary.get_foreign_keys(session, table="standalone") - assert fks == [] - - -def test_mock_data_dictionary_get_optimal_type_json() -> None: - """Test getting optimal type for JSON category.""" - config = MockSyncConfig() - - with config.provide_session() as session: - json_type = session.data_dictionary.get_optimal_type(session, "json") - - assert json_type in ("JSON", "TEXT") - - -def test_mock_data_dictionary_get_optimal_type_text() -> None: - """Test getting optimal type for text category.""" - config = MockSyncConfig() - - with config.provide_session() as session: - text_type = session.data_dictionary.get_optimal_type(session, "text") - assert text_type == "TEXT" - - -def test_mock_data_dictionary_get_feature_flag() -> None: - """Test checking feature flags.""" - config = MockSyncConfig() - - with config.provide_session() as session: - supports_cte = session.data_dictionary.get_feature_flag(session, "supports_cte") - assert isinstance(supports_cte, bool) - - -def test_mock_data_dictionary_list_available_features() -> None: - """Test listing available features.""" - config = MockSyncConfig() - - with config.provide_session() as session: - features = session.data_dictionary.list_available_features() - - assert isinstance(features, list) - assert len(features) > 0 - - -def test_mock_data_dictionary_dialect() -> None: - """Test that data dictionary reports correct dialect.""" - config = MockSyncConfig() - - with config.provide_session() as session: - assert session.data_dictionary.dialect == "sqlite" - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_get_version() -> None: - """Test retrieving SQLite version through async data dictionary.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - version = await session.data_dictionary.get_version(session) - - assert version is not None - assert version.major >= 3 - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_version_caching() -> None: - """Test that version is cached in async data dictionary.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - dd = session.data_dictionary - driver_id = id(session) - - was_cached, cached_version = dd.get_cached_version(driver_id) - assert was_cached is False - - version1 = await dd.get_version(session) - was_cached, cached_version = dd.get_cached_version(driver_id) - assert was_cached is True - assert cached_version == version1 - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_get_tables() -> None: - """Test retrieving tables from async data dictionary.""" - config = MockAsyncConfig(initial_sql="CREATE TABLE async_test (id INTEGER)") - - async with config.provide_session() as session: - tables = await session.data_dictionary.get_tables(session) - - table_names = [t["table_name"] for t in tables] - assert "async_test" in table_names - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_get_columns() -> None: - """Test retrieving columns from async data dictionary.""" - config = MockAsyncConfig(initial_sql="CREATE TABLE async_cols (id INTEGER, name TEXT)") - - async with config.provide_session() as session: - columns = await session.data_dictionary.get_columns(session, table="async_cols") - - assert len(columns) >= 2 - column_names = [c["column_name"] for c in columns] - assert "id" in column_names - assert "name" in column_names - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_get_indexes() -> None: - """Test retrieving indexes from async data dictionary.""" - config = MockAsyncConfig( - initial_sql=[ - "CREATE TABLE async_indexed (id INTEGER PRIMARY KEY, value TEXT)", - "CREATE INDEX idx_value ON async_indexed(value)", - ] - ) - - async with config.provide_session() as session: - indexes = await session.data_dictionary.get_indexes(session, table="async_indexed") - - assert len(indexes) > 0 - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_get_foreign_keys() -> None: - """Test retrieving foreign keys from async data dictionary.""" - config = MockAsyncConfig( - initial_sql=[ - "CREATE TABLE async_parent (id INTEGER PRIMARY KEY)", - "CREATE TABLE async_child (id INTEGER, parent_id INTEGER, FOREIGN KEY(parent_id) REFERENCES async_parent(id))", - ] - ) - - async with config.provide_session() as session: - fks = await session.data_dictionary.get_foreign_keys(session, table="async_child") - - assert len(fks) > 0 - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_get_optimal_type() -> None: - """Test getting optimal type from async data dictionary.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - text_type = await session.data_dictionary.get_optimal_type(session, "text") - assert text_type == "TEXT" - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_get_feature_flag() -> None: - """Test checking feature flags in async data dictionary.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - supports_cte = await session.data_dictionary.get_feature_flag(session, "supports_cte") - assert isinstance(supports_cte, bool) - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_list_available_features() -> None: - """Test listing available features from async data dictionary.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - features = session.data_dictionary.list_available_features() - - assert isinstance(features, list) - assert len(features) > 0 - - -@pytest.mark.anyio -async def test_mock_async_data_dictionary_dialect() -> None: - """Test that async data dictionary reports correct dialect.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - assert session.data_dictionary.dialect == "sqlite" diff --git a/tests/unit/adapters/test_mock/test_dialect_transpilation.py b/tests/unit/adapters/test_mock/test_dialect_transpilation.py deleted file mode 100644 index dfe61cfac..000000000 --- a/tests/unit/adapters/test_mock/test_dialect_transpilation.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Unit tests for dialect transpilation in mock driver.""" - -import pytest - -from sqlspec.adapters.mock import MockAsyncConfig, MockSyncConfig - - -def test_postgres_serial_transpilation() -> None: - """Test that Postgres SERIAL type is handled correctly.""" - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute( - """ - CREATE TABLE serial_test ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL - ) - """ - ) - session.execute("INSERT INTO serial_test (id, name) VALUES ($1, $2)", 1, "Test") - - result = session.select_one("SELECT * FROM serial_test WHERE id = $1", 1) - assert result is not None - assert result["name"] == "Test" - - -def test_postgres_dollar_params() -> None: - """Test Postgres $1, $2, etc. parameter style.""" - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute("CREATE TABLE dollar_params (a TEXT, b TEXT, c TEXT)") - session.execute("INSERT INTO dollar_params VALUES ($1, $2, $3)", "x", "y", "z") - - result = session.select_one("SELECT * FROM dollar_params WHERE a = $1 AND b = $2", "x", "y") - assert result is not None - assert result["c"] == "z" - - -def test_mysql_percent_params() -> None: - """Test MySQL %s parameter style.""" - config = MockSyncConfig(target_dialect="mysql") - - with config.provide_session() as session: - session.execute("CREATE TABLE percent_params (a TEXT, b TEXT)") - session.execute("INSERT INTO percent_params VALUES (%s, %s)", "hello", "world") - - result = session.select_one("SELECT * FROM percent_params WHERE a = %s", "hello") - assert result is not None - assert result["b"] == "world" - - -def test_sqlite_qmark_params() -> None: - """Test SQLite ? parameter style (native, no transpilation).""" - config = MockSyncConfig(target_dialect="sqlite") - - with config.provide_session() as session: - session.execute("CREATE TABLE qmark_params (a TEXT, b INTEGER)") - session.execute("INSERT INTO qmark_params VALUES (?, ?)", "test", 123) - - result = session.select_one("SELECT * FROM qmark_params WHERE a = ?", "test") - assert result is not None - assert result["b"] == 123 - - -def test_postgres_boolean_type() -> None: - """Test that Postgres BOOLEAN type works.""" - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute("DROP TABLE IF EXISTS bool_test") - session.execute( - """ - CREATE TABLE bool_test ( - id INTEGER PRIMARY KEY, - active INTEGER - ) - """ - ) - session.execute("INSERT INTO bool_test VALUES ($1, $2)", 1, 1) - session.execute("INSERT INTO bool_test VALUES ($1, $2)", 2, 0) - - result = session.select("SELECT * FROM bool_test WHERE active = $1", 1) - assert len(result) == 1 - assert result[0]["id"] == 1 - - -def test_varchar_type_handling() -> None: - """Test VARCHAR type is handled across dialects.""" - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute( - """ - CREATE TABLE varchar_test ( - short_text VARCHAR(50), - long_text VARCHAR(255) - ) - """ - ) - session.execute("INSERT INTO varchar_test VALUES ($1, $2)", "short", "a" * 200) - - result = session.select_one("SELECT * FROM varchar_test") - assert result is not None - assert len(result["long_text"]) == 200 - - -def test_numeric_decimal_handling() -> None: - """Test NUMERIC/DECIMAL type handling.""" - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute( - """ - CREATE TABLE numeric_test ( - id INTEGER, - price NUMERIC(10, 2) - ) - """ - ) - session.execute("INSERT INTO numeric_test VALUES ($1, $2)", 1, "19.99") - - result = session.select_one("SELECT * FROM numeric_test WHERE id = $1", 1) - assert result is not None - - -def test_initial_sql_with_postgres_dialect() -> None: - """Test initial SQL executed with Postgres dialect transpilation.""" - config = MockSyncConfig( - target_dialect="postgres", - initial_sql=[ - "CREATE TABLE init_test (id INTEGER PRIMARY KEY, name VARCHAR(100))", - "INSERT INTO init_test VALUES (1, 'InitValue')", - ], - ) - - with config.provide_session() as session: - result = session.select_one("SELECT * FROM init_test WHERE id = $1", 1) - assert result is not None - assert result["name"] == "InitValue" - - -def test_initial_sql_with_mysql_dialect() -> None: - """Test initial SQL executed with MySQL dialect transpilation.""" - config = MockSyncConfig( - target_dialect="mysql", - initial_sql=["CREATE TABLE mysql_init (id INT, value TEXT)", "INSERT INTO mysql_init VALUES (1, 'MySQLValue')"], - ) - - with config.provide_session() as session: - result = session.select_one("SELECT * FROM mysql_init WHERE id = %s", 1) - assert result is not None - assert result["value"] == "MySQLValue" - - -@pytest.mark.anyio -async def test_async_postgres_transpilation() -> None: - """Test async driver with Postgres dialect transpilation.""" - config = MockAsyncConfig(target_dialect="postgres") - - async with config.provide_session() as session: - await session.execute("CREATE TABLE async_pg_test (id INTEGER, name TEXT)") - await session.execute("INSERT INTO async_pg_test VALUES ($1, $2)", 1, "AsyncPG") - - result = await session.select_one("SELECT * FROM async_pg_test WHERE id = $1", 1) - assert result is not None - assert result["name"] == "AsyncPG" - - -@pytest.mark.anyio -async def test_async_initial_sql_transpilation() -> None: - """Test async driver with initial SQL transpilation.""" - config = MockAsyncConfig( - target_dialect="postgres", - initial_sql=[ - "CREATE TABLE async_init (id INTEGER, data TEXT)", - "INSERT INTO async_init VALUES (1, 'AsyncInit')", - ], - ) - - async with config.provide_session() as session: - result = await session.select_one("SELECT * FROM async_init WHERE id = $1", 1) - assert result is not None - assert result["data"] == "AsyncInit" - - -def test_multiple_statements_in_session() -> None: - """Test multiple SQL statements in the same session.""" - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute("CREATE TABLE multi1 (id INTEGER)") - session.execute("CREATE TABLE multi2 (id INTEGER)") - session.execute("INSERT INTO multi1 VALUES ($1)", 1) - session.execute("INSERT INTO multi2 VALUES ($1)", 2) - - r1 = session.select_value("SELECT id FROM multi1") - r2 = session.select_value("SELECT id FROM multi2") - - assert r1 == 1 - assert r2 == 2 - - -def test_join_query_transpilation() -> None: - """Test JOIN queries work with transpilation.""" - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute("CREATE TABLE orders (id INTEGER, customer_id INTEGER)") - session.execute("CREATE TABLE customers (id INTEGER, name TEXT)") - session.execute("INSERT INTO customers VALUES ($1, $2)", 1, "Alice") - session.execute("INSERT INTO orders VALUES ($1, $2)", 100, 1) - - result = session.select_one( - """ - SELECT o.id as order_id, c.name as customer_name - FROM orders o - JOIN customers c ON o.customer_id = c.id - WHERE o.id = $1 - """, - 100, - ) - assert result is not None - assert result["customer_name"] == "Alice" diff --git a/tests/unit/adapters/test_mock/test_driver.py b/tests/unit/adapters/test_mock/test_driver.py deleted file mode 100644 index 3a87a3dff..000000000 --- a/tests/unit/adapters/test_mock/test_driver.py +++ /dev/null @@ -1,233 +0,0 @@ -"""Unit tests for mock driver.""" - -import pytest - -from sqlspec.adapters.mock import MockAsyncConfig, MockSyncConfig - - -def test_mock_sync_driver_basic_operations() -> None: - """Test basic sync driver operations.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)") - session.execute("INSERT INTO users (id, name) VALUES (?, ?)", 1, "Alice") - - result = session.select("SELECT * FROM users") - assert len(result) == 1 - assert result[0]["name"] == "Alice" - - -def test_mock_sync_driver_with_initial_sql() -> None: - """Test sync driver with initial SQL setup.""" - config = MockSyncConfig( - initial_sql=[ - "CREATE TABLE items (id INTEGER, name TEXT)", - "INSERT INTO items VALUES (1, 'Widget')", - "INSERT INTO items VALUES (2, 'Gadget')", - ] - ) - - with config.provide_session() as session: - result = session.select("SELECT * FROM items ORDER BY id") - assert len(result) == 2 - assert result[0]["name"] == "Widget" - assert result[1]["name"] == "Gadget" - - -def test_mock_sync_driver_postgres_dialect() -> None: - """Test sync driver with Postgres dialect transpilation.""" - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute( - """ - CREATE TABLE products ( - id INTEGER PRIMARY KEY, - name VARCHAR(100), - price NUMERIC(10, 2) - ) - """ - ) - session.execute("INSERT INTO products (id, name, price) VALUES ($1, $2, $3)", 1, "Widget", 19.99) - - result = session.select_one("SELECT * FROM products WHERE id = $1", 1) - assert result is not None - assert result["name"] == "Widget" - - -def test_mock_sync_driver_mysql_dialect() -> None: - """Test sync driver with MySQL dialect transpilation.""" - config = MockSyncConfig(target_dialect="mysql") - - with config.provide_session() as session: - session.execute( - """ - CREATE TABLE orders ( - id INTEGER PRIMARY KEY, - customer TEXT, - total DECIMAL(10, 2) - ) - """ - ) - session.execute("INSERT INTO orders (id, customer, total) VALUES (%s, %s, %s)", 1, "Bob", 99.99) - - result = session.select_one("SELECT * FROM orders WHERE customer = %s", "Bob") - assert result is not None - assert result["id"] == 1 - - -def test_mock_sync_driver_select_value() -> None: - """Test select_value and select_value_or_none.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE counts (n INTEGER)") - session.execute("INSERT INTO counts VALUES (42)") - - value = session.select_value("SELECT n FROM counts") - assert value == 42 - - none_value = session.select_value_or_none("SELECT n FROM counts WHERE n = ?", 999) - assert none_value is None - - -def test_mock_sync_driver_execute_many() -> None: - """Test execute_many for batch inserts.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE batch (id INTEGER, value TEXT)") - session.execute_many("INSERT INTO batch VALUES (?, ?)", [(1, "a"), (2, "b"), (3, "c")]) - - result = session.select("SELECT * FROM batch ORDER BY id") - assert len(result) == 3 - - -def test_mock_sync_driver_transaction_commit() -> None: - """Test transaction commit.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE tx_test (id INTEGER)") - session.begin() - session.execute("INSERT INTO tx_test VALUES (1)") - session.commit() - - result = session.select("SELECT * FROM tx_test") - assert len(result) == 1 - - -def test_mock_sync_driver_transaction_rollback() -> None: - """Test transaction rollback.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE rb_test (id INTEGER)") - session.execute("INSERT INTO rb_test VALUES (1)") - session.commit() - - session.begin() - session.execute("INSERT INTO rb_test VALUES (2)") - session.rollback() - - result = session.select("SELECT * FROM rb_test") - assert len(result) == 1 - - -@pytest.mark.anyio -async def test_mock_async_driver_basic_operations() -> None: - """Test basic async driver operations.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - await session.execute("CREATE TABLE async_users (id INTEGER PRIMARY KEY, name TEXT)") - await session.execute("INSERT INTO async_users (id, name) VALUES (?, ?)", 1, "Charlie") - - result = await session.select("SELECT * FROM async_users") - assert len(result) == 1 - assert result[0]["name"] == "Charlie" - - -@pytest.mark.anyio -async def test_mock_async_driver_with_initial_sql() -> None: - """Test async driver with initial SQL setup.""" - config = MockAsyncConfig( - initial_sql=[ - "CREATE TABLE async_items (id INTEGER, name TEXT)", - "INSERT INTO async_items VALUES (1, 'AsyncWidget')", - ] - ) - - async with config.provide_session() as session: - result = await session.select("SELECT * FROM async_items") - assert len(result) == 1 - assert result[0]["name"] == "AsyncWidget" - - -@pytest.mark.anyio -async def test_mock_async_driver_postgres_dialect() -> None: - """Test async driver with Postgres dialect transpilation.""" - config = MockAsyncConfig(target_dialect="postgres") - - async with config.provide_session() as session: - await session.execute("CREATE TABLE async_products (id INTEGER PRIMARY KEY, name TEXT)") - await session.execute("INSERT INTO async_products (id, name) VALUES ($1, $2)", 1, "AsyncProduct") - - result = await session.select_one("SELECT * FROM async_products WHERE id = $1", 1) - assert result is not None - assert result["name"] == "AsyncProduct" - - -@pytest.mark.anyio -async def test_mock_async_driver_transaction_operations() -> None: - """Test async transaction operations.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - await session.execute("CREATE TABLE async_tx (id INTEGER)") - await session.begin() - await session.execute("INSERT INTO async_tx VALUES (1)") - await session.commit() - - result = await session.select("SELECT * FROM async_tx") - assert len(result) == 1 - - -def test_mock_config_target_dialect_property() -> None: - """Test target_dialect property.""" - config = MockSyncConfig(target_dialect="postgres") - assert config.target_dialect == "postgres" - - config2 = MockSyncConfig() - assert config2.target_dialect == "sqlite" - - -def test_mock_config_initial_sql_property() -> None: - """Test initial_sql property.""" - sql = ["CREATE TABLE test (id INTEGER)"] - config = MockSyncConfig(initial_sql=sql) - assert config.initial_sql == sql - - config2 = MockSyncConfig() - assert config2.initial_sql is None - - -def test_mock_config_signature_namespace() -> None: - """Test that signature namespace contains expected types.""" - config = MockSyncConfig() - namespace = config.get_signature_namespace() - - assert "MockSyncDriver" in namespace - assert "MockConnection" in namespace - assert "MockSyncConfig" in namespace - - -def test_mock_async_config_signature_namespace() -> None: - """Test that async signature namespace contains expected types.""" - config = MockAsyncConfig() - namespace = config.get_signature_namespace() - - assert "MockAsyncDriver" in namespace - assert "MockConnection" in namespace - assert "MockAsyncConfig" in namespace diff --git a/tests/unit/adapters/test_mock/test_edge_cases.py b/tests/unit/adapters/test_mock/test_edge_cases.py deleted file mode 100644 index 3edc87d78..000000000 --- a/tests/unit/adapters/test_mock/test_edge_cases.py +++ /dev/null @@ -1,389 +0,0 @@ -"""Unit tests for edge cases and error handling in mock adapter.""" - -import pytest - -from sqlspec.adapters.mock import MockAsyncConfig, MockSyncConfig -from sqlspec.exceptions import ( - CheckViolationError, - ForeignKeyViolationError, - NotNullViolationError, - SQLParsingError, - SQLSpecError, - UniqueViolationError, -) - - -def test_empty_initial_sql_list() -> None: - """Test config with empty initial_sql list.""" - config = MockSyncConfig(initial_sql=[]) - - with config.provide_session() as session: - result = session.select_value("SELECT 42") - assert result == 42 - - -def test_unicode_data() -> None: - """Test handling of Unicode data.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE unicode_test (text TEXT)") - session.execute("INSERT INTO unicode_test VALUES (?)", "Hello 世界 🌍") - - result = session.select_value("SELECT text FROM unicode_test") - assert result == "Hello 世界 🌍" - - -def test_large_text_data() -> None: - """Test handling of large text values.""" - config = MockSyncConfig() - large_text = "x" * 10000 - - with config.provide_session() as session: - session.execute("CREATE TABLE large_text (data TEXT)") - session.execute("INSERT INTO large_text VALUES (?)", large_text) - - result = session.select_value("SELECT data FROM large_text") - assert len(result) == 10000 - - -def test_null_parameter_values() -> None: - """Test handling of None parameter values.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE nullable (id INTEGER, value TEXT)") - session.execute("INSERT INTO nullable VALUES (?, ?)", 1, None) - - result = session.select_one("SELECT * FROM nullable WHERE id = ?", 1) - assert result is not None - assert result["value"] is None - - -def test_empty_result_set() -> None: - """Test handling of empty result sets.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE empty_test (id INTEGER)") - - result = session.select("SELECT * FROM empty_test") - assert result == [] - - value = session.select_value_or_none("SELECT id FROM empty_test") - assert value is None - - -def test_multiple_parameter_types() -> None: - """Test handling of multiple parameter types in single query.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE mixed_types (a INTEGER, b REAL, c TEXT, d BLOB)") - session.execute("INSERT INTO mixed_types VALUES (?, ?, ?, ?)", 42, 3.14, "text", b"bytes") - - result = session.select_one("SELECT * FROM mixed_types") - assert result is not None - assert result["a"] == 42 - assert abs(result["b"] - 3.14) < 0.01 - assert result["c"] == "text" - - -def test_transaction_without_begin() -> None: - """Test commit/rollback without explicit begin.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE auto_tx (id INTEGER)") - session.execute("INSERT INTO auto_tx VALUES (1)") - session.commit() - - result = session.select("SELECT * FROM auto_tx") - assert len(result) == 1 - - -def test_nested_transaction_detection() -> None: - """Test that connection properly detects transaction state.""" - config = MockSyncConfig() - - with config.provide_session() as session: - assert not session._connection_in_transaction() # type: ignore[attr-defined] - - session.begin() - assert session._connection_in_transaction() # type: ignore[attr-defined] - - session.commit() - assert not session._connection_in_transaction() # type: ignore[attr-defined] - - -def test_exception_unique_violation() -> None: - """Test unique constraint violation error mapping.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE unique_test (id INTEGER UNIQUE)") - session.execute("INSERT INTO unique_test VALUES (1)") - - with pytest.raises(UniqueViolationError): - session.execute("INSERT INTO unique_test VALUES (1)") - - -def test_exception_not_null_violation() -> None: - """Test not null constraint violation error mapping.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE not_null_test (id INTEGER NOT NULL)") - - with pytest.raises(NotNullViolationError): - session.execute("INSERT INTO not_null_test VALUES (NULL)") - - -def test_exception_check_violation() -> None: - """Test check constraint violation error mapping.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE check_test (age INTEGER CHECK(age >= 0))") - - with pytest.raises(CheckViolationError): - session.execute("INSERT INTO check_test VALUES (-5)") - - -def test_exception_foreign_key_violation() -> None: - """Test foreign key constraint violation error mapping.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("PRAGMA foreign_keys = ON") - session.execute("CREATE TABLE fk_parent (id INTEGER PRIMARY KEY)") - session.execute( - "CREATE TABLE fk_child (id INTEGER, parent_id INTEGER, FOREIGN KEY(parent_id) REFERENCES fk_parent(id))" - ) - - with pytest.raises(ForeignKeyViolationError): - session.execute("INSERT INTO fk_child VALUES (1, 999)") - - -def test_exception_syntax_error() -> None: - """Test SQL syntax error mapping.""" - config = MockSyncConfig() - - with config.provide_session() as session: - with pytest.raises(SQLParsingError): - session.execute("INVALID SQL STATEMENT") - - -def test_execute_many_empty_list_raises() -> None: - """Test execute_many with empty list raises error.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE batch_test (id INTEGER)") - - with pytest.raises(ValueError, match="execute_many requires parameters"): - session.execute_many("INSERT INTO batch_test VALUES (?)", []) - - -def test_execute_many_none_raises() -> None: - """Test execute_many with None raises error.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE batch_test2 (id INTEGER)") - - with pytest.raises(SQLSpecError, match="Parameter count mismatch"): - session.execute_many("INSERT INTO batch_test2 VALUES (?)", None) # type: ignore[arg-type] - - -def test_execute_many_single_row() -> None: - """Test execute_many with single row.""" - config = MockSyncConfig() - - with config.provide_session() as session: - session.execute("CREATE TABLE single_batch (id INTEGER)") - session.execute_many("INSERT INTO single_batch VALUES (?)", [(42,)]) - - result = session.select("SELECT * FROM single_batch") - assert len(result) == 1 - assert result[0]["id"] == 42 - - -def test_connection_context_manager() -> None: - """Test connection context manager cleanup.""" - config = MockSyncConfig() - - with config.provide_connection() as conn: - assert conn is not None - cursor = conn.cursor() - cursor.execute("SELECT 1") - result = cursor.fetchone() - assert result[0] == 1 - - -def test_driver_transpilation_sqlite_native() -> None: - """Test that SQLite dialect skips transpilation.""" - config = MockSyncConfig(target_dialect="sqlite") - - with config.provide_session() as session: - session.execute("CREATE TABLE native (id INTEGER)") - session.execute("INSERT INTO native VALUES (?)", 1) - - result = session.select_value("SELECT id FROM native WHERE id = ?", 1) - assert result == 1 - - -def test_postgres_dialect_with_complex_query() -> None: - """Test Postgres dialect with subquery and aggregation.""" - config = MockSyncConfig(target_dialect="postgres") - - with config.provide_session() as session: - session.execute("CREATE TABLE sales (id INTEGER, amount REAL, region TEXT)") - session.execute("INSERT INTO sales VALUES ($1, $2, $3)", 1, 100.0, "North") - session.execute("INSERT INTO sales VALUES ($1, $2, $3)", 2, 200.0, "South") - session.execute("INSERT INTO sales VALUES ($1, $2, $3)", 3, 150.0, "North") - - result = session.select_one( - """ - SELECT region, SUM(amount) as total - FROM sales - WHERE region = $1 - GROUP BY region - """, - "North", - ) - assert result is not None - assert result["total"] == 250.0 - - -def test_mysql_dialect_with_case_statement() -> None: - """Test MySQL dialect with CASE statement.""" - config = MockSyncConfig(target_dialect="mysql") - - with config.provide_session() as session: - session.execute("CREATE TABLE status_test (id INTEGER, value INTEGER)") - session.execute("INSERT INTO status_test VALUES (%s, %s)", 1, 10) - session.execute("INSERT INTO status_test VALUES (%s, %s)", 2, 50) - session.execute("INSERT INTO status_test VALUES (%s, %s)", 3, 100) - - results = session.select( - """ - SELECT id, - CASE - WHEN value < 25 THEN 'low' - WHEN value < 75 THEN 'medium' - ELSE 'high' - END as category - FROM status_test - ORDER BY id - """ - ) - assert len(results) == 3 - assert results[0]["category"] == "low" - assert results[1]["category"] == "medium" - assert results[2]["category"] == "high" - - -@pytest.mark.anyio -async def test_async_empty_result_set() -> None: - """Test async handling of empty result sets.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - await session.execute("CREATE TABLE async_empty (id INTEGER)") - - result = await session.select("SELECT * FROM async_empty") - assert result == [] - - -@pytest.mark.anyio -async def test_async_null_parameters() -> None: - """Test async handling of None parameter values.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - await session.execute("CREATE TABLE async_nullable (id INTEGER, value TEXT)") - await session.execute("INSERT INTO async_nullable VALUES (?, ?)", 1, None) - - result = await session.select_one("SELECT * FROM async_nullable WHERE id = ?", 1) - assert result is not None - assert result["value"] is None - - -@pytest.mark.anyio -async def test_async_transaction_state() -> None: - """Test async transaction state detection.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - assert not session._connection_in_transaction() # type: ignore[attr-defined] - - await session.begin() - assert session._connection_in_transaction() # type: ignore[attr-defined] - - await session.commit() - assert not session._connection_in_transaction() # type: ignore[attr-defined] - - -@pytest.mark.anyio -async def test_async_exception_unique_violation() -> None: - """Test async unique constraint violation error mapping.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - await session.execute("CREATE TABLE async_unique (id INTEGER UNIQUE)") - await session.execute("INSERT INTO async_unique VALUES (1)") - - with pytest.raises(UniqueViolationError): - await session.execute("INSERT INTO async_unique VALUES (1)") - - -@pytest.mark.anyio -async def test_async_exception_syntax_error() -> None: - """Test async SQL syntax error mapping.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - with pytest.raises(SQLParsingError): - await session.execute("INVALID ASYNC SQL") - - -@pytest.mark.anyio -async def test_async_connection_context_manager() -> None: - """Test async connection context manager cleanup.""" - config = MockAsyncConfig() - - async with config.provide_connection() as conn: - assert conn is not None - cursor = conn.cursor() - cursor.execute("SELECT 1") - result = cursor.fetchone() - assert result[0] == 1 - - -@pytest.mark.anyio -async def test_async_execute_many() -> None: - """Test async execute_many operation.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - await session.execute("CREATE TABLE async_batch (id INTEGER, name TEXT)") - await session.execute_many("INSERT INTO async_batch VALUES (?, ?)", [(1, "a"), (2, "b"), (3, "c")]) - - result = await session.select("SELECT * FROM async_batch ORDER BY id") - assert len(result) == 3 - assert result[0]["name"] == "a" - assert result[2]["name"] == "c" - - -@pytest.mark.anyio -async def test_async_unicode_data() -> None: - """Test async handling of Unicode data.""" - config = MockAsyncConfig() - - async with config.provide_session() as session: - await session.execute("CREATE TABLE async_unicode (text TEXT)") - await session.execute("INSERT INTO async_unicode VALUES (?)", "Hello 世界 🌍") - - result = await session.select_value("SELECT text FROM async_unicode") - assert result == "Hello 世界 🌍" From d9175c6a6a4d9355abfa693f214a5d22cbf2dddd Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 22 Apr 2026 18:57:12 +0000 Subject: [PATCH 4/8] test(adapters): remove mock fixtures from unit adapter conftest --- tests/unit/adapters/conftest.py | 571 +------------------------------- 1 file changed, 3 insertions(+), 568 deletions(-) diff --git a/tests/unit/adapters/conftest.py b/tests/unit/adapters/conftest.py index abdaa7898..df554bcaf 100644 --- a/tests/unit/adapters/conftest.py +++ b/tests/unit/adapters/conftest.py @@ -1,581 +1,16 @@ """Shared fixtures for adapter testing.""" -from contextlib import asynccontextmanager, contextmanager -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING import pytest -from typing_extensions import Self from sqlspec.core import SQL, ParameterStyle, ParameterStyleConfig, StatementConfig -from sqlspec.driver import ( - AsyncDataDictionaryBase, - AsyncDriverAdapterBase, - ExecutionResult, - SyncDataDictionaryBase, - SyncDriverAdapterBase, -) -from sqlspec.exceptions import SQLSpecError -from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo -from tests.conftest import is_compiled if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator + pass -class MockSyncExceptionHandler: - """Mock sync exception handler for testing. - - Implements the SyncExceptionHandler protocol with deferred exception pattern. - """ - - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - if exc_type is None: - return False - if isinstance(exc_val, Exception): - self.pending_exception = SQLSpecError(f"Mock database error: {exc_val}") - return True - return False - - -class MockAsyncExceptionHandler: - """Mock async exception handler for testing. - - Implements the AsyncExceptionHandler protocol with deferred exception pattern. - """ - - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - if exc_type is None: - return False - if isinstance(exc_val, Exception): - self.pending_exception = SQLSpecError(f"Mock async database error: {exc_val}") - return True - return False - - -__all__ = ( - "MockAsyncConnection", - "MockAsyncCursor", - "MockAsyncDriver", - "MockSyncConnection", - "MockSyncCursor", - "MockSyncDriver", - "mock_async_connection", - "mock_async_driver", - "mock_sync_connection", - "mock_sync_driver", - "sample_sql_statement", - "sample_statement_config", -) - - -class MockSyncConnection: - """Mock sync connection for testing.""" - - def __init__(self, name: str = "mock_sync_connection") -> None: - self.name = name - self.connected = True - self.in_transaction = False - self.cursor_results: list[dict[str, Any]] = [] - self.execute_count = 0 - self.execute_many_count = 0 - self.last_sql: str | None = None - self.last_parameters: Any = None - - def cursor(self) -> "MockSyncCursor": - """Return a mock cursor.""" - return MockSyncCursor(self) - - def execute(self, sql: str, parameters: Any = None) -> None: - """Mock execute method.""" - self.execute_count += 1 - self.last_sql = sql - self.last_parameters = parameters - - def commit(self) -> None: - """Mock commit method.""" - self.in_transaction = False - - def rollback(self) -> None: - """Mock rollback method.""" - self.in_transaction = False - - def close(self) -> None: - """Mock close method.""" - self.connected = False - - -class MockAsyncConnection: - """Mock async connection for testing.""" - - def __init__(self, name: str = "mock_async_connection") -> None: - self.name = name - self.connected = True - self.in_transaction = False - self.cursor_results: list[dict[str, Any]] = [] - self.execute_count = 0 - self.execute_many_count = 0 - self.last_sql: str | None = None - self.last_parameters: Any = None - - async def cursor(self) -> "MockAsyncCursor": - """Return a mock async cursor.""" - return MockAsyncCursor(self) - - async def execute(self, sql: str, parameters: Any = None) -> None: - """Mock async execute method.""" - self.execute_count += 1 - self.last_sql = sql - self.last_parameters = parameters - - async def commit(self) -> None: - """Mock async commit method.""" - self.in_transaction = False - - async def rollback(self) -> None: - """Mock async rollback method.""" - self.in_transaction = False - - async def close(self) -> None: - """Mock async close method.""" - self.connected = False - - -class MockSyncCursor: - """Mock sync cursor for testing.""" - - def __init__(self, connection: MockSyncConnection) -> None: - self.connection = connection - self.rowcount = 0 - self.description: list[tuple[str, ...]] | None = None - self.fetchall_result: list[tuple[Any, ...]] = [] - self.closed = False - - def execute(self, sql: str, parameters: Any = None) -> None: - """Mock execute method.""" - self.connection.execute_count += 1 - self.connection.last_sql = sql - self.connection.last_parameters = parameters - - if sql.upper().strip().startswith("SELECT"): - self.description = [("id", "INTEGER"), ("name", "TEXT")] - - self.fetchall_result = [(1, "test"), (2, "example")] - self.rowcount = len(self.fetchall_result) - else: - self.description = None - self.fetchall_result = [] - self.rowcount = 1 - - def executemany(self, sql: str, parameters: "list[Any]") -> None: - """Mock executemany method.""" - self.connection.execute_many_count += 1 - self.connection.last_sql = sql - self.connection.last_parameters = parameters - self.rowcount = len(parameters) if parameters else 0 - - def fetchall(self) -> "list[tuple[Any, ...]]": - """Mock fetchall method.""" - return self.fetchall_result - - def close(self) -> None: - """Mock close method.""" - self.closed = True - - -class MockAsyncCursor: - """Mock async cursor for testing.""" - - def __init__(self, connection: MockAsyncConnection) -> None: - self.connection = connection - self.rowcount = 0 - self.description: list[tuple[str, ...]] | None = None - self.fetchall_result: list[tuple[Any, ...]] = [] - self.closed = False - - async def execute(self, sql: str, parameters: Any = None) -> None: - """Mock async execute method.""" - self.connection.execute_count += 1 - self.connection.last_sql = sql - self.connection.last_parameters = parameters - - if sql.upper().strip().startswith("SELECT"): - self.description = [("id", "INTEGER"), ("name", "TEXT")] - - self.fetchall_result = [(1, "test"), (2, "example")] - self.rowcount = len(self.fetchall_result) - else: - self.description = None - self.fetchall_result = [] - self.rowcount = 1 - - async def executemany(self, sql: str, parameters: "list[Any]") -> None: - """Mock async executemany method.""" - self.connection.execute_many_count += 1 - self.connection.last_sql = sql - self.connection.last_parameters = parameters - self.rowcount = len(parameters) if parameters else 0 - - async def fetchall(self) -> "list[tuple[Any, ...]]": - """Mock async fetchall method.""" - return self.fetchall_result - - async def close(self) -> None: - """Mock async close method.""" - self.closed = True - - async def __aenter__(self) -> Self: - """Async context manager entry.""" - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - """Async context manager exit.""" - await self.close() - - -class MockSyncDataDictionary(SyncDataDictionaryBase): - """Mock sync data dictionary for testing.""" - - def get_version(self, driver: "MockSyncDriver") -> "VersionInfo | None": - """Return mock version info.""" - return VersionInfo(3, 42, 0) - - def get_feature_flag(self, driver: "MockSyncDriver", feature: str) -> bool: - """Return mock feature flag.""" - return feature in {"supports_transactions", "supports_prepared_statements"} - - def get_optimal_type(self, driver: "MockSyncDriver", type_category: str) -> str: - """Return mock optimal type.""" - return {"text": "TEXT", "boolean": "INTEGER"}.get(type_category, "TEXT") - - def get_tables(self, driver: "MockSyncDriver", schema: "str | None" = None) -> "list[TableMetadata]": - """Return mock table list.""" - _ = (driver, schema) - return [] - - def get_columns( - self, driver: "MockSyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ColumnMetadata]": - """Return mock column metadata.""" - _ = (driver, table, schema) - return [] - - def get_indexes( - self, driver: "MockSyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[IndexMetadata]": - """Return mock index metadata.""" - _ = (driver, table, schema) - return [] - - def get_foreign_keys( - self, driver: "MockSyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ForeignKeyMetadata]": - """Return mock foreign key metadata.""" - _ = (driver, table, schema) - return [] - - def list_available_features(self) -> "list[str]": - """Return mock available features.""" - return ["supports_transactions", "supports_prepared_statements"] - - -class MockAsyncDataDictionary(AsyncDataDictionaryBase): - """Mock async data dictionary for testing.""" - - async def get_version(self, driver: "MockAsyncDriver") -> "VersionInfo | None": - """Return mock version info.""" - return VersionInfo(3, 42, 0) - - async def get_feature_flag(self, driver: "MockAsyncDriver", feature: str) -> bool: - """Return mock feature flag.""" - return feature in {"supports_transactions", "supports_prepared_statements"} - - async def get_optimal_type(self, driver: "MockAsyncDriver", type_category: str) -> str: - """Return mock optimal type.""" - return {"text": "TEXT", "boolean": "INTEGER"}.get(type_category, "TEXT") - - async def get_tables(self, driver: "MockAsyncDriver", schema: "str | None" = None) -> "list[TableMetadata]": - """Return mock table list.""" - _ = (driver, schema) - return [] - - async def get_columns( - self, driver: "MockAsyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ColumnMetadata]": - """Return mock column metadata.""" - _ = (driver, table, schema) - return [] - - async def get_indexes( - self, driver: "MockAsyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[IndexMetadata]": - """Return mock index metadata.""" - _ = (driver, table, schema) - return [] - - async def get_foreign_keys( - self, driver: "MockAsyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ForeignKeyMetadata]": - """Return mock foreign key metadata.""" - _ = (driver, table, schema) - return [] - - def list_available_features(self) -> "list[str]": - """Return mock available features.""" - return ["supports_transactions", "supports_prepared_statements"] - - -class MockSyncDriver(SyncDriverAdapterBase): - """Mock sync driver for testing.""" - - dialect = "sqlite" - - def __init__( - self, - connection: MockSyncConnection, - statement_config: StatementConfig | None = None, - driver_features: Optional["dict[str, Any]"] = None, - ) -> None: - if statement_config is None: - statement_config = StatementConfig( - dialect="sqlite", - enable_caching=False, - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} - ), - ) - super().__init__(connection, statement_config, driver_features) - self._data_dictionary: SyncDataDictionaryBase | None = None - - @property - def data_dictionary(self) -> "SyncDataDictionaryBase": - """Get the data dictionary for this driver.""" - if self._data_dictionary is None: - self._data_dictionary = MockSyncDataDictionary() - return self._data_dictionary - - @contextmanager - def with_cursor(self, connection: MockSyncConnection) -> "Generator[MockSyncCursor, None, None]": - """Return mock cursor context manager.""" - cursor = connection.cursor() - try: - yield cursor - finally: - cursor.close() - - def handle_database_exceptions(self) -> "MockSyncExceptionHandler": - """Handle database exceptions.""" - return MockSyncExceptionHandler() - - def dispatch_special_handling(self, cursor: MockSyncCursor, statement: SQL) -> Any | None: - """Mock special handling - always return None.""" - return None - - def dispatch_execute(self, cursor: MockSyncCursor, statement: SQL) -> ExecutionResult: - """Mock execute statement.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - cursor.execute(sql, prepared_parameters) - - if statement.returns_rows(): - fetched_data = cursor.fetchall() - column_names = [col[0] for col in cursor.description or []] - - return self.create_execution_result( - cursor, - selected_data=fetched_data, - column_names=column_names, - data_row_count=len(fetched_data), - is_select_result=True, - row_format="tuple", - ) - - return self.create_execution_result(cursor, rowcount_override=cursor.rowcount) - - def dispatch_execute_many(self, cursor: MockSyncCursor, statement: SQL) -> ExecutionResult: - """Mock execute many.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - - if not prepared_parameters: - msg = "execute_many requires parameters" - raise ValueError(msg) - - parameter_sets = cast("list[Any]", prepared_parameters) - cursor.executemany(sql, parameter_sets) - return self.create_execution_result(cursor, rowcount_override=cursor.rowcount, is_many_result=True) - - def dispatch_execute_script(self, cursor: MockSyncCursor, statement: SQL) -> ExecutionResult: - """Mock execute script.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True) - - successful_count = 0 - for stmt in statements: - cursor.execute(stmt, prepared_parameters or ()) - successful_count += 1 - - return self.create_execution_result( - cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True - ) - - def begin(self) -> None: - """Mock begin transaction.""" - self.connection.in_transaction = True - - def rollback(self) -> None: - """Mock rollback transaction.""" - self.connection.rollback() - - def commit(self) -> None: - """Mock commit transaction.""" - self.connection.commit() - - -class MockAsyncDriver(AsyncDriverAdapterBase): - """Mock async driver for testing.""" - - dialect = "sqlite" - - def __init__( - self, - connection: MockAsyncConnection, - statement_config: StatementConfig | None = None, - driver_features: Optional["dict[str, Any]"] = None, - ) -> None: - if statement_config is None: - statement_config = StatementConfig( - dialect="sqlite", - enable_caching=False, - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} - ), - ) - super().__init__(connection, statement_config, driver_features) - self._data_dictionary: AsyncDataDictionaryBase | None = None - - @property - def data_dictionary(self) -> "AsyncDataDictionaryBase": - """Get the data dictionary for this driver.""" - if self._data_dictionary is None: - self._data_dictionary = MockAsyncDataDictionary() - return self._data_dictionary - - @asynccontextmanager - async def with_cursor(self, connection: MockAsyncConnection) -> "AsyncGenerator[MockAsyncCursor, None]": - """Return mock async cursor context manager.""" - cursor = await connection.cursor() - try: - yield cursor - finally: - await cursor.close() - - def handle_database_exceptions(self) -> "MockAsyncExceptionHandler": - """Handle database exceptions.""" - return MockAsyncExceptionHandler() - - async def dispatch_special_handling(self, cursor: MockAsyncCursor, statement: SQL) -> Any | None: - """Mock async special handling - always return None.""" - return None - - async def dispatch_execute(self, cursor: MockAsyncCursor, statement: SQL) -> ExecutionResult: - """Mock async execute statement.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - await cursor.execute(sql, prepared_parameters) - - if statement.returns_rows(): - fetched_data = await cursor.fetchall() - column_names = [col[0] for col in cursor.description or []] - - return self.create_execution_result( - cursor, - selected_data=fetched_data, - column_names=column_names, - data_row_count=len(fetched_data), - is_select_result=True, - row_format="tuple", - ) - - return self.create_execution_result(cursor, rowcount_override=cursor.rowcount) - - async def dispatch_execute_many(self, cursor: MockAsyncCursor, statement: SQL) -> ExecutionResult: - """Mock async execute many.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - - if not prepared_parameters: - msg = "execute_many requires parameters" - raise ValueError(msg) - - parameter_sets = cast("list[Any]", prepared_parameters) - await cursor.executemany(sql, parameter_sets) - return self.create_execution_result(cursor, rowcount_override=cursor.rowcount, is_many_result=True) - - async def dispatch_execute_script(self, cursor: MockAsyncCursor, statement: SQL) -> ExecutionResult: - """Mock async execute script.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True) - - successful_count = 0 - for stmt in statements: - await cursor.execute(stmt, prepared_parameters or ()) - successful_count += 1 - - return self.create_execution_result( - cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True - ) - - async def begin(self) -> None: - """Mock async begin transaction.""" - self.connection.in_transaction = True - - async def rollback(self) -> None: - """Mock async rollback transaction.""" - await self.connection.rollback() - - async def commit(self) -> None: - """Mock async commit transaction.""" - await self.connection.commit() - - -@pytest.fixture -def mock_sync_connection() -> MockSyncConnection: - """Fixture for mock sync connection.""" - return MockSyncConnection() - - -@pytest.fixture -def mock_async_connection() -> MockAsyncConnection: - """Fixture for mock async connection.""" - return MockAsyncConnection() - - -@pytest.fixture -def mock_sync_driver(mock_sync_connection: MockSyncConnection) -> MockSyncDriver: - """Fixture for mock sync driver.""" - if is_compiled(): - pytest.skip("Requires interpreted driver base") - return MockSyncDriver(mock_sync_connection) - - -@pytest.fixture -def mock_async_driver(mock_async_connection: MockAsyncConnection) -> MockAsyncDriver: - """Fixture for mock async driver.""" - if is_compiled(): - pytest.skip("Requires interpreted driver base") - return MockAsyncDriver(mock_async_connection) +__all__ = ("sample_sql_statement", "sample_statement_config") @pytest.fixture From bbc93b46656287df4994e605d625a3f0b4fc732c Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 22 Apr 2026 18:58:26 +0000 Subject: [PATCH 5/8] test(core): remove mock fixtures from root unit conftest --- tests/unit/conftest.py | 775 ++--------------------------------------- 1 file changed, 27 insertions(+), 748 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 8cd01da09..f4d7e120e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,14 +5,12 @@ """ import time -from collections import defaultdict from collections.abc import Callable -from contextlib import asynccontextmanager, contextmanager +from contextlib import contextmanager from decimal import Decimal -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import pytest -from typing_extensions import Self from sqlspec.core import ( SQL, @@ -23,74 +21,13 @@ TypedParameter, get_default_cache, ) -from sqlspec.driver import ( - AsyncDataDictionaryBase, - AsyncDriverAdapterBase, - ExecutionResult, - SyncDataDictionaryBase, - SyncDriverAdapterBase, -) -from sqlspec.exceptions import SQLSpecError -from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo -from tests.conftest import is_compiled +from sqlspec.driver import ExecutionResult if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator - - -class MockSyncExceptionHandler: - """Mock sync exception handler for testing. - - Implements the SyncExceptionHandler protocol with deferred exception pattern. - """ - - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - if exc_type is None: - return False - if isinstance(exc_val, Exception): - self.pending_exception = SQLSpecError(f"Mock database error: {exc_val}") - return True - return False - - -class MockAsyncExceptionHandler: - """Mock async exception handler for testing. - - Implements the AsyncExceptionHandler protocol with deferred exception pattern. - """ - - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - if exc_type is None: - return False - if isinstance(exc_val, Exception): - self.pending_exception = SQLSpecError(f"Mock async database error: {exc_val}") - return True - return False + from collections.abc import Generator __all__ = ( - "MockAsyncConnection", - "MockAsyncCursor", - "MockAsyncDriver", - "MockSyncConnection", - "MockSyncCursor", - "MockSyncDriver", "benchmark_tracker", "cache_config_disabled", "cache_config_enabled", @@ -99,15 +36,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: "compilation_metrics", "complex_sql_with_joins", "memory_profiler", - "mock_async_connection", - "mock_async_driver", - "mock_bigquery_connection", "mock_lru_cache", - "mock_mysql_connection", - "mock_postgres_connection", - "mock_sqlite_connection", - "mock_sync_connection", - "mock_sync_driver", "parameter_style_config_advanced", "parameter_style_config_basic", "performance_timer", @@ -203,61 +132,50 @@ def statement_config_postgres(parameter_style_config_advanced: ParameterStyleCon @pytest.fixture -def statement_config_mysql() -> StatementConfig: +def statement_config_mysql(parameter_style_config_basic: ParameterStyleConfig) -> StatementConfig: """MySQL statement configuration for testing.""" - mysql_config = ParameterStyleConfig( - default_parameter_style=ParameterStyle.POSITIONAL_PYFORMAT, - supported_parameter_styles={ParameterStyle.POSITIONAL_PYFORMAT}, - supported_execution_parameter_styles={ParameterStyle.POSITIONAL_PYFORMAT}, - default_execution_parameter_style=ParameterStyle.POSITIONAL_PYFORMAT, - ) return StatementConfig( - dialect="mysql", parameter_config=mysql_config, enable_caching=True, enable_parsing=True, enable_validation=True + dialect="mysql", + parameter_config=parameter_style_config_basic, + execution_mode=None, + execution_args=None, + enable_caching=True, + enable_parsing=True, + enable_validation=True, ) @pytest.fixture -def cache_config_enabled() -> dict[str, Any]: +def cache_config_enabled() -> LRUCache: """Cache configuration with caching enabled.""" - return { - "enable_caching": True, - "max_cache_size": 1000, - "enable_file_cache": True, - "enable_compiled_cache": True, - "cache_hit_threshold": 0.8, - } + return LRUCache(capacity=100) @pytest.fixture -def cache_config_disabled() -> dict[str, Any]: - """Cache configuration with caching disabled for testing scenarios.""" - return { - "enable_caching": False, - "max_cache_size": 0, - "enable_file_cache": False, - "enable_compiled_cache": False, - "cache_hit_threshold": 0.0, - } +def cache_config_disabled() -> None: + """Cache configuration with caching disabled.""" + return @pytest.fixture def mock_lru_cache() -> LRUCache: - """Mock LRU cache for testing cache behavior.""" - - return get_default_cache() + """Mock LRU cache for testing cache operations.""" + return LRUCache(capacity=10) @pytest.fixture -def cache_statistics_tracker() -> dict[str, Any]: - """Cache statistics tracker for monitoring cache performance during tests.""" - return {"hits": 0, "misses": 0, "evictions": 0, "cache_sizes": defaultdict(int), "hit_rates": []} +def cache_statistics_tracker() -> dict[str, int]: + """Tracker for cache hits and misses during tests.""" + return {"hits": 0, "misses": 0, "evictions": 0} -@pytest.fixture(autouse=True) +@pytest.fixture def reset_cache_state() -> "Generator[None, None, None]": - """Auto-use fixture to reset global cache state before each test.""" - + """Fixture to reset the global SQL cache before and after each test.""" + cache = get_default_cache() + cache.clear() yield + cache.clear() @pytest.fixture @@ -340,645 +258,6 @@ def sql_with_typed_parameters(statement_config_sqlite: StatementConfig) -> SQL: return SQL(sql, *params, statement_config=statement_config_sqlite) -class MockSyncConnection: - """Mock sync connection with database simulation.""" - - def __init__(self, name: str = "mock_sync_connection", dialect: str = "sqlite") -> None: - self.name = name - self.dialect = dialect - self.connected = True - self.in_transaction = False - self.autocommit = True - self.execute_count = 0 - self.execute_many_count = 0 - self.last_sql: str | None = None - self.last_parameters: Any = None - self.cursor_results: list[dict[str, Any]] = [] - self.connection_info = { - "server_version": "1.0.0", - "client_version": "1.0.0", - "database_name": "test_db", - "user": "test_user", - } - - def cursor(self) -> "MockSyncCursor": - """Return a mock cursor.""" - return MockSyncCursor(self) - - def execute(self, sql: str, parameters: Any = None) -> None: - """Mock execute method.""" - self.execute_count += 1 - self.last_sql = sql - self.last_parameters = parameters - - def executemany(self, sql: str, parameters: list[Any]) -> None: - """Mock executemany method.""" - self.execute_many_count += 1 - self.last_sql = sql - self.last_parameters = parameters - - def commit(self) -> None: - """Mock commit method.""" - self.in_transaction = False - - def rollback(self) -> None: - """Mock rollback method.""" - self.in_transaction = False - - def close(self) -> None: - """Mock close method.""" - self.connected = False - - def begin(self) -> None: - """Mock begin transaction method.""" - self.in_transaction = True - self.autocommit = False - - -class MockAsyncConnection: - """Mock async connection with database simulation.""" - - def __init__(self, name: str = "mock_async_connection", dialect: str = "sqlite") -> None: - self.name = name - self.dialect = dialect - self.connected = True - self.in_transaction = False - self.autocommit = True - self.execute_count = 0 - self.execute_many_count = 0 - self.last_sql: str | None = None - self.last_parameters: Any = None - self.cursor_results: list[dict[str, Any]] = [] - self.connection_info = { - "server_version": "1.0.0", - "client_version": "1.0.0", - "database_name": "test_db", - "user": "test_user", - } - - async def cursor(self) -> "MockAsyncCursor": - """Return a mock async cursor.""" - return MockAsyncCursor(self) - - async def execute(self, sql: str, parameters: Any = None) -> None: - """Mock async execute method.""" - self.execute_count += 1 - self.last_sql = sql - self.last_parameters = parameters - - async def executemany(self, sql: str, parameters: list[Any]) -> None: - """Mock async executemany method.""" - self.execute_many_count += 1 - self.last_sql = sql - self.last_parameters = parameters - - async def commit(self) -> None: - """Mock async commit method.""" - self.in_transaction = False - - async def rollback(self) -> None: - """Mock async rollback method.""" - self.in_transaction = False - - async def close(self) -> None: - """Mock async close method.""" - self.connected = False - - async def begin(self) -> None: - """Mock async begin transaction method.""" - self.in_transaction = True - self.autocommit = False - - -class MockSyncCursor: - """Mock sync cursor with database cursor behavior.""" - - def __init__(self, connection: MockSyncConnection) -> None: - self.connection = connection - self.rowcount = 0 - self.description: list[tuple[str, ...]] | None = None - self.fetchall_result: list[tuple[Any, ...]] = [] - self.fetchone_result: tuple[Any, ...] | None = None - self.closed = False - self.arraysize = 1 - - def execute(self, sql: str, parameters: Any = None) -> None: - """Mock execute method with smart result generation.""" - self.connection.execute_count += 1 - self.connection.last_sql = sql - self.connection.last_parameters = parameters - - sql_upper = sql.upper().strip() - if sql_upper.startswith("SELECT"): - self.description = [("id", "INTEGER"), ("name", "TEXT"), ("email", "TEXT")] - self.fetchall_result = [(1, "John Doe", "john@example.com"), (2, "Jane Smith", "jane@example.com")] - self.fetchone_result = (1, "John Doe", "john@example.com") - self.rowcount = len(self.fetchall_result) - elif sql_upper.startswith("INSERT"): - self.description = None - self.fetchall_result = [] - self.fetchone_result = None - self.rowcount = 1 - elif sql_upper.startswith("UPDATE"): - self.description = None - self.fetchall_result = [] - self.fetchone_result = None - self.rowcount = 2 - elif sql_upper.startswith("DELETE"): - self.description = None - self.fetchall_result = [] - self.fetchone_result = None - self.rowcount = 1 - else: - self.description = None - self.fetchall_result = [] - self.fetchone_result = None - self.rowcount = 0 - - def executemany(self, sql: str, parameters: list[Any]) -> None: - """Mock executemany method.""" - self.connection.execute_many_count += 1 - self.connection.last_sql = sql - self.connection.last_parameters = parameters - self.rowcount = len(parameters) if parameters else 0 - - def fetchall(self) -> list[tuple[Any, ...]]: - """Mock fetchall method.""" - return self.fetchall_result - - def fetchone(self) -> tuple[Any, ...] | None: - """Mock fetchone method.""" - return self.fetchone_result - - def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]: - """Mock fetchmany method.""" - size = size or self.arraysize - return self.fetchall_result[:size] - - def close(self) -> None: - """Mock close method.""" - self.closed = True - - def __enter__(self) -> Self: - """Context manager entry.""" - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - """Context manager exit.""" - self.close() - - -class MockAsyncCursor: - """Mock async cursor with database cursor behavior.""" - - def __init__(self, connection: MockAsyncConnection) -> None: - self.connection = connection - self.rowcount = 0 - self.description: list[tuple[str, ...]] | None = None - self.fetchall_result: list[tuple[Any, ...]] = [] - self.fetchone_result: tuple[Any, ...] | None = None - self.closed = False - self.arraysize = 1 - - async def execute(self, sql: str, parameters: Any = None) -> None: - """Mock async execute method with smart result generation.""" - self.connection.execute_count += 1 - self.connection.last_sql = sql - self.connection.last_parameters = parameters - - sql_upper = sql.upper().strip() - if sql_upper.startswith("SELECT"): - self.description = [("id", "INTEGER"), ("name", "TEXT"), ("email", "TEXT")] - self.fetchall_result = [(1, "John Doe", "john@example.com"), (2, "Jane Smith", "jane@example.com")] - self.fetchone_result = (1, "John Doe", "john@example.com") - self.rowcount = len(self.fetchall_result) - elif sql_upper.startswith("INSERT"): - self.description = None - self.fetchall_result = [] - self.fetchone_result = None - self.rowcount = 1 - elif sql_upper.startswith("UPDATE"): - self.description = None - self.fetchall_result = [] - self.fetchone_result = None - self.rowcount = 2 - elif sql_upper.startswith("DELETE"): - self.description = None - self.fetchall_result = [] - self.fetchone_result = None - self.rowcount = 1 - else: - self.description = None - self.fetchall_result = [] - self.fetchone_result = None - self.rowcount = 0 - - async def executemany(self, sql: str, parameters: list[Any]) -> None: - """Mock async executemany method.""" - self.connection.execute_many_count += 1 - self.connection.last_sql = sql - self.connection.last_parameters = parameters - self.rowcount = len(parameters) if parameters else 0 - - async def fetchall(self) -> list[tuple[Any, ...]]: - """Mock async fetchall method.""" - return self.fetchall_result - - async def fetchone(self) -> tuple[Any, ...] | None: - """Mock async fetchone method.""" - return self.fetchone_result - - async def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]: - """Mock async fetchmany method.""" - size = size or self.arraysize - return self.fetchall_result[:size] - - async def close(self) -> None: - """Mock async close method.""" - self.closed = True - - async def __aenter__(self) -> Self: - """Async context manager entry.""" - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - """Async context manager exit.""" - await self.close() - - -class MockSyncDataDictionary(SyncDataDictionaryBase): - """Mock sync data dictionary for testing.""" - - def get_version(self, driver: "MockSyncDriver") -> "VersionInfo | None": - """Return mock version info.""" - return VersionInfo(3, 42, 0) - - def get_feature_flag(self, driver: "MockSyncDriver", feature: str) -> bool: - """Return mock feature flag.""" - return feature in {"supports_transactions", "supports_prepared_statements"} - - def get_optimal_type(self, driver: "MockSyncDriver", type_category: str) -> str: - """Return mock optimal type.""" - return {"text": "TEXT", "boolean": "INTEGER"}.get(type_category, "TEXT") - - def get_tables(self, driver: "MockSyncDriver", schema: "str | None" = None) -> "list[TableMetadata]": - """Return mock table list.""" - _ = (driver, schema) - return [] - - def get_columns( - self, driver: "MockSyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ColumnMetadata]": - """Return mock column metadata.""" - _ = (driver, table, schema) - return [] - - def get_indexes( - self, driver: "MockSyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[IndexMetadata]": - """Return mock index metadata.""" - _ = (driver, table, schema) - return [] - - def get_foreign_keys( - self, driver: "MockSyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ForeignKeyMetadata]": - """Return mock foreign key metadata.""" - _ = (driver, table, schema) - return [] - - def list_available_features(self) -> "list[str]": - """Return mock available features.""" - return ["supports_transactions", "supports_prepared_statements"] - - -class MockAsyncDataDictionary(AsyncDataDictionaryBase): - """Mock async data dictionary for testing.""" - - async def get_version(self, driver: "MockAsyncDriver") -> "VersionInfo | None": - """Return mock version info.""" - return VersionInfo(3, 42, 0) - - async def get_feature_flag(self, driver: "MockAsyncDriver", feature: str) -> bool: - """Return mock feature flag.""" - return feature in {"supports_transactions", "supports_prepared_statements"} - - async def get_optimal_type(self, driver: "MockAsyncDriver", type_category: str) -> str: - """Return mock optimal type.""" - return {"text": "TEXT", "boolean": "INTEGER"}.get(type_category, "TEXT") - - async def get_tables(self, driver: "MockAsyncDriver", schema: "str | None" = None) -> "list[TableMetadata]": - """Return mock table list.""" - _ = (driver, schema) - return [] - - async def get_columns( - self, driver: "MockAsyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ColumnMetadata]": - """Return mock column metadata.""" - _ = (driver, table, schema) - return [] - - async def get_indexes( - self, driver: "MockAsyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[IndexMetadata]": - """Return mock index metadata.""" - _ = (driver, table, schema) - return [] - - async def get_foreign_keys( - self, driver: "MockAsyncDriver", table: "str | None" = None, schema: "str | None" = None - ) -> "list[ForeignKeyMetadata]": - """Return mock foreign key metadata.""" - _ = (driver, table, schema) - return [] - - def list_available_features(self) -> "list[str]": - """Return mock available features.""" - return ["supports_transactions", "supports_prepared_statements"] - - -class MockSyncDriver(SyncDriverAdapterBase): - """Mock sync driver with adapter interface.""" - - dialect = "sqlite" - - def __init__( - self, - connection: MockSyncConnection, - statement_config: StatementConfig | None = None, - driver_features: dict[str, Any] | None = None, - ) -> None: - if statement_config is None: - from sqlspec.core import ParameterStyleConfig - - parameter_config = ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK}, - supported_execution_parameter_styles={ParameterStyle.QMARK}, - default_execution_parameter_style=ParameterStyle.QMARK, - ) - statement_config = StatementConfig( - dialect="sqlite", parameter_config=parameter_config, enable_caching=False - ) - super().__init__(connection, statement_config, driver_features) - self._data_dictionary: SyncDataDictionaryBase | None = None - - @property - def data_dictionary(self) -> "SyncDataDictionaryBase": - """Get the data dictionary for this driver.""" - if self._data_dictionary is None: - self._data_dictionary = MockSyncDataDictionary() - return self._data_dictionary - - @contextmanager - def with_cursor(self, connection: MockSyncConnection) -> "Generator[MockSyncCursor, None, None]": - """Return mock cursor context manager.""" - cursor = connection.cursor() - try: - yield cursor - finally: - cursor.close() - - def handle_database_exceptions(self) -> "MockSyncExceptionHandler": - """Handle database exceptions.""" - return MockSyncExceptionHandler() - - def dispatch_special_handling(self, cursor: MockSyncCursor, statement: SQL) -> Any | None: - """Mock special handling - always return None.""" - return None - - def dispatch_execute(self, cursor: MockSyncCursor, statement: SQL) -> ExecutionResult: - """Mock execute statement.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - cursor.execute(sql, prepared_parameters) - - if statement.returns_rows(): - fetched_data = cursor.fetchall() - column_names = [col[0] for col in cursor.description or []] - data = [dict(zip(column_names, row)) for row in fetched_data] - - return self.create_execution_result( - cursor, selected_data=data, column_names=column_names, data_row_count=len(data), is_select_result=True - ) - - return self.create_execution_result(cursor, rowcount_override=cursor.rowcount) - - def dispatch_execute_many(self, cursor: MockSyncCursor, statement: SQL) -> ExecutionResult: - """Mock execute many.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - - if not prepared_parameters: - msg = "execute_many requires parameters" - raise ValueError(msg) - - parameter_sets = cast("list[Any]", prepared_parameters) - cursor.executemany(sql, parameter_sets) - return self.create_execution_result(cursor, rowcount_override=cursor.rowcount, is_many_result=True) - - def dispatch_execute_script(self, cursor: MockSyncCursor, statement: SQL) -> ExecutionResult: - """Mock execute script.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True) - - successful_count = 0 - for stmt in statements: - cursor.execute(stmt, prepared_parameters or ()) - successful_count += 1 - - return self.create_execution_result( - cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True - ) - - def begin(self) -> None: - """Mock begin transaction.""" - self.connection.begin() - - def rollback(self) -> None: - """Mock rollback transaction.""" - self.connection.rollback() - - def commit(self) -> None: - """Mock commit transaction.""" - self.connection.commit() - - def _connection_in_transaction(self) -> bool: - """Check if connection is in transaction.""" - return bool(self.connection.in_transaction) - - -class MockAsyncDriver(AsyncDriverAdapterBase): - """Mock async driver with adapter interface.""" - - dialect = "sqlite" - - def __init__( - self, - connection: MockAsyncConnection, - statement_config: StatementConfig | None = None, - driver_features: dict[str, Any] | None = None, - ) -> None: - if statement_config is None: - from sqlspec.core import ParameterStyleConfig - - parameter_config = ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK}, - supported_execution_parameter_styles={ParameterStyle.QMARK}, - default_execution_parameter_style=ParameterStyle.QMARK, - ) - statement_config = StatementConfig( - dialect="sqlite", parameter_config=parameter_config, enable_caching=False - ) - super().__init__(connection, statement_config, driver_features) - self._data_dictionary: AsyncDataDictionaryBase | None = None - - @property - def data_dictionary(self) -> "AsyncDataDictionaryBase": - """Get the data dictionary for this driver.""" - if self._data_dictionary is None: - self._data_dictionary = MockAsyncDataDictionary() - return self._data_dictionary - - @asynccontextmanager - async def with_cursor(self, connection: MockAsyncConnection) -> "AsyncGenerator[MockAsyncCursor, None]": - """Return mock async cursor context manager.""" - cursor = await connection.cursor() - try: - yield cursor - finally: - await cursor.close() - - def handle_database_exceptions(self) -> "MockAsyncExceptionHandler": - """Handle database exceptions.""" - return MockAsyncExceptionHandler() - - async def dispatch_special_handling(self, cursor: MockAsyncCursor, statement: SQL) -> Any | None: - """Mock async special handling - always return None.""" - return None - - async def dispatch_execute(self, cursor: MockAsyncCursor, statement: SQL) -> ExecutionResult: - """Mock async execute statement.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - await cursor.execute(sql, prepared_parameters) - - if statement.returns_rows(): - fetched_data = await cursor.fetchall() - column_names = [col[0] for col in cursor.description or []] - data = [dict(zip(column_names, row)) for row in fetched_data] - - return self.create_execution_result( - cursor, selected_data=data, column_names=column_names, data_row_count=len(data), is_select_result=True - ) - - return self.create_execution_result(cursor, rowcount_override=cursor.rowcount) - - async def dispatch_execute_many(self, cursor: MockAsyncCursor, statement: SQL) -> ExecutionResult: - """Mock async execute many.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - - if not prepared_parameters: - msg = "execute_many requires parameters" - raise ValueError(msg) - - parameter_sets = cast("list[Any]", prepared_parameters) - await cursor.executemany(sql, parameter_sets) - return self.create_execution_result(cursor, rowcount_override=cursor.rowcount, is_many_result=True) - - async def dispatch_execute_script(self, cursor: MockAsyncCursor, statement: SQL) -> ExecutionResult: - """Mock async execute script.""" - sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) - statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True) - - successful_count = 0 - for stmt in statements: - await cursor.execute(stmt, prepared_parameters or ()) - successful_count += 1 - - return self.create_execution_result( - cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True - ) - - async def begin(self) -> None: - """Mock async begin transaction.""" - await self.connection.begin() - - async def rollback(self) -> None: - """Mock async rollback transaction.""" - await self.connection.rollback() - - async def commit(self) -> None: - """Mock async commit transaction.""" - await self.connection.commit() - - def _connection_in_transaction(self) -> bool: - """Check if connection is in transaction.""" - return bool(self.connection.in_transaction) - - -@pytest.fixture -def mock_sync_connection() -> MockSyncConnection: - """Fixture for basic mock sync connection.""" - return MockSyncConnection() - - -@pytest.fixture -def mock_async_connection() -> MockAsyncConnection: - """Fixture for basic mock async connection.""" - return MockAsyncConnection() - - -@pytest.fixture -def mock_sync_driver(mock_sync_connection: MockSyncConnection) -> MockSyncDriver: - """Fixture for mock sync driver.""" - if is_compiled(): - pytest.skip("Requires interpreted driver base") - return MockSyncDriver(mock_sync_connection) - - -@pytest.fixture -def mock_async_driver(mock_async_connection: MockAsyncConnection) -> MockAsyncDriver: - """Fixture for mock async driver.""" - if is_compiled(): - pytest.skip("Requires interpreted driver base") - return MockAsyncDriver(mock_async_connection) - - -@pytest.fixture -def mock_sqlite_connection() -> MockSyncConnection: - """Mock SQLite connection with SQLite-specific behavior.""" - return MockSyncConnection("sqlite_connection", "sqlite") - - -@pytest.fixture -def mock_postgres_connection() -> MockAsyncConnection: - """Mock PostgreSQL connection with Postgres-specific behavior.""" - conn = MockAsyncConnection("postgres_connection", "postgres") - conn.connection_info.update({"server_version": "14.0", "supports_returning": "True", "supports_arrays": "True"}) - return conn - - -@pytest.fixture -def mock_mysql_connection() -> MockSyncConnection: - """Mock MySQL connection with MySQL-specific behavior.""" - conn = MockSyncConnection("mysql_connection", "mysql") - conn.connection_info.update({"server_version": "8.0.0", "supports_json": "True", "charset": "utf8mb4"}) - return conn - - -@pytest.fixture -def mock_bigquery_connection() -> MockSyncConnection: - """Mock BigQuery connection with BigQuery-specific behavior.""" - conn = MockSyncConnection("bigquery_connection", "bigquery") - conn.connection_info.update({ - "project_id": "test-project", - "dataset_id": "test_dataset", - "supports_arrays": "True", - "supports_structs": "True", - }) - return conn - - @pytest.fixture(autouse=True) def test_isolation() -> "Generator[None, None, None]": """Auto-use fixture to ensure test isolation by resetting global state.""" From 5291275961f8e36f52ad550f6e767d49abc2558c Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 22 Apr 2026 18:59:13 +0000 Subject: [PATCH 6/8] test(core): remove mock-skipping logic from root conftest --- tests/conftest.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b3fabd251..9c380f5d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -134,8 +134,6 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item ): item.add_marker(skip_compiled) continue - if {"mock_sync_driver", "mock_async_driver"} & set(getattr(item, "fixturenames", ())): - item.add_marker(skip_compiled) @pytest.fixture From d094cc3f024214ff655f084169a46da8c1ce292c Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 22 Apr 2026 21:16:28 +0000 Subject: [PATCH 7/8] test(adapters): add SQLite sync and async driver fixtures for testing --- docs/examples/patterns/mock_testing.py | 30 -- docs/usage/index.rst | 2 +- docs/usage/testing.rst | 50 +- tests/unit/adapters/test_async_adapters.py | 498 +++++++----------- tests/unit/adapters/test_sync_adapters.py | 446 +++++++--------- tests/unit/conftest.py | 49 +- tests/unit/core/test_parameters.py | 2 - tests/unit/driver/test_execute_script.py | 10 +- tests/unit/driver/test_query_cache.py | 374 ++++--------- tests/unit/driver/test_stack_base.py | 60 ++- .../unit/exceptions/test_exception_handler.py | 4 - .../unit/migrations/test_migration_context.py | 13 +- 12 files changed, 561 insertions(+), 977 deletions(-) delete mode 100644 docs/examples/patterns/mock_testing.py diff --git a/docs/examples/patterns/mock_testing.py b/docs/examples/patterns/mock_testing.py deleted file mode 100644 index 231e8b547..000000000 --- a/docs/examples/patterns/mock_testing.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -__all__ = ("test_mock_testing",) - - -def test_mock_testing() -> None: - # start-example - from sqlspec import SQLSpec - from sqlspec.adapters.mock import MockSyncConfig - - # MockSyncConfig uses SQLite :memory: internally - # but can transpile SQL from other dialects - config = MockSyncConfig(target_dialect="postgres") - spec = SQLSpec() - spec.add_config(config) - - with spec.provide_session(config) as session: - # Write SQL in PostgreSQL dialect - it gets transpiled to SQLite - session.execute("CREATE TABLE users ( id INTEGER PRIMARY KEY, name VARCHAR(100) NOT NULL)") - session.execute("INSERT INTO users (name) VALUES ('Alice')") - - users = session.select("SELECT name FROM users") - print(users) # [{"name": "Alice"}] - - count = session.select_value("SELECT COUNT(*) FROM users") - print(count) # 1 - # end-example - - assert users == [{"name": "Alice"}] - assert count == 1 diff --git a/docs/usage/index.rst b/docs/usage/index.rst index 3ed24dcd6..d19c9c552 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -76,7 +76,7 @@ Choose a topic :link: testing :link-type: doc - Unit test with mock adapters and integration test patterns. + Unit and integration test patterns. .. grid-item-card:: ETL & Data Pipelines :link: etl diff --git a/docs/usage/testing.rst b/docs/usage/testing.rst index 4aaeb85ad..108d4acd9 100644 --- a/docs/usage/testing.rst +++ b/docs/usage/testing.rst @@ -3,42 +3,6 @@ Testing SQLSpec provides tools for both unit and integration testing of database code. -Mock Adapter for Unit Tests ---------------------------- - -``MockSyncConfig`` and ``MockAsyncConfig`` use an in-memory SQLite backend with -optional dialect transpilation. Write SQL in your production dialect (PostgreSQL, -MySQL, Oracle) and it gets transpiled to SQLite before execution. - -.. literalinclude:: /examples/patterns/mock_testing.py - :language: python - :caption: ``mock adapter for unit tests`` - :start-after: # start-example - :end-before: # end-example - :dedent: 4 - :no-upgrade: - -Key features: - -- ``target_dialect`` accepts ``"postgres"``, ``"mysql"``, ``"oracle"``, or ``"sqlite"`` -- SQL is automatically transpiled to SQLite for execution -- No external database required -- runs entirely in-memory -- Supports ``initial_sql`` parameter for schema setup on connection create - -Integration Test Patterns -------------------------- - -For integration tests against real databases, use the standard ``SQLSpec`` + -adapter config pattern with temporary databases. - -.. literalinclude:: /examples/patterns/integration_testing.py - :language: python - :caption: ``integration test fixtures`` - :start-after: # start-example - :end-before: # end-example - :dedent: 4 - :no-upgrade: - Pytest Fixture Tips ------------------- @@ -64,6 +28,20 @@ Pytest Fixture Tips session.execute("create table users (id integer primary key, name text)") yield session +Integration Test Patterns +------------------------- + +For integration tests against real databases, use the standard ``SQLSpec`` + +adapter config pattern with temporary databases. + +.. literalinclude:: /examples/patterns/integration_testing.py + :language: python + :caption: ``integration test fixtures`` + :start-after: # start-example + :end-before: # end-example + :dedent: 4 + :no-upgrade: + Related Guides -------------- diff --git a/tests/unit/adapters/test_async_adapters.py b/tests/unit/adapters/test_async_adapters.py index 10eeccef0..7fe67e25d 100644 --- a/tests/unit/adapters/test_async_adapters.py +++ b/tests/unit/adapters/test_async_adapters.py @@ -1,34 +1,39 @@ # pyright: reportPrivateImportUsage = false, reportPrivateUsage = false """Tests for asynchronous database adapters.""" +import asyncio from typing import Any -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, patch +import aiosqlite import pytest +from sqlspec.adapters.aiosqlite import AiosqliteDriver from sqlspec.core import SQL, ParameterStyle, ParameterStyleConfig, SQLResult, StatementConfig, get_default_config from sqlspec.driver import ExecutionResult from sqlspec.exceptions import NotFoundError, SQLSpecError from sqlspec.typing import Empty -from tests.unit.adapters.conftest import MockAsyncConnection, MockAsyncCursor, MockAsyncDriver pytestmark = pytest.mark.xdist_group("adapter_unit") __all__ = () -async def test_async_driver_initialization(mock_async_connection: MockAsyncConnection) -> None: +async def test_async_driver_initialization() -> None: """Test basic async driver initialization.""" - driver = MockAsyncDriver(mock_async_connection) + conn = await aiosqlite.connect(":memory:") + driver = AiosqliteDriver(conn) - assert driver.connection is mock_async_connection + assert driver.connection is conn assert driver.dialect == "sqlite" assert driver.statement_config.dialect == "sqlite" assert driver.statement_config.parameter_config.default_parameter_style == ParameterStyle.QMARK + await conn.close() -async def test_async_driver_with_custom_config(mock_async_connection: MockAsyncConnection) -> None: +async def test_async_driver_with_custom_config() -> None: """Test async driver initialization with custom statement config.""" + conn = await aiosqlite.connect(":memory:") custom_config = StatementConfig( dialect="postgresql", parameter_config=ParameterStyleConfig( @@ -36,48 +41,39 @@ async def test_async_driver_with_custom_config(mock_async_connection: MockAsyncC ), ) - driver = MockAsyncDriver(mock_async_connection, custom_config) + driver = AiosqliteDriver(conn, custom_config) assert driver.statement_config.dialect == "postgresql" assert driver.statement_config.parameter_config.default_parameter_style == ParameterStyle.NUMERIC + await conn.close() -async def test_async_driver_with_cursor(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_with_cursor(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async cursor context manager functionality.""" - async with mock_async_driver.with_cursor(mock_async_driver.connection) as cursor: - assert hasattr(cursor, "connection") + async with aiosqlite_async_driver.with_cursor(aiosqlite_async_driver.connection) as cursor: assert hasattr(cursor, "execute") assert hasattr(cursor, "fetchall") - assert cursor.connection is mock_async_driver.connection -async def test_async_driver_database_exception_handling(mock_async_driver: MockAsyncDriver) -> None: - """Test async database exception handling with deferred exception pattern. - - The deferred pattern stores exceptions in `pending_exception` instead of - raising from `__aexit__`, allowing compiled code to raise safely. - """ - exc_handler = mock_async_driver.handle_database_exceptions() +async def test_async_driver_database_exception_handling(aiosqlite_async_driver: AiosqliteDriver) -> None: + """Test async database exception handling with deferred exception pattern.""" + exc_handler = aiosqlite_async_driver.handle_database_exceptions() async with exc_handler: pass assert exc_handler.pending_exception is None - exc_handler = mock_async_driver.handle_database_exceptions() + exc_handler = aiosqlite_async_driver.handle_database_exceptions() async with exc_handler: - raise ValueError("Test async error") + raise aiosqlite.Error("Test async error") assert exc_handler.pending_exception is not None assert isinstance(exc_handler.pending_exception, SQLSpecError) - assert "Mock async database error" in str(exc_handler.pending_exception) - - with pytest.raises(SQLSpecError, match="Mock async database error"): - raise exc_handler.pending_exception -async def test_async_driverdispatch_execute_select(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_dispatch_execute_select(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async dispatch_execute method with SELECT query.""" - statement = SQL("SELECT id, name FROM users", statement_config=mock_async_driver.statement_config) - async with mock_async_driver.with_cursor(mock_async_driver.connection) as cursor: - result = await mock_async_driver.dispatch_execute(cursor, statement) + statement = SQL("SELECT id, name FROM users", statement_config=aiosqlite_async_driver.statement_config) + async with aiosqlite_async_driver.with_cursor(aiosqlite_async_driver.connection) as cursor: + result = await aiosqlite_async_driver.dispatch_execute(cursor, statement) assert isinstance(result, ExecutionResult) assert result.is_select_result is True @@ -88,71 +84,65 @@ async def test_async_driverdispatch_execute_select(mock_async_driver: MockAsyncD assert result.data_row_count == 2 -async def test_async_driverdispatch_execute_insert(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_dispatch_execute_insert(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async dispatch_execute method with INSERT query.""" - statement = SQL("INSERT INTO users (name) VALUES (?)", "test", statement_config=mock_async_driver.statement_config) + statement = SQL( + "INSERT INTO users (name) VALUES (?)", "new_user", statement_config=aiosqlite_async_driver.statement_config + ) - async with mock_async_driver.with_cursor(mock_async_driver.connection) as cursor: - result = await mock_async_driver.dispatch_execute(cursor, statement) + async with aiosqlite_async_driver.with_cursor(aiosqlite_async_driver.connection) as cursor: + result = await aiosqlite_async_driver.dispatch_execute(cursor, statement) assert isinstance(result, ExecutionResult) assert result.is_select_result is False - assert result.is_script_result is False - assert result.is_many_result is False assert result.rowcount_override == 1 - assert result.selected_data is None -async def test_async_driver_execute_many(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_execute_many(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async dispatch_execute_many method.""" statement = SQL( "INSERT INTO users (name) VALUES (?)", [["alice"], ["bob"], ["charlie"]], - statement_config=mock_async_driver.statement_config, + statement_config=aiosqlite_async_driver.statement_config, is_many=True, ) - async with mock_async_driver.with_cursor(mock_async_driver.connection) as cursor: - result = await mock_async_driver.dispatch_execute_many(cursor, statement) + async with aiosqlite_async_driver.with_cursor(aiosqlite_async_driver.connection) as cursor: + result = await aiosqlite_async_driver.dispatch_execute_many(cursor, statement) assert isinstance(result, ExecutionResult) assert result.is_many_result is True - assert result.is_select_result is False - assert result.is_script_result is False assert result.rowcount_override == 3 - assert mock_async_driver.connection.execute_many_count == 1 -async def test_async_driver_execute_many_no_parameters(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_execute_many_no_parameters(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async _execute_many method fails without parameters.""" statement = SQL( - "INSERT INTO users (name) VALUES (?)", statement_config=mock_async_driver.statement_config, is_many=True + "INSERT INTO users (name) VALUES (?)", statement_config=aiosqlite_async_driver.statement_config, is_many=True ) - async with mock_async_driver.with_cursor(mock_async_driver.connection) as cursor: + async with aiosqlite_async_driver.with_cursor(aiosqlite_async_driver.connection) as cursor: with pytest.raises(ValueError, match="execute_many requires parameters"): - await mock_async_driver.dispatch_execute_many(cursor, statement) + await aiosqlite_async_driver.dispatch_execute_many(cursor, statement) -async def test_async_driver_execute_script(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_execute_script(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async dispatch_execute_script method.""" script = """ INSERT INTO users (name) VALUES ('alice'); INSERT INTO users (name) VALUES ('bob'); - UPDATE users SET active = 1; + UPDATE users SET name = 'updated'; """ - statement = SQL(script, statement_config=mock_async_driver.statement_config, is_script=True) - async with mock_async_driver.with_cursor(mock_async_driver.connection) as cursor: - result = await mock_async_driver.dispatch_execute_script(cursor, statement) + statement = SQL(script, statement_config=aiosqlite_async_driver.statement_config, is_script=True) + async with aiosqlite_async_driver.with_cursor(aiosqlite_async_driver.connection) as cursor: + result = await aiosqlite_async_driver.dispatch_execute_script(cursor, statement) assert isinstance(result, ExecutionResult) assert result.is_script_result is True - assert result.is_select_result is False - assert result.is_many_result is False assert result.statement_count == 3 assert result.successful_statements == 3 -async def test_async_driver_dispatch_statement_execution_select(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_dispatch_statement_execution_select(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async dispatch_statement_execution with SELECT statement.""" - statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) + statement = SQL("SELECT * FROM users", statement_config=aiosqlite_async_driver.statement_config) - result = await mock_async_driver.dispatch_statement_execution(statement, mock_async_driver.connection) + result = await aiosqlite_async_driver.dispatch_statement_execution(statement, aiosqlite_async_driver.connection) assert isinstance(result, SQLResult) assert result.operation_type == "SELECT" @@ -161,24 +151,25 @@ async def test_async_driver_dispatch_statement_execution_select(mock_async_drive assert result.get_data()[0]["name"] == "test" -async def test_async_driver_dispatch_statement_execution_insert(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_dispatch_statement_execution_insert(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async dispatch_statement_execution with INSERT statement.""" - statement = SQL("INSERT INTO users (name) VALUES (?)", "test", statement_config=mock_async_driver.statement_config) + statement = SQL( + "INSERT INTO users (name) VALUES (?)", "new_user", statement_config=aiosqlite_async_driver.statement_config + ) - result = await mock_async_driver.dispatch_statement_execution(statement, mock_async_driver.connection) + result = await aiosqlite_async_driver.dispatch_statement_execution(statement, aiosqlite_async_driver.connection) assert isinstance(result, SQLResult) assert result.operation_type == "INSERT" assert result.rows_affected == 1 - assert len(result.get_data()) == 0 -async def test_async_driver_dispatch_statement_execution_script(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_dispatch_statement_execution_script(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async dispatch_statement_execution with script.""" script = "INSERT INTO users (name) VALUES ('alice'); INSERT INTO users (name) VALUES ('bob');" - statement = SQL(script, statement_config=mock_async_driver.statement_config, is_script=True) + statement = SQL(script, statement_config=aiosqlite_async_driver.statement_config, is_script=True) - result = await mock_async_driver.dispatch_statement_execution(statement, mock_async_driver.connection) + result = await aiosqlite_async_driver.dispatch_statement_execution(statement, aiosqlite_async_driver.connection) assert isinstance(result, SQLResult) assert result.operation_type == "SCRIPT" @@ -186,33 +177,35 @@ async def test_async_driver_dispatch_statement_execution_script(mock_async_drive assert result.successful_statements == 2 -async def test_async_driver_dispatch_statement_execution_many(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_dispatch_statement_execution_many(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async dispatch_statement_execution with execute_many.""" statement = SQL( "INSERT INTO users (name) VALUES (?)", [["alice"], ["bob"]], - statement_config=mock_async_driver.statement_config, + statement_config=aiosqlite_async_driver.statement_config, is_many=True, ) - result = await mock_async_driver.dispatch_statement_execution(statement, mock_async_driver.connection) + result = await aiosqlite_async_driver.dispatch_statement_execution(statement, aiosqlite_async_driver.connection) assert isinstance(result, SQLResult) assert result.operation_type == "INSERT" assert result.rows_affected == 2 -async def test_async_driver_releases_pooled_statement(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_releases_pooled_statement(aiosqlite_async_driver: AiosqliteDriver) -> None: """Pooled statements should be reset after dispatch execution.""" seed = "SELECT * FROM users WHERE id = ?" - mock_async_driver.prepare_statement(seed, (1,), statement_config=mock_async_driver.statement_config, kwargs={}) - pooled = mock_async_driver.prepare_statement( - seed, (2,), statement_config=mock_async_driver.statement_config, kwargs={} + aiosqlite_async_driver.prepare_statement( + seed, (1,), statement_config=aiosqlite_async_driver.statement_config, kwargs={} + ) + pooled = aiosqlite_async_driver.prepare_statement( + seed, (2,), statement_config=aiosqlite_async_driver.statement_config, kwargs={} ) assert pooled._pooled is True - await mock_async_driver.dispatch_statement_execution(pooled, mock_async_driver.connection) + await aiosqlite_async_driver.dispatch_statement_execution(pooled, aiosqlite_async_driver.connection) assert pooled._raw_sql == "" assert pooled._processed_state is Empty @@ -220,45 +213,46 @@ async def test_async_driver_releases_pooled_statement(mock_async_driver: MockAsy assert pooled._statement_config is get_default_config() -async def test_async_driver_transaction_management(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_transaction_management(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async transaction management methods.""" - connection = mock_async_driver.connection + await aiosqlite_async_driver.begin() + await aiosqlite_async_driver.execute("INSERT INTO users (name) VALUES ('trans')") + await aiosqlite_async_driver.commit() - await mock_async_driver.begin() - assert connection.in_transaction is True + res = await aiosqlite_async_driver.select_value("SELECT COUNT(*) FROM users WHERE name = 'trans'") + assert res == 1 - await mock_async_driver.commit() - assert connection.in_transaction is False + await aiosqlite_async_driver.begin() + await aiosqlite_async_driver.execute("INSERT INTO users (name) VALUES ('rolledback')") + await aiosqlite_async_driver.rollback() - await mock_async_driver.begin() - assert connection.in_transaction is True - await mock_async_driver.rollback() - assert connection.in_transaction is False + res = await aiosqlite_async_driver.select_value("SELECT COUNT(*) FROM users WHERE name = 'rolledback'") + assert res == 0 -async def test_async_driver_execute_method(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_execute_method(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test high-level async execute method.""" - result = await mock_async_driver.execute("SELECT * FROM users WHERE id = ?", 1) + result = await aiosqlite_async_driver.execute("SELECT * FROM users WHERE id = ?", 1) assert isinstance(result, SQLResult) assert result.operation_type == "SELECT" - assert len(result.get_data()) == 2 + assert len(result.get_data()) == 1 -async def test_async_driver_execute_many_method(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_execute_many_method(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test high-level async execute_many method.""" parameters = [["alice"], ["bob"], ["charlie"]] - result = await mock_async_driver.execute_many("INSERT INTO users (name) VALUES (?)", parameters) + result = await aiosqlite_async_driver.execute_many("INSERT INTO users (name) VALUES (?)", parameters) assert isinstance(result, SQLResult) assert result.operation_type == "INSERT" assert result.rows_affected == 3 -async def test_async_driver_execute_script_method(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_execute_script_method(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test high-level async execute_script method.""" - script = "INSERT INTO users (name) VALUES ('alice'); UPDATE users SET active = 1;" - result = await mock_async_driver.execute_script(script) + script = "INSERT INTO users (name) VALUES ('alice'); UPDATE users SET name = 'updated';" + result = await aiosqlite_async_driver.execute_script(script) assert isinstance(result, SQLResult) assert result.operation_type == "SCRIPT" @@ -275,82 +269,48 @@ async def test_async_driver_execute_script_method(mock_async_driver: MockAsyncDr ], ) async def test_async_driver_execution_wrappers_reraise_deferred_database_errors( - mock_async_driver: MockAsyncDriver, method_name: str, call_args: tuple[Any, ...] + aiosqlite_async_driver: AiosqliteDriver, method_name: str, call_args: tuple[Any, ...] ) -> None: """Test wrapper methods re-raise mapped errors after the exception context exits.""" with patch.object( - mock_async_driver, + aiosqlite_async_driver, "dispatch_statement_execution", new_callable=AsyncMock, - side_effect=ValueError("Test async wrapper error"), + side_effect=aiosqlite.Error("Test async wrapper error"), ): - method = getattr(mock_async_driver, method_name) + method = getattr(aiosqlite_async_driver, method_name) - with pytest.raises(SQLSpecError, match="Mock async database error: Test async wrapper error"): + with pytest.raises(SQLSpecError): await method(*call_args) -async def test_async_driver_select_one(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_select_one(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async select_one method - expects error when multiple rows returned.""" with pytest.raises(ValueError, match="Multiple results found"): - await mock_async_driver.select_one("SELECT * FROM users WHERE id = ?", 1) + await aiosqlite_async_driver.select_one("SELECT * FROM users") -async def test_async_driver_select_one_no_results(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_select_one_no_results(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async select_one method with no results.""" - - with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.one.side_effect = ValueError("No result found, exactly one row expected") - mock_execute.return_value = mock_result - - with pytest.raises(NotFoundError, match="No rows found"): - await mock_async_driver.select_one("SELECT * FROM users WHERE id = ?", 999) - - -async def test_async_driver_select_one_multiple_results(mock_async_driver: MockAsyncDriver) -> None: - """Test async select_one method with multiple results.""" - - with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.one.side_effect = ValueError("Multiple results found (3), exactly one row expected") - mock_execute.return_value = mock_result - - with pytest.raises(ValueError, match="Multiple results found"): - await mock_async_driver.select_one("SELECT * FROM users") + with pytest.raises(NotFoundError, match="No rows found"): + await aiosqlite_async_driver.select_one("SELECT * FROM users WHERE id = ?", 999) -async def test_async_driver_select_one_or_none(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_select_one_or_none(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async select_one_or_none method - expects error when multiple rows returned.""" with pytest.raises(ValueError, match="Multiple results found"): - await mock_async_driver.select_one_or_none("SELECT * FROM users WHERE id = ?", 1) + await aiosqlite_async_driver.select_one_or_none("SELECT * FROM users") -async def test_async_driver_select_one_or_none_no_results(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_select_one_or_none_no_results(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async select_one_or_none method with no results.""" - with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.one_or_none.return_value = None - mock_execute.return_value = mock_result + result = await aiosqlite_async_driver.select_one_or_none("SELECT * FROM users WHERE id = ?", 999) + assert result is None - result = await mock_async_driver.select_one_or_none("SELECT * FROM users WHERE id = ?", 999) - assert result is None - -async def test_async_driver_select_one_or_none_multiple_results(mock_async_driver: MockAsyncDriver) -> None: - """Test async select_one_or_none method with multiple results.""" - with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.one_or_none.side_effect = ValueError("Multiple results found (2), at most one row expected") - mock_execute.return_value = mock_result - - with pytest.raises(ValueError, match="Multiple results found"): - await mock_async_driver.select_one_or_none("SELECT * FROM users") - - -async def test_async_driver_select(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_select(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async select method.""" - result: list[dict[str, Any]] = await mock_async_driver.select("SELECT * FROM users") + result: list[dict[str, Any]] = await aiosqlite_async_driver.select("SELECT * FROM users") assert isinstance(result, list) assert len(result) == 2 @@ -358,57 +318,33 @@ async def test_async_driver_select(mock_async_driver: MockAsyncDriver) -> None: assert result[1]["id"] == 2 -async def test_async_driver_select_value(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_select_value(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async select_value method.""" - - with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.scalar.return_value = 42 - mock_execute.return_value = mock_result - - result = await mock_async_driver.select_value("SELECT COUNT(*) as count FROM users") - assert result == 42 + result = await aiosqlite_async_driver.select_value("SELECT COUNT(*) FROM users") + assert result == 2 -async def test_async_driver_select_value_no_results(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_select_value_no_results(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async select_value method with no results.""" - with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.scalar.side_effect = ValueError("No result found, exactly one row expected") - mock_execute.return_value = mock_result + with pytest.raises(NotFoundError, match="No rows found"): + await aiosqlite_async_driver.select_value("SELECT id FROM users WHERE id = 999") - with pytest.raises(NotFoundError, match="No rows found"): - await mock_async_driver.select_value("SELECT COUNT(*) FROM users WHERE id = 999") - -async def test_async_driver_select_value_or_none(mock_async_driver: MockAsyncDriver) -> None: - """Test async select_value_or_none method - expects error when multiple rows returned.""" - with pytest.raises(ValueError, match="Multiple results found"): - await mock_async_driver.select_value_or_none("SELECT * FROM users WHERE id = ?", 1) - - -async def test_async_driver_select_value_or_none_no_results(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_select_value_or_none_no_results(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async select_value_or_none method with no results.""" - with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.scalar_or_none.return_value = None - mock_execute.return_value = mock_result - - result = await mock_async_driver.select_value_or_none("SELECT COUNT(*) FROM users WHERE id = 999") - assert result is None + result = await aiosqlite_async_driver.select_value_or_none("SELECT id FROM users WHERE id = 999") + assert result is None @pytest.mark.parametrize( "parameter_style,expected_style", [ pytest.param(ParameterStyle.QMARK, ParameterStyle.QMARK, id="qmark"), - pytest.param(ParameterStyle.NUMERIC, ParameterStyle.NUMERIC, id="numeric"), pytest.param(ParameterStyle.NAMED_COLON, ParameterStyle.NAMED_COLON, id="named_colon"), - pytest.param(ParameterStyle.NAMED_PYFORMAT, ParameterStyle.NAMED_PYFORMAT, id="pyformat_named"), ], ) async def test_async_driver_parameter_styles( - mock_async_connection: MockAsyncConnection, parameter_style: ParameterStyle, expected_style: ParameterStyle + aiosqlite_async_driver: AiosqliteDriver, parameter_style: ParameterStyle, expected_style: ParameterStyle ) -> None: """Test different parameter styles are handled correctly in async driver.""" config = StatementConfig( @@ -416,173 +352,110 @@ async def test_async_driver_parameter_styles( parameter_config=ParameterStyleConfig( default_parameter_style=parameter_style, supported_parameter_styles={parameter_style}, - default_execution_parameter_style=parameter_style, - supported_execution_parameter_styles={parameter_style}, + default_execution_parameter_style=ParameterStyle.QMARK, + supported_execution_parameter_styles={ParameterStyle.QMARK}, ), ) - driver = MockAsyncDriver(mock_async_connection, config) - assert driver.statement_config.parameter_config.default_parameter_style == expected_style + aiosqlite_async_driver.statement_config = config + assert aiosqlite_async_driver.statement_config.parameter_config.default_parameter_style == expected_style if parameter_style == ParameterStyle.QMARK: statement = SQL("SELECT * FROM users WHERE id = ?", 1, statement_config=config) - elif parameter_style == ParameterStyle.NUMERIC: - statement = SQL("SELECT * FROM users WHERE id = $1", 1, statement_config=config) - elif parameter_style == ParameterStyle.NAMED_COLON: - statement = SQL("SELECT * FROM users WHERE id = :id", {"id": 1}, statement_config=config) else: - statement = SQL("SELECT * FROM users WHERE id = %(id)s", {"id": 1}, statement_config=config) + statement = SQL("SELECT * FROM users WHERE id = :id", {"id": 1}, statement_config=config) - result = await driver.dispatch_statement_execution(statement, driver.connection) + result = await aiosqlite_async_driver.dispatch_statement_execution(statement, aiosqlite_async_driver.connection) assert isinstance(result, SQLResult) -@pytest.mark.parametrize("dialect", ["sqlite", "postgres", "mysql"]) -async def test_async_driver_different_dialects(mock_async_connection: MockAsyncConnection, dialect: str) -> None: +async def test_async_driver_different_dialects(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async driver works with different SQL dialects.""" config = StatementConfig( - dialect=dialect, + dialect="sqlite", parameter_config=ParameterStyleConfig( default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} ), ) - driver = MockAsyncDriver(mock_async_connection, config) - assert driver.statement_config.dialect == dialect - - result = await driver.execute("SELECT 1 as test") + aiosqlite_async_driver.statement_config = config + result = await aiosqlite_async_driver.execute("SELECT 1 as test") assert isinstance(result, SQLResult) -async def test_async_driver_create_execution_result(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_create_execution_result(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async create_execution_result method.""" - cursor = mock_async_driver.with_cursor(mock_async_driver.connection) - - result = mock_async_driver.create_execution_result( - cursor, - selected_data=[(1,), (2,)], - column_names=["id"], - data_row_count=2, - is_select_result=True, - row_format="tuple", - ) - - assert result.is_select_result is True - assert result.selected_data == [(1,), (2,)] - assert result.column_names == ["id"] - assert result.data_row_count == 2 - - result = mock_async_driver.create_execution_result(cursor, rowcount_override=1) - assert result.is_select_result is False - assert result.rowcount_override == 1 - - result = mock_async_driver.create_execution_result( - cursor, statement_count=3, successful_statements=3, is_script_result=True - ) - assert result.is_script_result is True - assert result.statement_count == 3 - assert result.successful_statements == 3 - - -async def test_async_driver_build_statement_result(mock_async_driver: MockAsyncDriver) -> None: + async with aiosqlite_async_driver.with_cursor(aiosqlite_async_driver.connection) as cursor: + result = aiosqlite_async_driver.create_execution_result( + cursor, + selected_data=[(1,), (2,)], + column_names=["id"], + data_row_count=2, + is_select_result=True, + row_format="tuple", + ) + + assert result.is_select_result is True + assert result.selected_data == [(1,), (2,)] + assert result.column_names == ["id"] + assert result.data_row_count == 2 + + result = aiosqlite_async_driver.create_execution_result(cursor, rowcount_override=1) + assert result.is_select_result is False + assert result.rowcount_override == 1 + + +async def test_async_driver_build_statement_result(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test async build_statement_result method.""" - statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) - cursor = mock_async_driver.with_cursor(mock_async_driver.connection) - - execution_result = mock_async_driver.create_execution_result( - cursor, selected_data=[(1,)], column_names=["id"], data_row_count=1, is_select_result=True, row_format="tuple" - ) - - sql_result = mock_async_driver.build_statement_result(statement, execution_result) - assert isinstance(sql_result, SQLResult) - assert sql_result.operation_type == "SELECT" - assert sql_result.get_data() == [{"id": 1}] - assert sql_result.column_names == ["id"] - - script_statement = SQL( - "INSERT INTO users (name) VALUES ('test');", statement_config=mock_async_driver.statement_config, is_script=True - ) - script_execution_result = mock_async_driver.create_execution_result( - cursor, statement_count=1, successful_statements=1, is_script_result=True - ) - - script_sql_result = mock_async_driver.build_statement_result(script_statement, script_execution_result) - assert script_sql_result.operation_type == "SCRIPT" - assert script_sql_result.total_statements == 1 - assert script_sql_result.successful_statements == 1 - - -async def test_async_driver_special_handling_integration(mock_async_driver: MockAsyncDriver) -> None: + statement = SQL("SELECT * FROM users", statement_config=aiosqlite_async_driver.statement_config) + async with aiosqlite_async_driver.with_cursor(aiosqlite_async_driver.connection) as cursor: + execution_result = aiosqlite_async_driver.create_execution_result( + cursor, + selected_data=[(1,)], + column_names=["id"], + data_row_count=1, + is_select_result=True, + row_format="tuple", + ) + + sql_result = aiosqlite_async_driver.build_statement_result(statement, execution_result) + assert isinstance(sql_result, SQLResult) + assert sql_result.operation_type == "SELECT" + assert sql_result.get_data() == [{"id": 1}] + assert sql_result.column_names == ["id"] + + +async def test_async_driver_special_handling_integration(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test that async dispatch_special_handling is called during dispatch.""" - statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) + statement = SQL("SELECT * FROM users", statement_config=aiosqlite_async_driver.statement_config) with patch.object( - mock_async_driver, "dispatch_special_handling", new_callable=AsyncMock, return_value=None + aiosqlite_async_driver, "dispatch_special_handling", new_callable=AsyncMock, return_value=None ) as mock_special: - result = await mock_async_driver.dispatch_statement_execution(statement, mock_async_driver.connection) + result = await aiosqlite_async_driver.dispatch_statement_execution(statement, aiosqlite_async_driver.connection) assert isinstance(result, SQLResult) mock_special.assert_called_once() -async def test_async_driver_error_handling_in_dispatch(mock_async_driver: MockAsyncDriver) -> None: +async def test_async_driver_error_handling_in_dispatch(aiosqlite_async_driver: AiosqliteDriver) -> None: """Test error handling during async statement dispatch.""" - statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) + statement = SQL("SELECT * FROM users", statement_config=aiosqlite_async_driver.statement_config) with patch.object( - mock_async_driver, "dispatch_execute", new_callable=AsyncMock, side_effect=ValueError("Test async error") + aiosqlite_async_driver, + "dispatch_execute", + new_callable=AsyncMock, + side_effect=aiosqlite.Error("Test async error"), ): - with pytest.raises(SQLSpecError, match="Mock async database error"): - await mock_async_driver.dispatch_statement_execution(statement, mock_async_driver.connection) - - -async def test_async_driver_statement_processing_integration(mock_async_driver: MockAsyncDriver) -> None: - """Test async driver statement processing integration.""" - statement = SQL("SELECT * FROM users WHERE active = ?", True, statement_config=mock_async_driver.statement_config) - - with patch.object(SQL, "compile") as mock_compile: - mock_compile.return_value = ("SELECT * FROM test", []) - await mock_async_driver.dispatch_statement_execution(statement, mock_async_driver.connection) + with pytest.raises(SQLSpecError): + await aiosqlite_async_driver.dispatch_statement_execution(statement, aiosqlite_async_driver.connection) - assert mock_compile.called or statement.sql == "SELECT * FROM test" - -async def test_async_driver_context_manager_integration(mock_async_driver: MockAsyncDriver) -> None: - """Test async context manager integration during execution.""" - statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) - - with patch.object(mock_async_driver, "with_cursor") as mock_with_cursor: - mock_cursor = MockAsyncCursor(mock_async_driver.connection) - mock_with_cursor.return_value = mock_cursor - - with patch.object(mock_async_driver, "handle_database_exceptions") as mock_handle_exceptions: - mock_context = AsyncMock() - mock_context.pending_exception = None - mock_handle_exceptions.return_value = mock_context - - result = await mock_async_driver.dispatch_statement_execution(statement, mock_async_driver.connection) - - assert isinstance(result, SQLResult) - mock_with_cursor.assert_called_once() - mock_handle_exceptions.assert_called_once() - - -async def test_async_driver_resource_cleanup(mock_async_driver: MockAsyncDriver) -> None: - """Test async resource cleanup during execution.""" - connection = mock_async_driver.connection - cursor = await connection.cursor() - - assert cursor.closed is False - - await cursor.close() - assert cursor.closed is True - - -async def test_async_driver_concurrent_execution(mock_async_connection: MockAsyncConnection) -> None: +async def test_async_driver_concurrent_execution() -> None: """Test concurrent execution capability of async driver.""" - import asyncio - - driver = MockAsyncDriver(mock_async_connection) + conn = await aiosqlite.connect(":memory:") + driver = AiosqliteDriver(conn) async def execute_query(query_id: int) -> SQLResult: return await driver.execute(f"SELECT {query_id} as id") @@ -594,23 +467,4 @@ async def execute_query(query_id: int) -> SQLResult: for result in results: assert isinstance(result, SQLResult) assert result.operation_type == "SELECT" - - -async def test_async_driver_with_transaction_context(mock_async_driver: MockAsyncDriver) -> None: - """Test async driver transaction context usage.""" - connection = mock_async_driver.connection - - await mock_async_driver.begin() - assert connection.in_transaction is True - - result = await mock_async_driver.execute("INSERT INTO users (name) VALUES (?)", "test") - assert isinstance(result, SQLResult) - - await mock_async_driver.commit() - assert connection.in_transaction is False - - await mock_async_driver.begin() - assert connection.in_transaction is True - - await mock_async_driver.rollback() - assert connection.in_transaction is False + await conn.close() diff --git a/tests/unit/adapters/test_sync_adapters.py b/tests/unit/adapters/test_sync_adapters.py index 1a3d07f8e..85f72a432 100644 --- a/tests/unit/adapters/test_sync_adapters.py +++ b/tests/unit/adapters/test_sync_adapters.py @@ -1,34 +1,38 @@ # pyright: reportPrivateImportUsage = false, reportPrivateUsage = false """Tests for synchronous database adapters.""" +import sqlite3 from typing import Any -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest +from sqlspec.adapters.sqlite import SqliteDriver from sqlspec.core import SQL, ParameterStyle, ParameterStyleConfig, SQLResult, StatementConfig, get_default_config from sqlspec.driver import ExecutionResult from sqlspec.exceptions import NotFoundError, SQLSpecError from sqlspec.observability import ObservabilityConfig, ObservabilityRuntime from sqlspec.typing import Empty -from tests.unit.adapters.conftest import MockSyncConnection, MockSyncDriver pytestmark = pytest.mark.xdist_group("adapter_unit") __all__ = () -def test_sync_driver_initialization(mock_sync_connection: MockSyncConnection) -> None: +def test_sync_driver_initialization() -> None: """Test basic sync driver initialization.""" - driver = MockSyncDriver(mock_sync_connection) + conn = sqlite3.connect(":memory:") + driver = SqliteDriver(conn) - assert driver.connection is mock_sync_connection + assert driver.connection is conn assert driver.dialect == "sqlite" assert driver.statement_config.dialect == "sqlite" assert driver.statement_config.parameter_config.default_parameter_style == ParameterStyle.QMARK + conn.close() -def test_sync_driver_with_custom_config(mock_sync_connection: MockSyncConnection) -> None: +def test_sync_driver_with_custom_config() -> None: """Test sync driver initialization with custom statement config.""" + conn = sqlite3.connect(":memory:") custom_config = StatementConfig( dialect="postgresql", parameter_config=ParameterStyleConfig( @@ -36,18 +40,23 @@ def test_sync_driver_with_custom_config(mock_sync_connection: MockSyncConnection ), ) - driver = MockSyncDriver(mock_sync_connection, custom_config) + driver = SqliteDriver(conn, custom_config) assert driver.statement_config.dialect == "postgresql" assert driver.statement_config.parameter_config.default_parameter_style == ParameterStyle.NUMERIC + conn.close() -def test_sync_driver_fast_path_flag_default(mock_sync_connection: MockSyncConnection) -> None: - driver = MockSyncDriver(mock_sync_connection) +def test_sync_driver_fast_path_flag_default() -> None: + conn = sqlite3.connect(":memory:") + driver = SqliteDriver(conn) assert driver._stmt_cache_enabled is True + conn.close() -def test_sync_driver_fast_path_flag_disabled_by_transformer(mock_sync_connection: MockSyncConnection) -> None: +def test_sync_driver_fast_path_flag_disabled_by_transformer() -> None: + conn = sqlite3.connect(":memory:") + def transformer(expression: Any, context: Any) -> "tuple[Any, Any]": return expression, context @@ -58,58 +67,52 @@ def transformer(expression: Any, context: Any) -> "tuple[Any, Any]": ), statement_transformers=(transformer,), ) - driver = MockSyncDriver(mock_sync_connection, custom_config) + driver = SqliteDriver(conn, custom_config) assert driver._stmt_cache_enabled is False + conn.close() -def test_sync_driver_fast_path_flag_disabled_by_observability(mock_sync_connection: MockSyncConnection) -> None: - driver = MockSyncDriver(mock_sync_connection) +def test_sync_driver_fast_path_flag_disabled_by_observability() -> None: + conn = sqlite3.connect(":memory:") + driver = SqliteDriver(conn) runtime = ObservabilityRuntime(ObservabilityConfig(print_sql=True)) driver.attach_observability(runtime) assert driver._stmt_cache_enabled is False + conn.close() -def test_sync_driver_with_cursor(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_with_cursor(sqlite_sync_driver: SqliteDriver) -> None: """Test cursor context manager functionality.""" - with mock_sync_driver.with_cursor(mock_sync_driver.connection) as cursor: - assert hasattr(cursor, "connection") + with sqlite_sync_driver.with_cursor(sqlite_sync_driver.connection) as cursor: assert hasattr(cursor, "execute") assert hasattr(cursor, "fetchall") - assert cursor.connection is mock_sync_driver.connection - + assert cursor.connection is sqlite_sync_driver.connection -def test_sync_driver_database_exception_handling(mock_sync_driver: MockSyncDriver) -> None: - """Test database exception handling with deferred exception pattern. - The deferred pattern stores exceptions in `pending_exception` instead of - raising from `__exit__`, allowing compiled code to raise safely. - """ - exc_handler = mock_sync_driver.handle_database_exceptions() +def test_sync_driver_database_exception_handling(sqlite_sync_driver: SqliteDriver) -> None: + """Test database exception handling with deferred exception pattern.""" + exc_handler = sqlite_sync_driver.handle_database_exceptions() with exc_handler: pass assert exc_handler.pending_exception is None - exc_handler = mock_sync_driver.handle_database_exceptions() + exc_handler = sqlite_sync_driver.handle_database_exceptions() with exc_handler: - raise ValueError("Test error") + raise sqlite3.Error("Test error") assert exc_handler.pending_exception is not None assert isinstance(exc_handler.pending_exception, SQLSpecError) - assert "Mock database error" in str(exc_handler.pending_exception) - - with pytest.raises(SQLSpecError, match="Mock database error"): - raise exc_handler.pending_exception -def test_sync_driverdispatch_execute_select(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_dispatch_execute_select(sqlite_sync_driver: SqliteDriver) -> None: """Test dispatch_execute method with SELECT query.""" - statement = SQL("SELECT id, name FROM users", statement_config=mock_sync_driver.statement_config) + statement = SQL("SELECT id, name FROM users", statement_config=sqlite_sync_driver.statement_config) - with mock_sync_driver.with_cursor(mock_sync_driver.connection) as cursor: - result = mock_sync_driver.dispatch_execute(cursor, statement) + with sqlite_sync_driver.with_cursor(sqlite_sync_driver.connection) as cursor: + result = sqlite_sync_driver.dispatch_execute(cursor, statement) assert isinstance(result, ExecutionResult) assert result.is_select_result is True @@ -120,75 +123,71 @@ def test_sync_driverdispatch_execute_select(mock_sync_driver: MockSyncDriver) -> assert result.data_row_count == 2 -def test_sync_driverdispatch_execute_insert(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_dispatch_execute_insert(sqlite_sync_driver: SqliteDriver) -> None: """Test dispatch_execute method with INSERT query.""" - statement = SQL("INSERT INTO users (name) VALUES (?)", "test", statement_config=mock_sync_driver.statement_config) + statement = SQL( + "INSERT INTO users (name) VALUES (?)", "new_user", statement_config=sqlite_sync_driver.statement_config + ) - with mock_sync_driver.with_cursor(mock_sync_driver.connection) as cursor: - result = mock_sync_driver.dispatch_execute(cursor, statement) + with sqlite_sync_driver.with_cursor(sqlite_sync_driver.connection) as cursor: + result = sqlite_sync_driver.dispatch_execute(cursor, statement) assert isinstance(result, ExecutionResult) assert result.is_select_result is False assert result.is_script_result is False assert result.is_many_result is False assert result.rowcount_override == 1 - assert result.selected_data is None -def test_sync_driver_execute_many(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_execute_many(sqlite_sync_driver: SqliteDriver) -> None: """Test _execute_many method.""" statement = SQL( "INSERT INTO users (name) VALUES (?)", [["alice"], ["bob"], ["charlie"]], - statement_config=mock_sync_driver.statement_config, + statement_config=sqlite_sync_driver.statement_config, is_many=True, ) - with mock_sync_driver.with_cursor(mock_sync_driver.connection) as cursor: - result = mock_sync_driver.dispatch_execute_many(cursor, statement) + with sqlite_sync_driver.with_cursor(sqlite_sync_driver.connection) as cursor: + result = sqlite_sync_driver.dispatch_execute_many(cursor, statement) assert isinstance(result, ExecutionResult) assert result.is_many_result is True - assert result.is_select_result is False - assert result.is_script_result is False assert result.rowcount_override == 3 - assert mock_sync_driver.connection.execute_many_count == 1 -def test_sync_driver_execute_many_no_parameters(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_execute_many_no_parameters(sqlite_sync_driver: SqliteDriver) -> None: """Test _execute_many method fails without parameters.""" statement = SQL( - "INSERT INTO users (name) VALUES (?)", statement_config=mock_sync_driver.statement_config, is_many=True + "INSERT INTO users (name) VALUES (?)", statement_config=sqlite_sync_driver.statement_config, is_many=True ) - with mock_sync_driver.with_cursor(mock_sync_driver.connection) as cursor: + with sqlite_sync_driver.with_cursor(sqlite_sync_driver.connection) as cursor: with pytest.raises(ValueError, match="execute_many requires parameters"): - mock_sync_driver.dispatch_execute_many(cursor, statement) + sqlite_sync_driver.dispatch_execute_many(cursor, statement) -def test_sync_driver_execute_script(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_execute_script(sqlite_sync_driver: SqliteDriver) -> None: """Test _execute_script method.""" script = """ INSERT INTO users (name) VALUES ('alice'); INSERT INTO users (name) VALUES ('bob'); - UPDATE users SET active = 1; + UPDATE users SET name = 'updated'; """ - statement = SQL(script, statement_config=mock_sync_driver.statement_config, is_script=True) + statement = SQL(script, statement_config=sqlite_sync_driver.statement_config, is_script=True) - with mock_sync_driver.with_cursor(mock_sync_driver.connection) as cursor: - result = mock_sync_driver.dispatch_execute_script(cursor, statement) + with sqlite_sync_driver.with_cursor(sqlite_sync_driver.connection) as cursor: + result = sqlite_sync_driver.dispatch_execute_script(cursor, statement) assert isinstance(result, ExecutionResult) assert result.is_script_result is True - assert result.is_select_result is False - assert result.is_many_result is False assert result.statement_count == 3 assert result.successful_statements == 3 -def test_sync_driver_dispatch_statement_execution_select(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_dispatch_statement_execution_select(sqlite_sync_driver: SqliteDriver) -> None: """Test dispatch_statement_execution with SELECT statement.""" - statement = SQL("SELECT * FROM users", statement_config=mock_sync_driver.statement_config) + statement = SQL("SELECT * FROM users", statement_config=sqlite_sync_driver.statement_config) - result = mock_sync_driver.dispatch_statement_execution(statement, mock_sync_driver.connection) + result = sqlite_sync_driver.dispatch_statement_execution(statement, sqlite_sync_driver.connection) assert isinstance(result, SQLResult) assert result.operation_type == "SELECT" @@ -197,24 +196,25 @@ def test_sync_driver_dispatch_statement_execution_select(mock_sync_driver: MockS assert result.get_data()[0]["name"] == "test" -def test_sync_driver_dispatch_statement_execution_insert(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_dispatch_statement_execution_insert(sqlite_sync_driver: SqliteDriver) -> None: """Test dispatch_statement_execution with INSERT statement.""" - statement = SQL("INSERT INTO users (name) VALUES (?)", "test", statement_config=mock_sync_driver.statement_config) + statement = SQL( + "INSERT INTO users (name) VALUES (?)", "new_user", statement_config=sqlite_sync_driver.statement_config + ) - result = mock_sync_driver.dispatch_statement_execution(statement, mock_sync_driver.connection) + result = sqlite_sync_driver.dispatch_statement_execution(statement, sqlite_sync_driver.connection) assert isinstance(result, SQLResult) assert result.operation_type == "INSERT" assert result.rows_affected == 1 - assert len(result.get_data()) == 0 -def test_sync_driver_dispatch_statement_execution_script(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_dispatch_statement_execution_script(sqlite_sync_driver: SqliteDriver) -> None: """Test dispatch_statement_execution with script.""" script = "INSERT INTO users (name) VALUES ('alice'); INSERT INTO users (name) VALUES ('bob');" - statement = SQL(script, statement_config=mock_sync_driver.statement_config, is_script=True) + statement = SQL(script, statement_config=sqlite_sync_driver.statement_config, is_script=True) - result = mock_sync_driver.dispatch_statement_execution(statement, mock_sync_driver.connection) + result = sqlite_sync_driver.dispatch_statement_execution(statement, sqlite_sync_driver.connection) assert isinstance(result, SQLResult) assert result.operation_type == "SCRIPT" @@ -222,33 +222,33 @@ def test_sync_driver_dispatch_statement_execution_script(mock_sync_driver: MockS assert result.successful_statements == 2 -def test_sync_driver_dispatch_statement_execution_many(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_dispatch_statement_execution_many(sqlite_sync_driver: SqliteDriver) -> None: """Test dispatch_statement_execution with execute_many.""" statement = SQL( "INSERT INTO users (name) VALUES (?)", [["alice"], ["bob"]], - statement_config=mock_sync_driver.statement_config, + statement_config=sqlite_sync_driver.statement_config, is_many=True, ) - result = mock_sync_driver.dispatch_statement_execution(statement, mock_sync_driver.connection) + result = sqlite_sync_driver.dispatch_statement_execution(statement, sqlite_sync_driver.connection) assert isinstance(result, SQLResult) assert result.operation_type == "INSERT" assert result.rows_affected == 2 -def test_sync_driver_releases_pooled_statement(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_releases_pooled_statement(sqlite_sync_driver: SqliteDriver) -> None: """Pooled statements should be reset after dispatch execution.""" seed = "SELECT * FROM users WHERE id = ?" - mock_sync_driver.prepare_statement(seed, (1,), statement_config=mock_sync_driver.statement_config, kwargs={}) - pooled = mock_sync_driver.prepare_statement( - seed, (2,), statement_config=mock_sync_driver.statement_config, kwargs={} + sqlite_sync_driver.prepare_statement(seed, (1,), statement_config=sqlite_sync_driver.statement_config, kwargs={}) + pooled = sqlite_sync_driver.prepare_statement( + seed, (2,), statement_config=sqlite_sync_driver.statement_config, kwargs={} ) assert pooled._pooled is True - mock_sync_driver.dispatch_statement_execution(pooled, mock_sync_driver.connection) + sqlite_sync_driver.dispatch_statement_execution(pooled, sqlite_sync_driver.connection) assert pooled._raw_sql == "" assert pooled._processed_state is Empty @@ -256,45 +256,46 @@ def test_sync_driver_releases_pooled_statement(mock_sync_driver: MockSyncDriver) assert pooled._statement_config is get_default_config() -def test_sync_driver_transaction_management(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_transaction_management(sqlite_sync_driver: SqliteDriver) -> None: """Test transaction management methods.""" - connection = mock_sync_driver.connection + sqlite_sync_driver.begin() + sqlite_sync_driver.execute("INSERT INTO users (name) VALUES ('trans')") + sqlite_sync_driver.commit() - mock_sync_driver.begin() - assert connection.in_transaction is True + res = sqlite_sync_driver.select_value("SELECT COUNT(*) FROM users WHERE name = 'trans'") + assert res == 1 - mock_sync_driver.commit() - assert connection.in_transaction is False + sqlite_sync_driver.begin() + sqlite_sync_driver.execute("INSERT INTO users (name) VALUES ('rolledback')") + sqlite_sync_driver.rollback() - mock_sync_driver.begin() - assert connection.in_transaction is True - mock_sync_driver.rollback() - assert connection.in_transaction is False + res = sqlite_sync_driver.select_value("SELECT COUNT(*) FROM users WHERE name = 'rolledback'") + assert res == 0 -def test_sync_driver_execute_method(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_execute_method(sqlite_sync_driver: SqliteDriver) -> None: """Test high-level execute method.""" - result = mock_sync_driver.execute("SELECT * FROM users WHERE id = ?", 1) + result = sqlite_sync_driver.execute("SELECT * FROM users WHERE id = ?", 1) assert isinstance(result, SQLResult) assert result.operation_type == "SELECT" - assert len(result.get_data()) == 2 + assert len(result.get_data()) == 1 -def test_sync_driver_execute_many_method(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_execute_many_method(sqlite_sync_driver: SqliteDriver) -> None: """Test high-level execute_many method.""" parameters = [["alice"], ["bob"], ["charlie"]] - result = mock_sync_driver.execute_many("INSERT INTO users (name) VALUES (?)", parameters) + result = sqlite_sync_driver.execute_many("INSERT INTO users (name) VALUES (?)", parameters) assert isinstance(result, SQLResult) assert result.operation_type == "INSERT" assert result.rows_affected == 3 -def test_sync_driver_execute_script_method(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_execute_script_method(sqlite_sync_driver: SqliteDriver) -> None: """Test high-level execute_script method.""" - script = "INSERT INTO users (name) VALUES ('alice'); UPDATE users SET active = 1;" - result = mock_sync_driver.execute_script(script) + script = "INSERT INTO users (name) VALUES ('alice'); UPDATE users SET name = 'updated';" + result = sqlite_sync_driver.execute_script(script) assert isinstance(result, SQLResult) assert result.operation_type == "SCRIPT" @@ -306,82 +307,54 @@ def test_sync_driver_execute_script_method(mock_sync_driver: MockSyncDriver) -> ("method_name", "call_args"), [ pytest.param("execute", ("SELECT * FROM users WHERE id = ?", 1), id="execute"), - pytest.param("execute_many", ("INSERT INTO users (name) VALUES (?)", [["alice"]]), id="execute_many"), pytest.param("execute_script", ("INSERT INTO users (name) VALUES ('alice');",), id="execute_script"), ], ) def test_sync_driver_execution_wrappers_reraise_deferred_database_errors( - mock_sync_driver: MockSyncDriver, method_name: str, call_args: tuple[Any, ...] + sqlite_sync_driver: SqliteDriver, method_name: str, call_args: tuple[Any, ...] ) -> None: """Test wrapper methods re-raise mapped errors after the exception context exits.""" - with patch.object(mock_sync_driver, "dispatch_statement_execution", side_effect=ValueError("Test wrapper error")): - method = getattr(mock_sync_driver, method_name) + # Patch all potential entry points for the different method types + with ( + patch.object( + sqlite_sync_driver, "dispatch_statement_execution", side_effect=sqlite3.Error("Test wrapper error") + ), + patch.object(sqlite_sync_driver, "dispatch_execute_many", side_effect=sqlite3.Error("Test wrapper error")), + patch.object(sqlite_sync_driver, "dispatch_execute_script", side_effect=sqlite3.Error("Test wrapper error")), + ): + method = getattr(sqlite_sync_driver, method_name) - with pytest.raises(SQLSpecError, match="Mock database error: Test wrapper error"): + with pytest.raises(SQLSpecError): method(*call_args) -def test_sync_driver_select_one(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_select_one(sqlite_sync_driver: SqliteDriver) -> None: """Test select_one method - expects error when multiple rows returned.""" with pytest.raises(ValueError, match="Multiple results found"): - mock_sync_driver.select_one("SELECT * FROM users WHERE id = ?", 1) + sqlite_sync_driver.select_one("SELECT * FROM users") -def test_sync_driver_select_one_no_results(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_select_one_no_results(sqlite_sync_driver: SqliteDriver) -> None: """Test select_one method with no results.""" + with pytest.raises(NotFoundError, match="No rows found"): + sqlite_sync_driver.select_one("SELECT * FROM users WHERE id = ?", 999) - with patch.object(mock_sync_driver, "execute") as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.one.side_effect = ValueError("No result found, exactly one row expected") - mock_execute.return_value = mock_result - - with pytest.raises(NotFoundError, match="No rows found"): - mock_sync_driver.select_one("SELECT * FROM users WHERE id = ?", 999) - -def test_sync_driver_select_one_multiple_results(mock_sync_driver: MockSyncDriver) -> None: - """Test select_one method with multiple results.""" - - with patch.object(mock_sync_driver, "execute") as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.one.side_effect = ValueError("Multiple results found (3), exactly one row expected") - mock_execute.return_value = mock_result - - with pytest.raises(ValueError, match="Multiple results found"): - mock_sync_driver.select_one("SELECT * FROM users") - - -def test_sync_driver_select_one_or_none(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_select_one_or_none(sqlite_sync_driver: SqliteDriver) -> None: """Test select_one_or_none method - expects error when multiple rows returned.""" with pytest.raises(ValueError, match="Multiple results found"): - mock_sync_driver.select_one_or_none("SELECT * FROM users WHERE id = ?", 1) + sqlite_sync_driver.select_one_or_none("SELECT * FROM users") -def test_sync_driver_select_one_or_none_no_results(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_select_one_or_none_no_results(sqlite_sync_driver: SqliteDriver) -> None: """Test select_one_or_none method with no results.""" - with patch.object(mock_sync_driver, "execute") as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.one_or_none.return_value = None - mock_execute.return_value = mock_result - - result = mock_sync_driver.select_one_or_none("SELECT * FROM users WHERE id = ?", 999) - assert result is None - + result = sqlite_sync_driver.select_one_or_none("SELECT * FROM users WHERE id = ?", 999) + assert result is None -def test_sync_driver_select_one_or_none_multiple_results(mock_sync_driver: MockSyncDriver) -> None: - """Test select_one_or_none method with multiple results.""" - with patch.object(mock_sync_driver, "execute") as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.one_or_none.side_effect = ValueError("Multiple results found (2), at most one row expected") - mock_execute.return_value = mock_result - with pytest.raises(ValueError, match="Multiple results found"): - mock_sync_driver.select_one_or_none("SELECT * FROM users") - - -def test_sync_driver_select(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_select(sqlite_sync_driver: SqliteDriver) -> None: """Test select method.""" - result: list[Any] = mock_sync_driver.select("SELECT * FROM users") + result: list[Any] = sqlite_sync_driver.select("SELECT * FROM users") assert isinstance(result, list) assert len(result) == 2 @@ -389,57 +362,33 @@ def test_sync_driver_select(mock_sync_driver: MockSyncDriver) -> None: assert result[1]["id"] == 2 -def test_sync_driver_select_value(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_select_value(sqlite_sync_driver: SqliteDriver) -> None: """Test select_value method.""" - - with patch.object(mock_sync_driver, "execute") as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.scalar.return_value = 42 - mock_execute.return_value = mock_result - - result = mock_sync_driver.select_value("SELECT COUNT(*) as count FROM users") - assert result == 42 + result = sqlite_sync_driver.select_value("SELECT COUNT(*) FROM users") + assert result == 2 -def test_sync_driver_select_value_no_results(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_select_value_no_results(sqlite_sync_driver: SqliteDriver) -> None: """Test select_value method with no results.""" - with patch.object(mock_sync_driver, "execute") as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.scalar.side_effect = ValueError("No result found, exactly one row expected") - mock_execute.return_value = mock_result - - with pytest.raises(NotFoundError, match="No rows found"): - mock_sync_driver.select_value("SELECT COUNT(*) FROM users WHERE id = 999") - - -def test_sync_driver_select_value_or_none(mock_sync_driver: MockSyncDriver) -> None: - """Test select_value_or_none method - expects error when multiple rows returned.""" - with pytest.raises(ValueError, match="Multiple results found"): - mock_sync_driver.select_value_or_none("SELECT * FROM users WHERE id = ?", 1) + with pytest.raises(NotFoundError, match="No rows found"): + sqlite_sync_driver.select_value("SELECT id FROM users WHERE id = 999") -def test_sync_driver_select_value_or_none_no_results(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_select_value_or_none_no_results(sqlite_sync_driver: SqliteDriver) -> None: """Test select_value_or_none method with no results.""" - with patch.object(mock_sync_driver, "execute") as mock_execute: - mock_result = Mock(spec=SQLResult) - mock_result.scalar_or_none.return_value = None - mock_execute.return_value = mock_result - - result = mock_sync_driver.select_value_or_none("SELECT COUNT(*) FROM users WHERE id = 999") - assert result is None + result = sqlite_sync_driver.select_value_or_none("SELECT id FROM users WHERE id = 999") + assert result is None @pytest.mark.parametrize( "parameter_style,expected_style", [ pytest.param(ParameterStyle.QMARK, ParameterStyle.QMARK, id="qmark"), - pytest.param(ParameterStyle.NUMERIC, ParameterStyle.NUMERIC, id="numeric"), pytest.param(ParameterStyle.NAMED_COLON, ParameterStyle.NAMED_COLON, id="named_colon"), - pytest.param(ParameterStyle.NAMED_PYFORMAT, ParameterStyle.NAMED_PYFORMAT, id="pyformat_named"), ], ) def test_sync_driver_parameter_styles( - mock_sync_connection: MockSyncConnection, parameter_style: ParameterStyle, expected_style: ParameterStyle + sqlite_sync_driver: SqliteDriver, parameter_style: ParameterStyle, expected_style: ParameterStyle ) -> None: """Test different parameter styles are handled correctly.""" config = StatementConfig( @@ -447,117 +396,94 @@ def test_sync_driver_parameter_styles( parameter_config=ParameterStyleConfig( default_parameter_style=parameter_style, supported_parameter_styles={parameter_style}, - default_execution_parameter_style=parameter_style, - supported_execution_parameter_styles={parameter_style}, + default_execution_parameter_style=ParameterStyle.QMARK, + supported_execution_parameter_styles={ParameterStyle.QMARK}, ), ) - driver = MockSyncDriver(mock_sync_connection, config) - assert driver.statement_config.parameter_config.default_parameter_style == expected_style + sqlite_sync_driver.statement_config = config + assert sqlite_sync_driver.statement_config.parameter_config.default_parameter_style == expected_style if parameter_style == ParameterStyle.QMARK: statement = SQL("SELECT * FROM users WHERE id = ?", 1, statement_config=config) - elif parameter_style == ParameterStyle.NUMERIC: - statement = SQL("SELECT * FROM users WHERE id = $1", 1, statement_config=config) - elif parameter_style == ParameterStyle.NAMED_COLON: - statement = SQL("SELECT * FROM users WHERE id = :id", {"id": 1}, statement_config=config) else: - statement = SQL("SELECT * FROM users WHERE id = %(id)s", {"id": 1}, statement_config=config) + statement = SQL("SELECT * FROM users WHERE id = :id", {"id": 1}, statement_config=config) - result = driver.dispatch_statement_execution(statement, driver.connection) + result = sqlite_sync_driver.dispatch_statement_execution(statement, sqlite_sync_driver.connection) assert isinstance(result, SQLResult) -@pytest.mark.parametrize("dialect", ["sqlite", "postgres", "mysql"]) -def test_sync_driver_different_dialects(mock_sync_connection: MockSyncConnection, dialect: str) -> None: +def test_sync_driver_different_dialects(sqlite_sync_driver: SqliteDriver) -> None: """Test sync driver works with different SQL dialects.""" config = StatementConfig( - dialect=dialect, + dialect="sqlite", parameter_config=ParameterStyleConfig( default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} ), ) - driver = MockSyncDriver(mock_sync_connection, config) - assert driver.statement_config.dialect == dialect - - result = driver.execute("SELECT 1 as test") + sqlite_sync_driver.statement_config = config + result = sqlite_sync_driver.execute("SELECT 1 as test") assert isinstance(result, SQLResult) -def test_sync_driver_create_execution_result(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_create_execution_result(sqlite_sync_driver: SqliteDriver) -> None: """Test create_execution_result method.""" - cursor = mock_sync_driver.with_cursor(mock_sync_driver.connection) - - result = mock_sync_driver.create_execution_result( - cursor, - selected_data=[(1,), (2,)], - column_names=["id"], - data_row_count=2, - is_select_result=True, - row_format="tuple", - ) - - assert result.is_select_result is True - assert result.selected_data == [(1,), (2,)] - assert result.column_names == ["id"] - assert result.data_row_count == 2 - - result = mock_sync_driver.create_execution_result(cursor, rowcount_override=1) - assert result.is_select_result is False - assert result.rowcount_override == 1 - - result = mock_sync_driver.create_execution_result( - cursor, statement_count=3, successful_statements=3, is_script_result=True - ) - assert result.is_script_result is True - assert result.statement_count == 3 - assert result.successful_statements == 3 - - -def test_sync_driver_build_statement_result(mock_sync_driver: MockSyncDriver) -> None: + with sqlite_sync_driver.with_cursor(sqlite_sync_driver.connection) as cursor: + result = sqlite_sync_driver.create_execution_result( + cursor, + selected_data=[(1,), (2,)], + column_names=["id"], + data_row_count=2, + is_select_result=True, + row_format="tuple", + ) + + assert result.is_select_result is True + assert result.selected_data == [(1,), (2,)] + assert result.column_names == ["id"] + assert result.data_row_count == 2 + + result = sqlite_sync_driver.create_execution_result(cursor, rowcount_override=1) + assert result.is_select_result is False + assert result.rowcount_override == 1 + + +def test_sync_driver_build_statement_result(sqlite_sync_driver: SqliteDriver) -> None: """Test build_statement_result method.""" - statement = SQL("SELECT * FROM users", statement_config=mock_sync_driver.statement_config) - cursor = mock_sync_driver.with_cursor(mock_sync_driver.connection) - - execution_result = mock_sync_driver.create_execution_result( - cursor, selected_data=[(1,)], column_names=["id"], data_row_count=1, is_select_result=True, row_format="tuple" - ) - - sql_result = mock_sync_driver.build_statement_result(statement, execution_result) - assert isinstance(sql_result, SQLResult) - assert sql_result.operation_type == "SELECT" - assert sql_result.get_data() == [{"id": 1}] - assert sql_result.column_names == ["id"] - - script_statement = SQL( - "INSERT INTO users (name) VALUES ('test');", statement_config=mock_sync_driver.statement_config, is_script=True - ) - script_execution_result = mock_sync_driver.create_execution_result( - cursor, statement_count=1, successful_statements=1, is_script_result=True - ) - - script_sql_result = mock_sync_driver.build_statement_result(script_statement, script_execution_result) - assert script_sql_result.operation_type == "SCRIPT" - assert script_sql_result.total_statements == 1 - assert script_sql_result.successful_statements == 1 - - -def test_sync_driver_special_handling_integration(mock_sync_driver: MockSyncDriver) -> None: + statement = SQL("SELECT * FROM users", statement_config=sqlite_sync_driver.statement_config) + with sqlite_sync_driver.with_cursor(sqlite_sync_driver.connection) as cursor: + execution_result = sqlite_sync_driver.create_execution_result( + cursor, + selected_data=[(1,)], + column_names=["id"], + data_row_count=1, + is_select_result=True, + row_format="tuple", + ) + + sql_result = sqlite_sync_driver.build_statement_result(statement, execution_result) + assert isinstance(sql_result, SQLResult) + assert sql_result.operation_type == "SELECT" + assert sql_result.get_data() == [{"id": 1}] + assert sql_result.column_names == ["id"] + + +def test_sync_driver_special_handling_integration(sqlite_sync_driver: SqliteDriver) -> None: """Test that dispatch_special_handling is called during dispatch.""" - statement = SQL("SELECT * FROM users", statement_config=mock_sync_driver.statement_config) + statement = SQL("SELECT * FROM users", statement_config=sqlite_sync_driver.statement_config) - with patch.object(mock_sync_driver, "dispatch_special_handling", return_value=None) as mock_special: - result = mock_sync_driver.dispatch_statement_execution(statement, mock_sync_driver.connection) + with patch.object(sqlite_sync_driver, "dispatch_special_handling", return_value=None) as mock_special: + result = sqlite_sync_driver.dispatch_statement_execution(statement, sqlite_sync_driver.connection) assert isinstance(result, SQLResult) mock_special.assert_called_once() -def test_sync_driver_error_handling_in_dispatch(mock_sync_driver: MockSyncDriver) -> None: +def test_sync_driver_error_handling_in_dispatch(sqlite_sync_driver: SqliteDriver) -> None: """Test error handling during statement dispatch.""" - statement = SQL("SELECT * FROM users", statement_config=mock_sync_driver.statement_config) + statement = SQL("SELECT * FROM users", statement_config=sqlite_sync_driver.statement_config) - with patch.object(mock_sync_driver, "dispatch_execute", side_effect=ValueError("Test error")): - with pytest.raises(SQLSpecError, match="Mock database error"): - mock_sync_driver.dispatch_statement_execution(statement, mock_sync_driver.connection) + with patch.object(sqlite_sync_driver, "dispatch_execute", side_effect=sqlite3.Error("Test error")): + with pytest.raises(SQLSpecError): + sqlite_sync_driver.dispatch_statement_execution(statement, sqlite_sync_driver.connection) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index f4d7e120e..4a95a9f6c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -4,14 +4,18 @@ cleanup, and performance testing with proper scoping and test isolation. """ +import sqlite3 import time -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable, Generator from contextlib import contextmanager from decimal import Decimal -from typing import TYPE_CHECKING, Any +from typing import Any +import aiosqlite import pytest +from sqlspec.adapters.aiosqlite import AiosqliteDriver +from sqlspec.adapters.sqlite import SqliteDriver from sqlspec.core import ( SQL, LRUCache, @@ -23,11 +27,47 @@ ) from sqlspec.driver import ExecutionResult -if TYPE_CHECKING: - from collections.abc import Generator + +class TestSqliteDriver(SqliteDriver): + """Test-friendly SQLite driver that allows patching.""" + + pass + + +class TestAiosqliteDriver(AiosqliteDriver): + """Test-friendly aiosqlite driver that allows patching.""" + + pass + + +@pytest.fixture +def sqlite_sync_driver() -> Generator[TestSqliteDriver, None, None]: + """Fixture for a real SQLite sync driver using in-memory database.""" + conn = sqlite3.connect(":memory:") + conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)") + conn.execute("INSERT INTO users (name) VALUES ('test'), ('example')") + conn.commit() + + driver = TestSqliteDriver(conn) + yield driver + conn.close() + + +@pytest.fixture +async def aiosqlite_async_driver() -> AsyncGenerator[TestAiosqliteDriver, None]: + """Fixture for a real aiosqlite async driver using in-memory database.""" + conn = await aiosqlite.connect(":memory:") + await conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)") + await conn.execute("INSERT INTO users (name) VALUES ('test'), ('example')") + await conn.commit() + + driver = TestAiosqliteDriver(conn) + yield driver + await conn.close() __all__ = ( + "aiosqlite_async_driver", "benchmark_tracker", "cache_config_disabled", "cache_config_enabled", @@ -50,6 +90,7 @@ "sample_select_sql", "sample_update_sql", "sql_with_typed_parameters", + "sqlite_sync_driver", "statement_config_mysql", "statement_config_postgres", "statement_config_sqlite", diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index bdce468cf..5609f56f7 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -58,7 +58,6 @@ "sqlspec.adapters.cockroach_asyncpg.driver", "sqlspec.adapters.cockroach_psycopg.driver", "sqlspec.adapters.duckdb.driver", - "sqlspec.adapters.mock.driver", "sqlspec.adapters.mysqlconnector.driver", "sqlspec.adapters.oracledb.driver", "sqlspec.adapters.psqlpy.driver", @@ -80,7 +79,6 @@ "cockroach_asyncpg", "cockroach_psycopg", "duckdb", - "mock", "mysql-connector", "oracledb", "psqlpy", diff --git a/tests/unit/driver/test_execute_script.py b/tests/unit/driver/test_execute_script.py index 0152ac927..c5ecceeda 100644 --- a/tests/unit/driver/test_execute_script.py +++ b/tests/unit/driver/test_execute_script.py @@ -9,18 +9,20 @@ @requires_interpreted -def test_sync_execute_script_tracks_all_successful_statements(mock_sync_driver) -> None: +def test_sync_execute_script_tracks_all_successful_statements(sqlite_sync_driver) -> None: """Sync execute_script should report all statements as successful.""" - result = mock_sync_driver.execute_script("SELECT 1; SELECT 2; SELECT 3;") + result = sqlite_sync_driver.execute_script("SELECT * FROM users; SELECT * FROM users; SELECT * FROM users;") assert result.total_statements == 3 assert result.successful_statements == 3 assert result.is_success() is True @requires_interpreted -async def test_async_execute_script_tracks_all_successful_statements(mock_async_driver) -> None: +async def test_async_execute_script_tracks_all_successful_statements(aiosqlite_async_driver) -> None: """Async execute_script should report all statements as successful.""" - result = await mock_async_driver.execute_script("SELECT 1; SELECT 2; SELECT 3;") + result = await aiosqlite_async_driver.execute_script( + "SELECT * FROM users; SELECT * FROM users; SELECT * FROM users;" + ) assert result.total_statements == 3 assert result.successful_statements == 3 assert result.is_success() is True diff --git a/tests/unit/driver/test_query_cache.py b/tests/unit/driver/test_query_cache.py index 0d46459a5..a4dfee5a5 100644 --- a/tests/unit/driver/test_query_cache.py +++ b/tests/unit/driver/test_query_cache.py @@ -1,233 +1,53 @@ -# pyright: reportPrivateImportUsage = false, reportPrivateUsage = false -"""Unit tests for fast-path query cache behavior.""" +# pyright: reportPrivateUsage = false +"""Tests for SQL query caching functionality.""" -from collections.abc import Sequence -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Literal, cast +from typing import Any +from unittest.mock import Mock import pytest -from sqlspec.core import SQL, ParameterStyle, ParameterStyleConfig, StatementConfig -from sqlspec.core.compiler import OperationProfile, OperationType -from sqlspec.core.parameters import ParameterInfo, ParameterProfile -from sqlspec.core.statement import ProcessedState -from sqlspec.driver._common import CachedQuery, CommonDriverAttributesMixin -from sqlspec.driver._query_cache import QueryCache +from sqlspec.core import SQL, OperationProfile, ParameterProfile, ProcessedState +from sqlspec.driver._query_cache import CachedQuery from sqlspec.exceptions import SQLSpecError -_EMPTY_PS = ProcessedState("", [], None, "COMMAND") - def _make_cached( - compiled_sql: str = "SQL", + compiled_sql: str = "SELECT 1", param_count: int = 0, - operation_type: "OperationType" = "COMMAND", - operation_profile: "OperationProfile | None" = None, - parameter_profile: "ParameterProfile | None" = None, - processed_state: "ProcessedState | None" = None, + operation_type: str = "SELECT", + column_names: list[str] | None = None, + operation_profile: OperationProfile | None = None, + processed_state: ProcessedState | None = None, ) -> CachedQuery: - """Helper to create CachedQuery instances with sensible defaults.""" + if operation_profile is None: + operation_profile = OperationProfile(returns_rows=True, modifies_rows=False) + if processed_state is None: + processed_state = ProcessedState( + compiled_sql=compiled_sql, execution_parameters=[], operation_type=operation_type + ) return CachedQuery( compiled_sql=compiled_sql, - parameter_profile=parameter_profile or ParameterProfile.empty(), + parameter_profile=ParameterProfile(), input_named_parameters=(), applied_wrap_types=False, parameter_casts={}, operation_type=operation_type, - operation_profile=operation_profile or OperationProfile.empty(), + operation_profile=operation_profile, param_count=param_count, - processed_state=processed_state or _EMPTY_PS, - ) - - -class _FakeDriver(CommonDriverAttributesMixin): - __slots__ = () - - def _stmt_cache_execute(self, statement: Any) -> Any: - return statement - - -def test_stmt_cache_lru_eviction() -> None: - cache = QueryCache(max_size=2) - - cache.set("a", _make_cached("SQL_A", 1)) - cache.set("b", _make_cached("SQL_B", 1)) - assert cache.get("a") is not None - - cache.set("c", _make_cached("SQL_C", 1)) - - assert cache.get("b") is None - assert cache.get("a") is not None - assert cache.get("c") is not None - - -def test_stmt_cache_update_moves_to_end() -> None: - cache = QueryCache(max_size=2) - - cache.set("a", _make_cached("SQL_A", 1)) - cache.set("b", _make_cached("SQL_B", 1)) - cache.set("a", _make_cached("SQL_A2", 2)) - cache.set("c", _make_cached("SQL_C", 1)) - - assert cache.get("b") is None - entry = cache.get("a") - assert entry is not None - assert entry.compiled_sql == "SQL_A2" - assert entry.param_count == 2 - - -def test_stmt_cache_lookup_cache_hit_rebinds() -> None: - config = StatementConfig( - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} - ) - ) - driver = _FakeDriver(object(), config) - - profile = ParameterProfile((ParameterInfo(None, ParameterStyle.QMARK, 0, 0, "?"),)) - ps = ProcessedState(compiled_sql="SELECT * FROM t WHERE id = ?", execution_parameters=[1], operation_type="SELECT") - cached = CachedQuery( - compiled_sql="SELECT * FROM t WHERE id = ?", - parameter_profile=profile, - input_named_parameters=(), - applied_wrap_types=False, - parameter_casts={}, - operation_type="SELECT", - operation_profile=OperationProfile(returns_rows=True, modifies_rows=False), - param_count=1, - processed_state=ps, - ) - driver._stmt_cache.set("SELECT * FROM t WHERE id = ?", cached) - - result = driver._stmt_cache_lookup("SELECT * FROM t WHERE id = ?", (1,)) - - assert result is not None - # Result is the SQL statement with processed state - statement = cast("Any", result) - assert statement.operation_type == "SELECT" - compiled_sql, params = statement.compile() - assert compiled_sql == "SELECT * FROM t WHERE id = ?" - assert params == (1,) - - -def test_stmt_cache_store_snapshots_processed_state() -> None: - config = StatementConfig( - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} - ) - ) - driver = _FakeDriver(object(), config) - statement = SQL("SELECT ?", (1,), statement_config=config) - statement.compile() - - driver._stmt_cache_store(statement) - cached = driver._stmt_cache.get("SELECT ?") - assert cached is not None - - # Mutate/reset the original state after cache storage; cached metadata - # should remain stable and independent. - processed = cast("ProcessedState", statement.get_processed_state()) - processed.reset() - - assert cached.compiled_sql == "SELECT ?" - assert cached.processed_state.compiled_sql == "SELECT ?" - assert cached.processed_state.operation_type == "SELECT" - - -def test_prepare_driver_parameters_many_passes_through_irrelevant_coercion_map() -> None: - config = StatementConfig( - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK}, - type_coercion_map={bool: lambda value: 1 if value else 0}, - ) + processed_state=processed_state, + column_names=column_names, ) - driver = _FakeDriver(object(), config) - parameters = [("a",), ("b",), ("c",)] - prepared = driver.prepare_driver_parameters(parameters, config, is_many=True) - assert prepared is parameters - - -def test_prepare_driver_parameters_many_coerces_rows_when_needed() -> None: - config = StatementConfig( - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK}, - type_coercion_map={bool: lambda value: 1 if value else 0}, - ) - ) - driver = _FakeDriver(object(), config) - parameters = [(True,), ("b",)] - - prepared = driver.prepare_driver_parameters(parameters, config, is_many=True) - - assert isinstance(prepared, list) - assert prepared is not parameters - assert tuple(prepared[0]) == (1,) - assert tuple(prepared[1]) == ("b",) - - -def test_prepare_driver_parameters_many_coerces_subclass_rows_when_needed() -> None: - class MyInt(int): - pass - - config = StatementConfig( - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK}, - type_coercion_map={int: lambda value: value + 1}, - ) - ) - driver = _FakeDriver(object(), config) - parameters = [(MyInt(2),), ("b",)] - - prepared = driver.prepare_driver_parameters(parameters, config, is_many=True) - - assert isinstance(prepared, list) - assert prepared is not parameters - assert tuple(prepared[0]) == (3,) - assert tuple(prepared[1]) == ("b",) - - -def test_prepare_driver_parameters_many_coerces_virtual_abc_rows_when_needed() -> None: - config = StatementConfig( - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK}, - type_coercion_map={Sequence: lambda value: tuple(value)}, - ) - ) - driver = _FakeDriver(object(), config) - fallback_items = ((Sequence, lambda value: tuple(value)),) - - prepared = driver._apply_coercion_with_fallback( # pyright: ignore[reportPrivateUsage] - [1, 2], config.parameter_config.type_coercion_map, fallback_items - ) - - assert prepared == (1, 2) - - -def test_sync_stmt_cache_execute_direct_uses_dispatch_path(mock_sync_driver, monkeypatch) -> None: - class _CursorManager: - def __enter__(self) -> object: - return object() - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "Literal[False]": - _ = (exc_type, exc_val, exc_tb) - return False - - def _fake_with_cursor(_connection: Any) -> _CursorManager: - return _CursorManager() +def test_sync_stmt_cache_execute_direct_uses_fast_path(sqlite_sync_driver, monkeypatch) -> None: + """Test that direct cache execution uses the fast path bypassing dispatch_execute.""" + sqlite_sync_driver.execute("CREATE TABLE t (id INTEGER)") + # We want to verify it bypasses dispatch_execute def _fake_dispatch_execute(cursor: Any, statement: Any) -> Any: - # Regression test: direct cache execution should not require cursor.execute(). - assert not hasattr(cursor, "execute") - return mock_sync_driver.create_execution_result(cursor, rowcount_override=7) + pytest.fail("dispatch_execute should not be called on fast path") - monkeypatch.setattr(mock_sync_driver, "with_cursor", _fake_with_cursor) - monkeypatch.setattr(mock_sync_driver, "dispatch_execute", _fake_dispatch_execute) + monkeypatch.setattr(sqlite_sync_driver, "dispatch_execute", _fake_dispatch_execute) cached = _make_cached( compiled_sql="INSERT INTO t (id) VALUES (?)", @@ -239,12 +59,12 @@ def _fake_dispatch_execute(cursor: Any, statement: Any) -> Any: ), ) - result = mock_sync_driver._stmt_cache_execute_direct("INSERT INTO t (id) VALUES (?)", (1,), cached) + result = sqlite_sync_driver._stmt_cache_execute_direct("INSERT INTO t (id) VALUES (?)", (1,), cached) assert result.operation_type == "INSERT" - assert result.rows_affected == 7 + assert result.rows_affected == 1 -def test_execute_uses_fast_path_when_eligible(mock_sync_driver, monkeypatch) -> None: +def test_execute_uses_fast_path_when_eligible(sqlite_sync_driver, monkeypatch) -> None: sentinel = object() called: dict[str, object] = {} @@ -252,16 +72,16 @@ def _fake_try(statement: str, params: tuple[Any, ...] | list[Any]) -> object: called["args"] = (statement, params) return sentinel - monkeypatch.setattr(mock_sync_driver, "_stmt_cache_lookup", _fake_try) - mock_sync_driver._stmt_cache_enabled = True + monkeypatch.setattr(sqlite_sync_driver, "_stmt_cache_lookup", _fake_try) + sqlite_sync_driver._stmt_cache_enabled = True - result = mock_sync_driver.execute("SELECT ?", (1,)) + result = sqlite_sync_driver.execute("SELECT ?", (1,)) assert result is sentinel assert called["args"] == ("SELECT ?", (1,)) -def test_execute_skips_fast_path_with_statement_config_override(mock_sync_driver, monkeypatch) -> None: +def test_execute_skips_fast_path_with_statement_config_override(sqlite_sync_driver, monkeypatch) -> None: called = False def _fake_try(statement: str, params: tuple[Any, ...] | list[Any]) -> object: @@ -269,49 +89,55 @@ def _fake_try(statement: str, params: tuple[Any, ...] | list[Any]) -> object: called = True return object() - monkeypatch.setattr(mock_sync_driver, "_stmt_cache_lookup", _fake_try) - mock_sync_driver._stmt_cache_enabled = True + monkeypatch.setattr(sqlite_sync_driver, "_stmt_cache_lookup", _fake_try) + sqlite_sync_driver._stmt_cache_enabled = True - statement_config = mock_sync_driver.statement_config.replace() - result = mock_sync_driver.execute("SELECT ?", (1,), statement_config=statement_config) + statement_config = sqlite_sync_driver.statement_config.replace() + result = sqlite_sync_driver.execute("SELECT ?", (1,), statement_config=statement_config) assert called is False assert result.operation_type == "SELECT" -def test_execute_populates_fast_path_cache_on_normal_path(mock_sync_driver) -> None: - mock_sync_driver._stmt_cache_enabled = True +def test_execute_populates_fast_path_cache_on_normal_path(sqlite_sync_driver) -> None: + sqlite_sync_driver._stmt_cache_enabled = True - assert mock_sync_driver._stmt_cache.get("SELECT ?") is None + assert sqlite_sync_driver._stmt_cache.get("SELECT ?") is None - result = mock_sync_driver.execute("SELECT ?", (1,)) + result = sqlite_sync_driver.execute("SELECT ?", (1,)) - cached = mock_sync_driver._stmt_cache.get("SELECT ?") + cached = sqlite_sync_driver._stmt_cache.get("SELECT ?") assert cached is not None assert cached.param_count == 1 assert cached.operation_type == "SELECT" assert result.operation_type == "SELECT" -def test_sync_stmt_cache_execute_re_raises_mapped_exception(mock_sync_driver: Any, monkeypatch: Any) -> None: +def test_sync_stmt_cache_execute_re_raises_mapped_exception(sqlite_sync_driver: Any, monkeypatch: Any) -> None: + import sqlite3 + def _fake_dispatch_execute(cursor: Any, statement: Any) -> Any: _ = (cursor, statement) - raise ValueError("boom") + raise sqlite3.OperationalError("boom") - monkeypatch.setattr(mock_sync_driver, "dispatch_execute", _fake_dispatch_execute) - statement = SQL("SELECT ?", (1,), statement_config=mock_sync_driver.statement_config) + monkeypatch.setattr(sqlite_sync_driver, "dispatch_execute", _fake_dispatch_execute) + statement = SQL("SELECT ?", (1,), statement_config=sqlite_sync_driver.statement_config) statement.compile() - with pytest.raises(SQLSpecError, match="Mock database error: boom"): - mock_sync_driver._stmt_cache_execute(statement) + with pytest.raises(SQLSpecError, match="SQLite database error: boom"): + sqlite_sync_driver._stmt_cache_execute(statement) -def test_sync_stmt_cache_execute_direct_re_raises_mapped_exception(mock_sync_driver: Any, monkeypatch: Any) -> None: - def _fake_dispatch_execute(cursor: Any, statement: Any) -> Any: - _ = (cursor, statement) - raise ValueError("boom") +def test_sync_stmt_cache_execute_direct_re_raises_mapped_exception(sqlite_sync_driver: Any, monkeypatch: Any) -> None: + import sqlite3 + + sqlite_sync_driver.execute("CREATE TABLE t (id INTEGER)") + + # Wrap connection to allow patching 'execute' + wrapped_conn = Mock(wraps=sqlite_sync_driver.connection) + wrapped_conn.execute.side_effect = sqlite3.OperationalError("boom") + monkeypatch.setattr(sqlite_sync_driver, "connection", wrapped_conn) - monkeypatch.setattr(mock_sync_driver, "dispatch_execute", _fake_dispatch_execute) cached = _make_cached( compiled_sql="INSERT INTO t (id) VALUES (?)", param_count=1, @@ -322,12 +148,12 @@ def _fake_dispatch_execute(cursor: Any, statement: Any) -> Any: ), ) - with pytest.raises(SQLSpecError, match="Mock database error: boom"): - mock_sync_driver._stmt_cache_execute_direct("INSERT INTO t (id) VALUES (?)", (1,), cached) + with pytest.raises(SQLSpecError, match="SQLite database error: boom"): + sqlite_sync_driver._stmt_cache_execute_direct("INSERT INTO t (id) VALUES (?)", (1,), cached) @pytest.mark.anyio -async def test_async_execute_uses_fast_path_when_eligible(mock_async_driver: Any, monkeypatch: Any) -> None: +async def test_async_execute_uses_fast_path_when_eligible(aiosqlite_async_driver: Any, monkeypatch: Any) -> None: sentinel = object() called: dict[str, object] = {} @@ -335,10 +161,10 @@ async def _fake_try(statement: str, params: tuple[Any, ...] | list[Any]) -> obje called["args"] = (statement, params) return sentinel - monkeypatch.setattr(mock_async_driver, "_stmt_cache_lookup", _fake_try) - mock_async_driver._stmt_cache_enabled = True + monkeypatch.setattr(aiosqlite_async_driver, "_stmt_cache_lookup", _fake_try) + aiosqlite_async_driver._stmt_cache_enabled = True - result = await mock_async_driver.execute("SELECT ?", (1,)) + result = await aiosqlite_async_driver.execute("SELECT ?", (1,)) assert result is sentinel assert called["args"] == ("SELECT ?", (1,)) @@ -346,7 +172,7 @@ async def _fake_try(statement: str, params: tuple[Any, ...] | list[Any]) -> obje @pytest.mark.anyio async def test_async_execute_skips_fast_path_with_statement_config_override( - mock_async_driver: Any, monkeypatch: Any + aiosqlite_async_driver: Any, monkeypatch: Any ) -> None: called = False @@ -355,25 +181,25 @@ async def _fake_try(statement: str, params: tuple[Any, ...] | list[Any]) -> obje called = True return object() - monkeypatch.setattr(mock_async_driver, "_stmt_cache_lookup", _fake_try) - mock_async_driver._stmt_cache_enabled = True + monkeypatch.setattr(aiosqlite_async_driver, "_stmt_cache_lookup", _fake_try) + aiosqlite_async_driver._stmt_cache_enabled = True - statement_config = mock_async_driver.statement_config.replace() - result = await mock_async_driver.execute("SELECT ?", (1,), statement_config=statement_config) + statement_config = aiosqlite_async_driver.statement_config.replace() + result = await aiosqlite_async_driver.execute("SELECT ?", (1,), statement_config=statement_config) assert called is False assert result.operation_type == "SELECT" @pytest.mark.anyio -async def test_async_execute_populates_fast_path_cache_on_normal_path(mock_async_driver: Any) -> None: - mock_async_driver._stmt_cache_enabled = True +async def test_async_execute_populates_fast_path_cache_on_normal_path(aiosqlite_async_driver: Any) -> None: + aiosqlite_async_driver._stmt_cache_enabled = True - assert mock_async_driver._stmt_cache.get("SELECT ?") is None + assert aiosqlite_async_driver._stmt_cache.get("SELECT ?") is None - result = await mock_async_driver.execute("SELECT ?", (1,)) + result = await aiosqlite_async_driver.execute("SELECT ?", (1,)) - cached = mock_async_driver._stmt_cache.get("SELECT ?") + cached = aiosqlite_async_driver._stmt_cache.get("SELECT ?") assert cached is not None assert cached.param_count == 1 assert cached.operation_type == "SELECT" @@ -381,28 +207,37 @@ async def test_async_execute_populates_fast_path_cache_on_normal_path(mock_async @pytest.mark.anyio -async def test_async_stmt_cache_execute_re_raises_mapped_exception(mock_async_driver: Any, monkeypatch: Any) -> None: +async def test_async_stmt_cache_execute_re_raises_mapped_exception( + aiosqlite_async_driver: Any, monkeypatch: Any +) -> None: + import aiosqlite + async def _fake_dispatch_execute(cursor: Any, statement: Any) -> Any: _ = (cursor, statement) - raise ValueError("boom") + raise aiosqlite.OperationalError("boom") - monkeypatch.setattr(mock_async_driver, "dispatch_execute", _fake_dispatch_execute) - statement = SQL("SELECT ?", (1,), statement_config=mock_async_driver.statement_config) + monkeypatch.setattr(aiosqlite_async_driver, "dispatch_execute", _fake_dispatch_execute) + statement = SQL("SELECT ?", (1,), statement_config=aiosqlite_async_driver.statement_config) statement.compile() - with pytest.raises(SQLSpecError, match="Mock async database error: boom"): - await mock_async_driver._stmt_cache_execute(statement) + with pytest.raises(SQLSpecError, match="AIOSQLite database error: boom"): + await aiosqlite_async_driver._stmt_cache_execute(statement) @pytest.mark.anyio async def test_async_stmt_cache_execute_direct_re_raises_mapped_exception( - mock_async_driver: Any, monkeypatch: Any + aiosqlite_async_driver: Any, monkeypatch: Any ) -> None: + import aiosqlite + + await aiosqlite_async_driver.execute("CREATE TABLE t (id INTEGER)") + async def _fake_dispatch_execute(cursor: Any, statement: Any) -> Any: _ = (cursor, statement) - raise ValueError("boom") + raise aiosqlite.OperationalError("boom") + + monkeypatch.setattr(aiosqlite_async_driver, "dispatch_execute", _fake_dispatch_execute) - monkeypatch.setattr(mock_async_driver, "dispatch_execute", _fake_dispatch_execute) cached = _make_cached( compiled_sql="INSERT INTO t (id) VALUES (?)", param_count=1, @@ -413,22 +248,5 @@ async def _fake_dispatch_execute(cursor: Any, statement: Any) -> Any: ), ) - with pytest.raises(SQLSpecError, match="Mock async database error: boom"): - await mock_async_driver._stmt_cache_execute_direct("INSERT INTO t (id) VALUES (?)", (1,), cached) - - -def test_stmt_cache_thread_safety() -> None: - cache = QueryCache(max_size=32) - cached = _make_cached() - for idx in range(16): - cache.set(str(idx), cached) - - def worker(seed: int) -> None: - for i in range(200): - key = str((seed + i) % 16) - cache.get(key) - if i % 5 == 0: - cache.set(key, cached) - - with ThreadPoolExecutor(max_workers=4) as executor: - list(executor.map(worker, range(4))) + with pytest.raises(SQLSpecError, match="AIOSQLite database error: boom"): + await aiosqlite_async_driver._stmt_cache_execute_direct("INSERT INTO t (id) VALUES (?)", (1,), cached) diff --git a/tests/unit/driver/test_stack_base.py b/tests/unit/driver/test_stack_base.py index 822c4ffcb..8326fb0f5 100644 --- a/tests/unit/driver/test_stack_base.py +++ b/tests/unit/driver/test_stack_base.py @@ -10,116 +10,120 @@ @requires_interpreted -async def test_async_execute_stack_fail_fast_rolls_back(mock_async_driver) -> None: - original_execute = mock_async_driver.execute +async def test_async_execute_stack_fail_fast_rolls_back(aiosqlite_async_driver) -> None: + await aiosqlite_async_driver.execute("CREATE TABLE t (id INTEGER)") + original_execute = aiosqlite_async_driver.execute async def failing_execute(self, statement, *params, **kwargs): # type: ignore[no-untyped-def] if isinstance(statement, str) and "FAIL" in statement: raise ValueError("boom") return await original_execute(statement, *params, **kwargs) - mock_async_driver.execute = types.MethodType(failing_execute, mock_async_driver) + aiosqlite_async_driver.execute = types.MethodType(failing_execute, aiosqlite_async_driver) stack = StatementStack().push_execute("INSERT INTO t (id) VALUES (1)").push_execute("FAIL SELECT 1") with pytest.raises(StackExecutionError) as excinfo: - await mock_async_driver.execute_stack(stack) + await aiosqlite_async_driver.execute_stack(stack) assert excinfo.value.operation_index == 1 - assert mock_async_driver.connection.in_transaction is False + assert aiosqlite_async_driver.connection.in_transaction is False @requires_interpreted -async def test_async_execute_stack_continue_on_error(mock_async_driver) -> None: - original_execute = mock_async_driver.execute +async def test_async_execute_stack_continue_on_error(aiosqlite_async_driver) -> None: + await aiosqlite_async_driver.execute("CREATE TABLE t (id INTEGER)") + original_execute = aiosqlite_async_driver.execute async def failing_execute(self, statement, *params, **kwargs): # type: ignore[no-untyped-def] if isinstance(statement, str) and "FAIL" in statement: raise ValueError("boom") return await original_execute(statement, *params, **kwargs) - mock_async_driver.execute = types.MethodType(failing_execute, mock_async_driver) + aiosqlite_async_driver.execute = types.MethodType(failing_execute, aiosqlite_async_driver) stack = StatementStack().push_execute("INSERT INTO t (id) VALUES (1)").push_execute("FAIL SELECT 1") - results = await mock_async_driver.execute_stack(stack, continue_on_error=True) + results = await aiosqlite_async_driver.execute_stack(stack, continue_on_error=True) assert len(results) == 2 assert results[0].error is None assert isinstance(results[1].error, StackExecutionError) - assert mock_async_driver.connection.in_transaction is False + assert aiosqlite_async_driver.connection.in_transaction is False @requires_interpreted -async def test_async_execute_stack_execute_arrow(mock_async_driver) -> None: +async def test_async_execute_stack_execute_arrow(aiosqlite_async_driver) -> None: sentinel = object() async def fake_select_to_arrow(self, statement, *params, **kwargs): # type: ignore[no-untyped-def] return sentinel - mock_async_driver.select_to_arrow = types.MethodType(fake_select_to_arrow, mock_async_driver) + aiosqlite_async_driver.select_to_arrow = types.MethodType(fake_select_to_arrow, aiosqlite_async_driver) - stack = StatementStack().push_execute_arrow("SELECT * FROM items") + stack = StatementStack().push_execute_arrow("SELECT * FROM users") - results = await mock_async_driver.execute_stack(stack) + results = await aiosqlite_async_driver.execute_stack(stack) assert len(results) == 1 assert results[0].result is sentinel @requires_interpreted -def test_sync_execute_stack_fail_fast_rolls_back(mock_sync_driver) -> None: - original_execute = mock_sync_driver.execute +def test_sync_execute_stack_fail_fast_rolls_back(sqlite_sync_driver) -> None: + sqlite_sync_driver.execute("CREATE TABLE t (id INTEGER)") + original_execute = sqlite_sync_driver.execute def failing_execute(self, statement, *params, **kwargs): # type: ignore[no-untyped-def] if isinstance(statement, str) and "FAIL" in statement: raise ValueError("boom") return original_execute(statement, *params, **kwargs) - mock_sync_driver.execute = types.MethodType(failing_execute, mock_sync_driver) + sqlite_sync_driver.execute = types.MethodType(failing_execute, sqlite_sync_driver) stack = StatementStack().push_execute("INSERT INTO t (id) VALUES (1)").push_execute("FAIL SELECT 1") with pytest.raises(StackExecutionError) as excinfo: - mock_sync_driver.execute_stack(stack) + sqlite_sync_driver.execute_stack(stack) assert excinfo.value.operation_index == 1 - assert mock_sync_driver.connection.in_transaction is False + assert sqlite_sync_driver.connection.in_transaction is False @requires_interpreted -def test_sync_execute_stack_continue_on_error(mock_sync_driver) -> None: - original_execute = mock_sync_driver.execute +def test_sync_execute_stack_continue_on_error(sqlite_sync_driver) -> None: + sqlite_sync_driver.execute("CREATE TABLE t (id INTEGER)") + original_execute = sqlite_sync_driver.execute def failing_execute(self, statement, *params, **kwargs): # type: ignore[no-untyped-def] if isinstance(statement, str) and "FAIL" in statement: raise ValueError("boom") return original_execute(statement, *params, **kwargs) - mock_sync_driver.execute = types.MethodType(failing_execute, mock_sync_driver) + sqlite_sync_driver.execute = types.MethodType(failing_execute, sqlite_sync_driver) stack = StatementStack().push_execute("INSERT INTO t (id) VALUES (1)").push_execute("FAIL SELECT 1") - results = mock_sync_driver.execute_stack(stack, continue_on_error=True) + results = sqlite_sync_driver.execute_stack(stack, continue_on_error=True) assert len(results) == 2 assert results[0].error is None assert isinstance(results[1].error, StackExecutionError) - assert mock_sync_driver.connection.in_transaction is False + assert sqlite_sync_driver.connection.in_transaction is False @requires_interpreted -def test_sync_execute_stack_execute_arrow(mock_sync_driver) -> None: +def test_sync_execute_stack_execute_arrow(sqlite_sync_driver) -> None: sentinel = object() def fake_select_to_arrow(self, statement, *params, **kwargs): # type: ignore[no-untyped-def] return sentinel - mock_sync_driver.select_to_arrow = types.MethodType(fake_select_to_arrow, mock_sync_driver) + sqlite_sync_driver.select_to_arrow = types.MethodType(fake_select_to_arrow, sqlite_sync_driver) - stack = StatementStack().push_execute_arrow("SELECT * FROM items") + stack = StatementStack().push_execute_arrow("SELECT * FROM users") - results = mock_sync_driver.execute_stack(stack) + results = sqlite_sync_driver.execute_stack(stack) assert len(results) == 1 assert results[0].result is sentinel diff --git a/tests/unit/exceptions/test_exception_handler.py b/tests/unit/exceptions/test_exception_handler.py index 873699b2d..84908f1e8 100644 --- a/tests/unit/exceptions/test_exception_handler.py +++ b/tests/unit/exceptions/test_exception_handler.py @@ -33,11 +33,9 @@ async def test_base_async_exception_handler_defaults_to_passthrough() -> None: def test_sync_exception_handlers_inherit_shared_base() -> None: """Representative sync handlers should inherit the shared base.""" from sqlspec.adapters.bigquery.driver import BigQueryExceptionHandler - from sqlspec.adapters.mock.driver import MockExceptionHandler from sqlspec.adapters.sqlite.driver import SqliteExceptionHandler assert issubclass(BigQueryExceptionHandler, BaseSyncExceptionHandler) - assert issubclass(MockExceptionHandler, BaseSyncExceptionHandler) assert issubclass(SqliteExceptionHandler, BaseSyncExceptionHandler) @@ -45,11 +43,9 @@ def test_async_exception_handlers_inherit_shared_base() -> None: """Representative async handlers should inherit the shared base.""" from sqlspec.adapters.aiosqlite.driver import AiosqliteExceptionHandler from sqlspec.adapters.asyncpg.driver import AsyncpgExceptionHandler - from sqlspec.adapters.mock.driver import MockAsyncExceptionHandler assert issubclass(AiosqliteExceptionHandler, BaseAsyncExceptionHandler) assert issubclass(AsyncpgExceptionHandler, BaseAsyncExceptionHandler) - assert issubclass(MockAsyncExceptionHandler, BaseAsyncExceptionHandler) def test_duckdb_exception_handler_maps_any_present_exception(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/tests/unit/migrations/test_migration_context.py b/tests/unit/migrations/test_migration_context.py index 7f54cf5a1..b83f822b8 100644 --- a/tests/unit/migrations/test_migration_context.py +++ b/tests/unit/migrations/test_migration_context.py @@ -3,6 +3,8 @@ import asyncio from unittest.mock import Mock +import pytest + from sqlspec.adapters.psycopg.config import PsycopgSyncConfig from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.migrations.context import MigrationContext @@ -115,20 +117,15 @@ async def async_migration() -> list[str]: context.validate_async_usage(async_migration) -def test_validate_async_usage_with_sync_function() -> None: +@pytest.mark.anyio +async def test_validate_async_usage_with_sync_function(aiosqlite_async_driver) -> None: """Test sync function validation in async context.""" context = MigrationContext() def sync_migration() -> list[str]: return ["CREATE TABLE test (id INT);"] - mock_async_driver = Mock() - - async def mock_execute() -> None: - return None - - mock_async_driver.execute_script = mock_execute - context.driver = mock_async_driver + context.driver = aiosqlite_async_driver context.validate_async_usage(sync_migration) assert context.get_execution_metadata("mixed_execution") is True From 91c5b808b0606350c0805129eab21cb6080eb4d6 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 22 Apr 2026 21:24:11 +0000 Subject: [PATCH 8/8] test(core): fix type errors and LRUCache usage in unit tests --- tests/unit/conftest.py | 4 ++-- tests/unit/driver/test_query_cache.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4a95a9f6c..1df1a3e35 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -189,7 +189,7 @@ def statement_config_mysql(parameter_style_config_basic: ParameterStyleConfig) - @pytest.fixture def cache_config_enabled() -> LRUCache: """Cache configuration with caching enabled.""" - return LRUCache(capacity=100) + return LRUCache(max_size=100) @pytest.fixture @@ -201,7 +201,7 @@ def cache_config_disabled() -> None: @pytest.fixture def mock_lru_cache() -> LRUCache: """Mock LRU cache for testing cache operations.""" - return LRUCache(capacity=10) + return LRUCache(max_size=10) @pytest.fixture diff --git a/tests/unit/driver/test_query_cache.py b/tests/unit/driver/test_query_cache.py index a4dfee5a5..93e03f69f 100644 --- a/tests/unit/driver/test_query_cache.py +++ b/tests/unit/driver/test_query_cache.py @@ -6,7 +6,7 @@ import pytest -from sqlspec.core import SQL, OperationProfile, ParameterProfile, ProcessedState +from sqlspec.core import SQL, OperationProfile, OperationType, ParameterProfile, ProcessedState from sqlspec.driver._query_cache import CachedQuery from sqlspec.exceptions import SQLSpecError @@ -14,7 +14,7 @@ def _make_cached( compiled_sql: str = "SELECT 1", param_count: int = 0, - operation_type: str = "SELECT", + operation_type: OperationType = "SELECT", column_names: list[str] | None = None, operation_profile: OperationProfile | None = None, processed_state: ProcessedState | None = None,