From c716237fa7cee9fe4fd84dfba06ea4f7d107e5cf Mon Sep 17 00:00:00 2001 From: Robert Date: Thu, 23 Apr 2026 13:44:31 +0200 Subject: [PATCH 1/3] feat: Reuse SQLAlchemy engine across execute_sql calls in a cell (SAL-51) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cache the SQLAlchemy engine per IPython cell execution so that two back-to-back `_dntk.execute_sql*` calls in the same generated cell share the same physical DBAPI connection. This lets Snowflake session state set by `USE WAREHOUSE`, `USE ROLE`, `SET ...` etc. persist onto the next query, which unblocks supporting `USE WAREHOUSE abc; SELECT 123` in Snowflake SQL blocks. Cached engines are created with `pool_size=1, max_overflow=0` so `engine.begin()` always checks out the same connection. Disposed via a `post_run_cell` IPython hook. Caching is skipped for SSH-tunneled engines, user-supplied custom pool config, and when no IPython shell is available (script/CLI use) — all of those fall back to the previous create-then-dispose-per-call behavior. Co-Authored-By: Claude Opus 4.7 (1M context) --- deepnote_toolkit/ipython_utils.py | 21 +++ deepnote_toolkit/sql/sql_execution.py | 98 +++++++++++- tests/unit/test_ipython_utils.py | 48 ++++++ tests/unit/test_sql_execution_internal.py | 172 ++++++++++++++++++++++ 4 files changed, 333 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_ipython_utils.py diff --git a/deepnote_toolkit/ipython_utils.py b/deepnote_toolkit/ipython_utils.py index a789985b..f1c9968b 100644 --- a/deepnote_toolkit/ipython_utils.py +++ b/deepnote_toolkit/ipython_utils.py @@ -3,6 +3,8 @@ # also defined in https://github.com/deepnote/deepnote/blob/a9f36659f50c84bd85aeba8ee2d3d4458f2f4998/libs/shared/src/constants.ts#L47 DEEPNOTE_SQL_METADATA_MIME_TYPE = "application/vnd.deepnote.sql-output-metadata+json" +_registered_post_run_cell_callbacks: set = set() + def output_display_data(mime_bundle): """ @@ -13,6 +15,25 @@ def output_display_data(mime_bundle): get_ipython().display_pub.publish(data=mime_bundle) +def register_post_run_cell_hook(callback) -> bool: + """Register *callback* to run after each IPython cell. Idempotent per callback. + + Returns True when the callback is registered (or already was), False when there + is no active IPython shell so the caller can fall back to non-cached behavior. + """ + + ip = get_ipython() + if ip is None: + return False + + if callback in _registered_post_run_cell_callbacks: + return True + + ip.events.register("post_run_cell", callback) + _registered_post_run_cell_callbacks.add(callback) + return True + + def output_sql_metadata(metadata: dict): """ Outputs SQL metadata to the notebook. Used for e.g. reporting on hit/miss of a SQL cache. or reporting the compiled query diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 07b61fe2..1b917dda 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -28,7 +28,10 @@ get_absolute_userpod_api_url, get_project_auth_headers, ) -from deepnote_toolkit.ipython_utils import output_sql_metadata +from deepnote_toolkit.ipython_utils import ( + output_sql_metadata, + register_post_run_cell_hook, +) from deepnote_toolkit.logging import LoggerManager from deepnote_toolkit.ocelots.pandas.utils import deduplicate_columns, is_large_number from deepnote_toolkit.sql.duckdb_sql import execute_duckdb_sql @@ -450,6 +453,91 @@ def _execute_sql_with_caching( ) +# Engines created during a single IPython cell are cached here so that +# multiple _dntk.execute_sql* calls in the same generated cell share the same +# physical DBAPI connection. This is what allows Snowflake session state set +# by `USE WAREHOUSE`, `USE ROLE`, `SET ...`, etc. to persist between separate +# execute_sql calls inside the same cell. Disposed via a post_run_cell hook. +_per_cell_engine_cache: dict[str, Any] = {} +_post_run_cell_hook_registered = False + + +def _compute_engine_cache_key(sql_alchemy_dict) -> Optional[str]: + """Stable key identifying the engine for *sql_alchemy_dict*, or None when + this connection should not be cached. + + SSH-tunneled engines can't be cached because the tunnel only lives inside + `_create_sql_ssh_uri`'s context manager. Connections that already specify + a custom pool config in user `params` are also skipped to avoid clashing + with the `pool_size=1` we set on cached engines. + """ + if sql_alchemy_dict.get("ssh_options", {}).get("enabled"): + return None + + params = sql_alchemy_dict.get("params") or {} + if "pool_size" in params or "max_overflow" in params or "poolclass" in params: + return None + + integration_id = sql_alchemy_dict.get("integration_id") + if integration_id: + return f"integration:{integration_id}" + + return f"url:{sql_alchemy_dict['url']}" + + +def _dispose_cell_engines(*_args, **_kwargs) -> None: + """post_run_cell callback: dispose every engine cached during the cell.""" + while _per_cell_engine_cache: + _, engine = _per_cell_engine_cache.popitem() + try: + engine.dispose() + except Exception: + logger.warning("Error disposing cached SQL engine", exc_info=True) + + +def _acquire_engine(sql_alchemy_dict, url) -> tuple[Any, bool]: + """Return ``(engine, owns_engine)``. + + When ``owns_engine`` is False the engine is owned by the per-cell cache + and the caller must not dispose it; the post_run_cell hook will. When True + the caller must dispose, matching the previous per-call behavior. + """ + cache_key = _compute_engine_cache_key(sql_alchemy_dict) + + if cache_key is not None: + cached = _per_cell_engine_cache.get(cache_key) + if cached is not None: + return cached, False + + should_cache = False + if cache_key is not None: + global _post_run_cell_hook_registered + if _post_run_cell_hook_registered: + should_cache = True + elif register_post_run_cell_hook(_dispose_cell_engines): + _post_run_cell_hook_registered = True + should_cache = True + # else: no IPython available — leave should_cache False so we don't + # leak engines in script/CLI contexts that never fire post_run_cell. + + extra_engine_params: dict[str, Any] = {"pool_pre_ping": True} + if should_cache: + # pool_size=1 + max_overflow=0 makes engine.begin() always check out + # the same physical DBAPI connection across calls within the cell, + # which is what makes session state (USE WAREHOUSE, ...) persist. + extra_engine_params["pool_size"] = 1 + extra_engine_params["max_overflow"] = 0 + + with suppress_third_party_deprecation_warnings(): + engine = create_engine(url, **sql_alchemy_dict["params"], **extra_engine_params) + + if not should_cache: + return engine, True + + _per_cell_engine_cache[cache_key] = engine + return engine, False + + @contextlib.contextmanager def suppress_third_party_deprecation_warnings(): """Suppress known deprecation warnings from third-party SQL packages. @@ -485,10 +573,7 @@ def _query_data_source( if url is None: url = sql_alchemy_dict["url"] - with suppress_third_party_deprecation_warnings(): - engine = create_engine( - url, **sql_alchemy_dict["params"], pool_pre_ping=True - ) + engine, owns_engine = _acquire_engine(sql_alchemy_dict, url) try: dataframe = _execute_sql_on_engine(engine, query, bind_params) @@ -524,7 +609,8 @@ def _query_data_source( return dataframe finally: - engine.dispose() + if owns_engine: + engine.dispose() class CursorTrackingDBAPIConnection(wrapt.ObjectProxy): diff --git a/tests/unit/test_ipython_utils.py b/tests/unit/test_ipython_utils.py new file mode 100644 index 00000000..34edd531 --- /dev/null +++ b/tests/unit/test_ipython_utils.py @@ -0,0 +1,48 @@ +from unittest import mock + +from deepnote_toolkit import ipython_utils + + +def _reset_registered_callbacks(): + ipython_utils._registered_post_run_cell_callbacks.clear() + + +def test_register_post_run_cell_hook_returns_false_without_ipython(): + _reset_registered_callbacks() + callback = mock.Mock() + + with mock.patch("deepnote_toolkit.ipython_utils.get_ipython", return_value=None): + result = ipython_utils.register_post_run_cell_hook(callback) + + assert result is False + assert callback not in ipython_utils._registered_post_run_cell_callbacks + + +def test_register_post_run_cell_hook_registers_with_ipython(): + _reset_registered_callbacks() + callback = mock.Mock() + fake_ipython = mock.Mock() + + with mock.patch( + "deepnote_toolkit.ipython_utils.get_ipython", return_value=fake_ipython + ): + result = ipython_utils.register_post_run_cell_hook(callback) + + assert result is True + fake_ipython.events.register.assert_called_once_with("post_run_cell", callback) + assert callback in ipython_utils._registered_post_run_cell_callbacks + + +def test_register_post_run_cell_hook_is_idempotent_for_same_callback(): + _reset_registered_callbacks() + callback = mock.Mock() + fake_ipython = mock.Mock() + + with mock.patch( + "deepnote_toolkit.ipython_utils.get_ipython", return_value=fake_ipython + ): + ipython_utils.register_post_run_cell_hook(callback) + result = ipython_utils.register_post_run_cell_hook(callback) + + assert result is True + fake_ipython.events.register.assert_called_once_with("post_run_cell", callback) diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index ad4ec003..b035ad40 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -377,6 +377,178 @@ def test_create_sql_ssh_uri_no_ssh(): assert url is None +@pytest.fixture +def reset_engine_cache(): + """Reset the per-cell engine cache and registration flag around a test.""" + se._per_cell_engine_cache.clear() + original_registered = se._post_run_cell_hook_registered + se._post_run_cell_hook_registered = False + try: + yield + finally: + se._per_cell_engine_cache.clear() + se._post_run_cell_hook_registered = original_registered + + +def _make_sql_alchemy_dict(integration_id="integration_a", url=None, params=None): + return { + "url": url or "postgresql://u:p@localhost:5432/db", + "params": params if params is not None else {}, + "param_style": "qmark", + "integration_id": integration_id, + } + + +def test_acquire_engine_caches_engine_when_ipython_available(reset_engine_cache): + sql_alchemy_dict = _make_sql_alchemy_dict() + fake_engine = mock.Mock() + + with ( + mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", return_value=fake_engine + ) as create_engine_mock, + mock.patch( + "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", + return_value=True, + ), + ): + engine_a, owns_a = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) + engine_b, owns_b = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) + + assert engine_a is fake_engine + assert engine_b is fake_engine + assert owns_a is False + assert owns_b is False + assert create_engine_mock.call_count == 1 + # Cached engines must use a single-connection pool so engine.begin() returns + # the same physical DBAPI connection across calls — that is what makes + # session state (USE WAREHOUSE, ...) persist between execute_sql calls. + create_engine_kwargs = create_engine_mock.call_args.kwargs + assert create_engine_kwargs["pool_size"] == 1 + assert create_engine_kwargs["max_overflow"] == 0 + assert create_engine_kwargs["pool_pre_ping"] is True + + +def test_acquire_engine_skips_cache_without_ipython(reset_engine_cache): + sql_alchemy_dict = _make_sql_alchemy_dict() + + with ( + mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() + ) as create_engine_mock, + mock.patch( + "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", + return_value=False, + ), + ): + engine_a, owns_a = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) + engine_b, owns_b = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) + + assert owns_a is True + assert owns_b is True + # Caller is responsible for disposal so create_engine runs each call. + assert create_engine_mock.call_count == 2 + assert se._per_cell_engine_cache == {} + # Without caching we don't impose pool_size on the user. + create_engine_kwargs = create_engine_mock.call_args.kwargs + assert "pool_size" not in create_engine_kwargs + assert "max_overflow" not in create_engine_kwargs + assert create_engine_kwargs["pool_pre_ping"] is True + + +def test_acquire_engine_different_integrations_get_separate_engines(reset_engine_cache): + dict_a = _make_sql_alchemy_dict(integration_id="int_a") + dict_b = _make_sql_alchemy_dict(integration_id="int_b") + engines = [mock.Mock(name="engine_a"), mock.Mock(name="engine_b")] + + with ( + mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", side_effect=engines + ), + mock.patch( + "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", + return_value=True, + ), + ): + engine_a, _ = se._acquire_engine(dict_a, dict_a["url"]) + engine_b, _ = se._acquire_engine(dict_b, dict_b["url"]) + engine_a_again, _ = se._acquire_engine(dict_a, dict_a["url"]) + + assert engine_a is engines[0] + assert engine_b is engines[1] + assert engine_a_again is engines[0] + + +def test_acquire_engine_skips_cache_for_ssh_tunnel(reset_engine_cache): + sql_alchemy_dict = _make_sql_alchemy_dict() + sql_alchemy_dict["ssh_options"] = { + "enabled": True, + "host": "h", + "port": 22, + "user": "u", + } + + with ( + mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() + ), + mock.patch( + "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", + return_value=True, + ), + ): + _, owns = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) + + assert owns is True + assert se._per_cell_engine_cache == {} + + +def test_acquire_engine_skips_cache_when_user_provides_pool_config(reset_engine_cache): + sql_alchemy_dict = _make_sql_alchemy_dict(params={"pool_size": 5}) + + with ( + mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() + ), + mock.patch( + "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", + return_value=True, + ), + ): + _, owns = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) + + assert owns is True + assert se._per_cell_engine_cache == {} + + +def test_dispose_cell_engines_disposes_and_clears(reset_engine_cache): + fake_engine_a = mock.Mock() + fake_engine_b = mock.Mock() + se._per_cell_engine_cache["a"] = fake_engine_a + se._per_cell_engine_cache["b"] = fake_engine_b + + se._dispose_cell_engines() + + fake_engine_a.dispose.assert_called_once() + fake_engine_b.dispose.assert_called_once() + assert se._per_cell_engine_cache == {} + + +def test_dispose_cell_engines_swallows_individual_failures(reset_engine_cache): + failing_engine = mock.Mock() + failing_engine.dispose.side_effect = RuntimeError("boom") + healthy_engine = mock.Mock() + se._per_cell_engine_cache["fail"] = failing_engine + se._per_cell_engine_cache["ok"] = healthy_engine + + with mock.patch.object(se.logger, "warning") as mock_warning: + se._dispose_cell_engines() + + healthy_engine.dispose.assert_called_once() + assert se._per_cell_engine_cache == {} + mock_warning.assert_called_once() + + def test_create_sql_ssh_uri_missing_key(monkeypatch): def fake_get_env(name, default=None): if name == "PRIVATE_SSH_KEY_BLOB": From 6762926098978dc38c747b7447bc10cb689cfd16 Mon Sep 17 00:00:00 2001 From: Robert Date: Thu, 23 Apr 2026 13:56:47 +0200 Subject: [PATCH 2/3] refactor: Replace per-cell engine cache with explicit sql_session() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drops the implicit cache + IPython post_run_cell hook in favor of an explicit `_dntk.sql_session()` context manager. Inside the with-block each unique connection (keyed by integration_id, or URL when not present) opens its SSH tunnel and SQLAlchemy engine once and reuses both for every `execute_sql*` call inside the block. On exit the session disposes the engine and closes the tunnel together. This makes SSH tunnels work with SAL-51 reuse — the previous cache had to skip SSH because the tunnel's lifetime was bound to a single call by `_create_sql_ssh_uri`'s context manager. With the session owning both, the tunnel's lifetime now matches the engine's. Other gains: - No IPython coupling. Lifetime is whatever the caller (deepnote- internal codegen) wraps in the `with` block. - No "magic" cache key vs. live state divergence — outside a session behavior is exactly the previous per-call create+dispose. - Nested `sql_session()` blocks no-op on the inner block so the outer retains ownership. Co-Authored-By: Claude Opus 4.7 (1M context) --- deepnote_toolkit/__init__.py | 1 + deepnote_toolkit/ipython_utils.py | 21 -- deepnote_toolkit/sql/sql_execution.py | 292 ++++++++++++---------- tests/unit/test_ipython_utils.py | 48 ---- tests/unit/test_sql_execution_internal.py | 270 ++++++++++++-------- 5 files changed, 326 insertions(+), 306 deletions(-) delete mode 100644 tests/unit/test_ipython_utils.py diff --git a/deepnote_toolkit/__init__.py b/deepnote_toolkit/__init__.py index 39688789..961ef983 100644 --- a/deepnote_toolkit/__init__.py +++ b/deepnote_toolkit/__init__.py @@ -34,6 +34,7 @@ (".set_notebook_path", "set_notebook_path"), (".sql.sql_execution", "execute_sql"), (".sql.sql_execution", "execute_sql_with_connection_json"), + (".sql.sql_execution", "sql_session"), (".variable_explorer", "deepnote_export_df"), (".variable_explorer", "deepnote_get_data_preview_json"), (".variable_explorer", "get_var_list"), diff --git a/deepnote_toolkit/ipython_utils.py b/deepnote_toolkit/ipython_utils.py index f1c9968b..a789985b 100644 --- a/deepnote_toolkit/ipython_utils.py +++ b/deepnote_toolkit/ipython_utils.py @@ -3,8 +3,6 @@ # also defined in https://github.com/deepnote/deepnote/blob/a9f36659f50c84bd85aeba8ee2d3d4458f2f4998/libs/shared/src/constants.ts#L47 DEEPNOTE_SQL_METADATA_MIME_TYPE = "application/vnd.deepnote.sql-output-metadata+json" -_registered_post_run_cell_callbacks: set = set() - def output_display_data(mime_bundle): """ @@ -15,25 +13,6 @@ def output_display_data(mime_bundle): get_ipython().display_pub.publish(data=mime_bundle) -def register_post_run_cell_hook(callback) -> bool: - """Register *callback* to run after each IPython cell. Idempotent per callback. - - Returns True when the callback is registered (or already was), False when there - is no active IPython shell so the caller can fall back to non-cached behavior. - """ - - ip = get_ipython() - if ip is None: - return False - - if callback in _registered_post_run_cell_callbacks: - return True - - ip.events.register("post_run_cell", callback) - _registered_post_run_cell_callbacks.add(callback) - return True - - def output_sql_metadata(metadata: dict): """ Outputs SQL metadata to the notebook. Used for e.g. reporting on hit/miss of a SQL cache. or reporting the compiled query diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 1b917dda..afa79971 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -1,5 +1,6 @@ import base64 import contextlib +import contextvars import json import re import uuid @@ -28,10 +29,7 @@ get_absolute_userpod_api_url, get_project_auth_headers, ) -from deepnote_toolkit.ipython_utils import ( - output_sql_metadata, - register_post_run_cell_hook, -) +from deepnote_toolkit.ipython_utils import output_sql_metadata from deepnote_toolkit.logging import LoggerManager from deepnote_toolkit.ocelots.pandas.utils import deduplicate_columns, is_large_number from deepnote_toolkit.sql.duckdb_sql import execute_duckdb_sql @@ -364,40 +362,53 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: ) +def _open_ssh_tunnel(sql_alchemy_dict) -> tuple[Any, Any]: + """Open an SSH tunnel for *sql_alchemy_dict* and return ``(server, url)``. + + Caller is responsible for eventually closing the returned server with + ``_close_ssh_tunnel``. + """ + base64_encoded_key = dnenv.get_env("PRIVATE_SSH_KEY_BLOB") + if not base64_encoded_key: + raise Exception( + "The private key needed to establish the SSH connection is missing. Please try again or contact support." + ) + original_url = make_url(sql_alchemy_dict["url"]) + server = create_ssh_tunnel( + ssh_host=sql_alchemy_dict["ssh_options"]["host"], + ssh_port=int(sql_alchemy_dict["ssh_options"]["port"]), + ssh_user=sql_alchemy_dict["ssh_options"]["user"], + remote_host=original_url.host, + remote_port=int(original_url.port), + private_key=base64.b64decode(base64_encoded_key).decode("utf-8"), + ) + url = URL.create( + drivername=original_url.drivername, + username=original_url.username, + password=original_url.password, + host=server.local_bind_host, + port=server.local_bind_port, + database=original_url.database, + query=original_url.query, + ) + return server, url + + +def _close_ssh_tunnel(server) -> None: + if server is not None and server.is_active: + server.close() + + @contextlib.contextmanager def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict): - server = None - if ssh_enabled: - base64_encoded_key = dnenv.get_env("PRIVATE_SSH_KEY_BLOB") - if not base64_encoded_key: - raise Exception( - "The private key needed to establish the SSH connection is missing. Please try again or contact support." - ) - original_url = make_url(sql_alchemy_dict["url"]) - try: - server = create_ssh_tunnel( - ssh_host=sql_alchemy_dict["ssh_options"]["host"], - ssh_port=int(sql_alchemy_dict["ssh_options"]["port"]), - ssh_user=sql_alchemy_dict["ssh_options"]["user"], - remote_host=original_url.host, - remote_port=int(original_url.port), - private_key=base64.b64decode(base64_encoded_key).decode("utf-8"), - ) - url = URL.create( - drivername=original_url.drivername, - username=original_url.username, - password=original_url.password, - host=server.local_bind_host, - port=server.local_bind_port, - database=original_url.database, - query=original_url.query, - ) - yield url - finally: - if server is not None and server.is_active: - server.close() - else: + if not ssh_enabled: yield None + return + server, url = _open_ssh_tunnel(sql_alchemy_dict) + try: + yield url + finally: + _close_ssh_tunnel(server) def _execute_sql_with_caching( @@ -453,27 +464,19 @@ def _execute_sql_with_caching( ) -# Engines created during a single IPython cell are cached here so that -# multiple _dntk.execute_sql* calls in the same generated cell share the same -# physical DBAPI connection. This is what allows Snowflake session state set -# by `USE WAREHOUSE`, `USE ROLE`, `SET ...`, etc. to persist between separate -# execute_sql calls inside the same cell. Disposed via a post_run_cell hook. -_per_cell_engine_cache: dict[str, Any] = {} -_post_run_cell_hook_registered = False - +# Active sql_session() registry, or None when no session is active. +# Maps a per-connection key to a (engine, ssh_server) bundle owned by the +# session: subsequent execute_sql* calls inside the with-block reuse them. +_session_registry: contextvars.ContextVar[Optional[dict[str, tuple[Any, Any]]]] = ( + contextvars.ContextVar("deepnote_sql_session_registry", default=None) +) -def _compute_engine_cache_key(sql_alchemy_dict) -> Optional[str]: - """Stable key identifying the engine for *sql_alchemy_dict*, or None when - this connection should not be cached. - SSH-tunneled engines can't be cached because the tunnel only lives inside - `_create_sql_ssh_uri`'s context manager. Connections that already specify - a custom pool config in user `params` are also skipped to avoid clashing - with the `pool_size=1` we set on cached engines. +def _session_resource_key(sql_alchemy_dict) -> Optional[str]: + """Key under which to share session resources for this connection, or None + when sharing is unsafe (e.g. user already specifies a custom pool config + that would clash with our ``pool_size=1``). """ - if sql_alchemy_dict.get("ssh_options", {}).get("enabled"): - return None - params = sql_alchemy_dict.get("params") or {} if "pool_size" in params or "max_overflow" in params or "poolclass" in params: return None @@ -485,57 +488,97 @@ def _compute_engine_cache_key(sql_alchemy_dict) -> Optional[str]: return f"url:{sql_alchemy_dict['url']}" -def _dispose_cell_engines(*_args, **_kwargs) -> None: - """post_run_cell callback: dispose every engine cached during the cell.""" - while _per_cell_engine_cache: - _, engine = _per_cell_engine_cache.popitem() - try: - engine.dispose() - except Exception: - logger.warning("Error disposing cached SQL engine", exc_info=True) +@contextlib.contextmanager +def sql_session(): + """Share SQLAlchemy engines, SSH tunnels, and connections across all + ``execute_sql*`` calls inside the with-block. + + Inside the session each unique connection (keyed by ``integration_id``, + or URL when not present) opens its SSH tunnel and SQLAlchemy engine + once; subsequent calls reuse them. Engines are created with + ``pool_size=1, max_overflow=0`` so ``engine.begin()`` always checks out + the same physical DBAPI connection — that is what makes Snowflake + session state set by ``USE WAREHOUSE``, ``USE ROLE``, ``SET ...`` + persist across multiple ``execute_sql`` calls in a single SQL block. + + Outside a session ``execute_sql*`` falls back to its previous per-call + behavior of opening and disposing everything for each invocation. + + Nested ``sql_session()`` blocks are no-ops; only the outermost block + owns and tears down the resources. + """ + if _session_registry.get() is not None: + yield + return + + registry: dict[str, tuple[Any, Any]] = {} + token = _session_registry.set(registry) + try: + yield + finally: + _session_registry.reset(token) + for engine, ssh_server in registry.values(): + try: + engine.dispose() + except Exception: + logger.warning( + "Error disposing SQL engine on session exit", exc_info=True + ) + try: + _close_ssh_tunnel(ssh_server) + except Exception: + logger.warning( + "Error closing SSH tunnel on session exit", exc_info=True + ) + registry.clear() -def _acquire_engine(sql_alchemy_dict, url) -> tuple[Any, bool]: - """Return ``(engine, owns_engine)``. +def _acquire_engine(sql_alchemy_dict) -> tuple[Any, Optional[Any], bool]: + """Return ``(engine, ssh_server, owns_resources)``. - When ``owns_engine`` is False the engine is owned by the per-cell cache - and the caller must not dispose it; the post_run_cell hook will. When True - the caller must dispose, matching the previous per-call behavior. + Inside an active ``sql_session()`` for a shareable connection, returns + cached resources (creating them on first use); ``owns_resources`` is + False and the caller must not dispose. Otherwise opens fresh resources + and ``owns_resources`` is True — caller must dispose the engine and + close the tunnel. """ - cache_key = _compute_engine_cache_key(sql_alchemy_dict) + registry = _session_registry.get() + key = _session_resource_key(sql_alchemy_dict) if registry is not None else None + in_session = registry is not None and key is not None - if cache_key is not None: - cached = _per_cell_engine_cache.get(cache_key) + if in_session: + cached = registry.get(key) if cached is not None: - return cached, False - - should_cache = False - if cache_key is not None: - global _post_run_cell_hook_registered - if _post_run_cell_hook_registered: - should_cache = True - elif register_post_run_cell_hook(_dispose_cell_engines): - _post_run_cell_hook_registered = True - should_cache = True - # else: no IPython available — leave should_cache False so we don't - # leak engines in script/CLI contexts that never fire post_run_cell. + engine, ssh_server = cached + return engine, ssh_server, False + + ssh_server: Optional[Any] = None + if sql_alchemy_dict.get("ssh_options", {}).get("enabled"): + ssh_server, url = _open_ssh_tunnel(sql_alchemy_dict) + else: + url = sql_alchemy_dict["url"] extra_engine_params: dict[str, Any] = {"pool_pre_ping": True} - if should_cache: - # pool_size=1 + max_overflow=0 makes engine.begin() always check out - # the same physical DBAPI connection across calls within the cell, - # which is what makes session state (USE WAREHOUSE, ...) persist. + if in_session: + # See sql_session() docstring for why pool_size=1 is required to make + # session state (USE WAREHOUSE, ...) persist across calls. extra_engine_params["pool_size"] = 1 extra_engine_params["max_overflow"] = 0 - with suppress_third_party_deprecation_warnings(): - engine = create_engine(url, **sql_alchemy_dict["params"], **extra_engine_params) + try: + with suppress_third_party_deprecation_warnings(): + engine = create_engine( + url, **sql_alchemy_dict["params"], **extra_engine_params + ) + except Exception: + _close_ssh_tunnel(ssh_server) + raise - if not should_cache: - return engine, True + if in_session: + registry[key] = (engine, ssh_server) + return engine, ssh_server, False - _per_cell_engine_cache[cache_key] = engine - return engine, False + return engine, ssh_server, True @contextlib.contextmanager @@ -567,50 +610,45 @@ def _query_data_source( return_variable_type, query_preview_source, ): - sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False) + engine, ssh_server, owns_resources = _acquire_engine(sql_alchemy_dict) - with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url: - if url is None: - url = sql_alchemy_dict["url"] + try: + dataframe = _execute_sql_on_engine(engine, query, bind_params) - engine, owns_engine = _acquire_engine(sql_alchemy_dict, url) + if dataframe is None: + return None - try: - dataframe = _execute_sql_on_engine(engine, query, bind_params) - - if dataframe is None: - return None - - # sanitize dataframe so that we can safely call .to_parquet on it - _sanitize_dataframe_for_parquet(dataframe) - - dataframe_size_in_bytes = int(dataframe.memory_usage(deep=True).sum()) - output_sql_metadata( - { - "status": "success_no_cache", - "size_in_bytes": dataframe_size_in_bytes, - "compiled_query": query, - "variable_type": return_variable_type, - "integration_id": sql_alchemy_dict.get("integration_id"), - } - ) + # sanitize dataframe so that we can safely call .to_parquet on it + _sanitize_dataframe_for_parquet(dataframe) + + dataframe_size_in_bytes = int(dataframe.memory_usage(deep=True).sum()) + output_sql_metadata( + { + "status": "success_no_cache", + "size_in_bytes": dataframe_size_in_bytes, + "compiled_query": query, + "variable_type": return_variable_type, + "integration_id": sql_alchemy_dict.get("integration_id"), + } + ) - # for Chained SQL we return the dataframe with the SQL source attached as DeepnoteQueryPreview object - if return_variable_type == "query_preview": - return _convert_dataframe_to_query_preview( - dataframe, query_preview_source - ) + # for Chained SQL we return the dataframe with the SQL source attached as DeepnoteQueryPreview object + if return_variable_type == "query_preview": + return _convert_dataframe_to_query_preview(dataframe, query_preview_source) - # if df is larger than 5GB, don't upload it. See NB-988 - dataframe_is_cacheable = dataframe_size_in_bytes < 5 * 1024 * 1024 * 1024 + # if df is larger than 5GB, don't upload it. See NB-988 + dataframe_is_cacheable = dataframe_size_in_bytes < 5 * 1024 * 1024 * 1024 - if cache_upload_url is not None and dataframe_is_cacheable: - upload_sql_cache(dataframe, cache_upload_url) + if cache_upload_url is not None and dataframe_is_cacheable: + upload_sql_cache(dataframe, cache_upload_url) - return dataframe - finally: - if owns_engine: + return dataframe + finally: + if owns_resources: + try: engine.dispose() + finally: + _close_ssh_tunnel(ssh_server) class CursorTrackingDBAPIConnection(wrapt.ObjectProxy): diff --git a/tests/unit/test_ipython_utils.py b/tests/unit/test_ipython_utils.py deleted file mode 100644 index 34edd531..00000000 --- a/tests/unit/test_ipython_utils.py +++ /dev/null @@ -1,48 +0,0 @@ -from unittest import mock - -from deepnote_toolkit import ipython_utils - - -def _reset_registered_callbacks(): - ipython_utils._registered_post_run_cell_callbacks.clear() - - -def test_register_post_run_cell_hook_returns_false_without_ipython(): - _reset_registered_callbacks() - callback = mock.Mock() - - with mock.patch("deepnote_toolkit.ipython_utils.get_ipython", return_value=None): - result = ipython_utils.register_post_run_cell_hook(callback) - - assert result is False - assert callback not in ipython_utils._registered_post_run_cell_callbacks - - -def test_register_post_run_cell_hook_registers_with_ipython(): - _reset_registered_callbacks() - callback = mock.Mock() - fake_ipython = mock.Mock() - - with mock.patch( - "deepnote_toolkit.ipython_utils.get_ipython", return_value=fake_ipython - ): - result = ipython_utils.register_post_run_cell_hook(callback) - - assert result is True - fake_ipython.events.register.assert_called_once_with("post_run_cell", callback) - assert callback in ipython_utils._registered_post_run_cell_callbacks - - -def test_register_post_run_cell_hook_is_idempotent_for_same_callback(): - _reset_registered_callbacks() - callback = mock.Mock() - fake_ipython = mock.Mock() - - with mock.patch( - "deepnote_toolkit.ipython_utils.get_ipython", return_value=fake_ipython - ): - ipython_utils.register_post_run_cell_hook(callback) - result = ipython_utils.register_post_run_cell_hook(callback) - - assert result is True - fake_ipython.events.register.assert_called_once_with("post_run_cell", callback) diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index b035ad40..41d36d0d 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -377,19 +377,6 @@ def test_create_sql_ssh_uri_no_ssh(): assert url is None -@pytest.fixture -def reset_engine_cache(): - """Reset the per-cell engine cache and registration flag around a test.""" - se._per_cell_engine_cache.clear() - original_registered = se._post_run_cell_hook_registered - se._post_run_cell_hook_registered = False - try: - yield - finally: - se._per_cell_engine_cache.clear() - se._post_run_cell_hook_registered = original_registered - - def _make_sql_alchemy_dict(integration_id="integration_a", url=None, params=None): return { "url": url or "postgresql://u:p@localhost:5432/db", @@ -399,87 +386,75 @@ def _make_sql_alchemy_dict(integration_id="integration_a", url=None, params=None } -def test_acquire_engine_caches_engine_when_ipython_available(reset_engine_cache): +def test_acquire_engine_outside_session_owns_resources(): + sql_alchemy_dict = _make_sql_alchemy_dict() + + with mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() + ) as create_engine_mock: + engine_a, ssh_a, owns_a = se._acquire_engine(sql_alchemy_dict) + engine_b, ssh_b, owns_b = se._acquire_engine(sql_alchemy_dict) + + assert owns_a is True + assert owns_b is True + assert ssh_a is None + assert ssh_b is None + # Two calls outside a session create two engines; caller disposes each. + assert create_engine_mock.call_count == 2 + # Outside a session we don't impose pool_size on the user. + kwargs = create_engine_mock.call_args.kwargs + assert "pool_size" not in kwargs + assert "max_overflow" not in kwargs + assert kwargs["pool_pre_ping"] is True + + +def test_sql_session_reuses_engine_within_block(): sql_alchemy_dict = _make_sql_alchemy_dict() fake_engine = mock.Mock() - with ( - mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", return_value=fake_engine - ) as create_engine_mock, - mock.patch( - "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", - return_value=True, - ), - ): - engine_a, owns_a = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) - engine_b, owns_b = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) + with mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", return_value=fake_engine + ) as create_engine_mock: + with se.sql_session(): + engine_a, _, owns_a = se._acquire_engine(sql_alchemy_dict) + engine_b, _, owns_b = se._acquire_engine(sql_alchemy_dict) assert engine_a is fake_engine assert engine_b is fake_engine assert owns_a is False assert owns_b is False assert create_engine_mock.call_count == 1 - # Cached engines must use a single-connection pool so engine.begin() returns - # the same physical DBAPI connection across calls — that is what makes - # session state (USE WAREHOUSE, ...) persist between execute_sql calls. - create_engine_kwargs = create_engine_mock.call_args.kwargs - assert create_engine_kwargs["pool_size"] == 1 - assert create_engine_kwargs["max_overflow"] == 0 - assert create_engine_kwargs["pool_pre_ping"] is True - - -def test_acquire_engine_skips_cache_without_ipython(reset_engine_cache): - sql_alchemy_dict = _make_sql_alchemy_dict() - - with ( - mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() - ) as create_engine_mock, - mock.patch( - "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", - return_value=False, - ), - ): - engine_a, owns_a = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) - engine_b, owns_b = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) - - assert owns_a is True - assert owns_b is True - # Caller is responsible for disposal so create_engine runs each call. - assert create_engine_mock.call_count == 2 - assert se._per_cell_engine_cache == {} - # Without caching we don't impose pool_size on the user. - create_engine_kwargs = create_engine_mock.call_args.kwargs - assert "pool_size" not in create_engine_kwargs - assert "max_overflow" not in create_engine_kwargs - assert create_engine_kwargs["pool_pre_ping"] is True - - -def test_acquire_engine_different_integrations_get_separate_engines(reset_engine_cache): + # In-session engines must use a single-connection pool so engine.begin() + # returns the same physical DBAPI connection across calls — that is what + # makes session state (USE WAREHOUSE, ...) persist between execute_sql + # calls in the same block. + kwargs = create_engine_mock.call_args.kwargs + assert kwargs["pool_size"] == 1 + assert kwargs["max_overflow"] == 0 + assert kwargs["pool_pre_ping"] is True + # On exit the session disposes the engine it owned. + fake_engine.dispose.assert_called_once() + + +def test_sql_session_separates_engines_per_integration(): dict_a = _make_sql_alchemy_dict(integration_id="int_a") dict_b = _make_sql_alchemy_dict(integration_id="int_b") engines = [mock.Mock(name="engine_a"), mock.Mock(name="engine_b")] - with ( - mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", side_effect=engines - ), - mock.patch( - "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", - return_value=True, - ), + with mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", side_effect=engines ): - engine_a, _ = se._acquire_engine(dict_a, dict_a["url"]) - engine_b, _ = se._acquire_engine(dict_b, dict_b["url"]) - engine_a_again, _ = se._acquire_engine(dict_a, dict_a["url"]) + with se.sql_session(): + engine_a, _, _ = se._acquire_engine(dict_a) + engine_b, _, _ = se._acquire_engine(dict_b) + engine_a_again, _, _ = se._acquire_engine(dict_a) assert engine_a is engines[0] assert engine_b is engines[1] assert engine_a_again is engines[0] -def test_acquire_engine_skips_cache_for_ssh_tunnel(reset_engine_cache): +def test_sql_session_opens_ssh_tunnel_once_and_closes_on_exit(): sql_alchemy_dict = _make_sql_alchemy_dict() sql_alchemy_dict["ssh_options"] = { "enabled": True, @@ -487,66 +462,141 @@ def test_acquire_engine_skips_cache_for_ssh_tunnel(reset_engine_cache): "port": 22, "user": "u", } + fake_engine = mock.Mock() + fake_server = mock.Mock() + fake_server.is_active = True + rewritten_url = "postgresql://u:p@127.0.0.1:65000/db" with ( mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() - ), + "deepnote_toolkit.sql.sql_execution._open_ssh_tunnel", + return_value=(fake_server, rewritten_url), + ) as open_tunnel, mock.patch( - "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", - return_value=True, - ), + "deepnote_toolkit.sql.sql_execution.create_engine", + return_value=fake_engine, + ) as create_engine_mock, ): - _, owns = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) + with se.sql_session(): + _, ssh_a, owns_a = se._acquire_engine(sql_alchemy_dict) + _, ssh_b, owns_b = se._acquire_engine(sql_alchemy_dict) - assert owns is True - assert se._per_cell_engine_cache == {} + # Tunnel opened once and reused; not torn down between calls. + assert open_tunnel.call_count == 1 + assert create_engine_mock.call_count == 1 + assert ssh_a is fake_server + assert ssh_b is fake_server + assert owns_a is False + assert owns_b is False + # Engine creation uses the rewritten (tunneled) URL. + assert create_engine_mock.call_args.args[0] == rewritten_url + # On session exit both the engine and the tunnel are torn down. + fake_engine.dispose.assert_called_once() + fake_server.close.assert_called_once() -def test_acquire_engine_skips_cache_when_user_provides_pool_config(reset_engine_cache): +def test_sql_session_skips_sharing_when_user_supplies_pool_config(): sql_alchemy_dict = _make_sql_alchemy_dict(params={"pool_size": 5}) + with mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() + ) as create_engine_mock: + with se.sql_session(): + _, _, owns_a = se._acquire_engine(sql_alchemy_dict) + _, _, owns_b = se._acquire_engine(sql_alchemy_dict) + + # Without sharing, caller owns disposal each call and we don't override the + # user's pool config (their pool_size=5 passes through unchanged). + assert owns_a is True + assert owns_b is True + assert create_engine_mock.call_count == 2 + kwargs = create_engine_mock.call_args.kwargs + assert kwargs["pool_size"] == 5 + + +def test_sql_session_swallows_per_resource_teardown_errors(): + sql_alchemy_dict = _make_sql_alchemy_dict() + failing_engine = mock.Mock() + failing_engine.dispose.side_effect = RuntimeError("boom") + failing_server = mock.Mock() + failing_server.is_active = True + failing_server.close.side_effect = RuntimeError("boom") + with ( mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() + "deepnote_toolkit.sql.sql_execution.create_engine", + return_value=failing_engine, ), mock.patch( - "deepnote_toolkit.sql.sql_execution.register_post_run_cell_hook", - return_value=True, + "deepnote_toolkit.sql.sql_execution._open_ssh_tunnel", + return_value=(failing_server, "postgresql://u@127.0.0.1:1/db"), ), + mock.patch.object(se.logger, "warning") as mock_warning, ): - _, owns = se._acquire_engine(sql_alchemy_dict, sql_alchemy_dict["url"]) - - assert owns is True - assert se._per_cell_engine_cache == {} + sql_alchemy_dict["ssh_options"] = { + "enabled": True, + "host": "h", + "port": 22, + "user": "u", + } + with se.sql_session(): + se._acquire_engine(sql_alchemy_dict) + failing_engine.dispose.assert_called_once() + failing_server.close.assert_called_once() + # One warning per failing resource. + assert mock_warning.call_count == 2 -def test_dispose_cell_engines_disposes_and_clears(reset_engine_cache): - fake_engine_a = mock.Mock() - fake_engine_b = mock.Mock() - se._per_cell_engine_cache["a"] = fake_engine_a - se._per_cell_engine_cache["b"] = fake_engine_b - se._dispose_cell_engines() +def test_nested_sql_session_does_not_steal_outer_resources(): + sql_alchemy_dict = _make_sql_alchemy_dict() + fake_engine = mock.Mock() - fake_engine_a.dispose.assert_called_once() - fake_engine_b.dispose.assert_called_once() - assert se._per_cell_engine_cache == {} + with mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", return_value=fake_engine + ) as create_engine_mock: + with se.sql_session(): + engine_outer, _, _ = se._acquire_engine(sql_alchemy_dict) + with se.sql_session(): + engine_inner, _, _ = se._acquire_engine(sql_alchemy_dict) + # Inner block must not have disposed the outer-owned engine. + fake_engine.dispose.assert_not_called() + engine_after_inner, _, _ = se._acquire_engine(sql_alchemy_dict) + + assert engine_outer is fake_engine + assert engine_inner is fake_engine + assert engine_after_inner is fake_engine + assert create_engine_mock.call_count == 1 + # Outer session disposes once on exit. + fake_engine.dispose.assert_called_once() -def test_dispose_cell_engines_swallows_individual_failures(reset_engine_cache): - failing_engine = mock.Mock() - failing_engine.dispose.side_effect = RuntimeError("boom") - healthy_engine = mock.Mock() - se._per_cell_engine_cache["fail"] = failing_engine - se._per_cell_engine_cache["ok"] = healthy_engine +def test_acquire_engine_closes_tunnel_when_engine_creation_fails_outside_session(): + sql_alchemy_dict = _make_sql_alchemy_dict() + sql_alchemy_dict["ssh_options"] = { + "enabled": True, + "host": "h", + "port": 22, + "user": "u", + } + fake_server = mock.Mock() + fake_server.is_active = True - with mock.patch.object(se.logger, "warning") as mock_warning: - se._dispose_cell_engines() + with ( + mock.patch( + "deepnote_toolkit.sql.sql_execution._open_ssh_tunnel", + return_value=(fake_server, "postgresql://u@127.0.0.1:1/db"), + ), + mock.patch( + "deepnote_toolkit.sql.sql_execution.create_engine", + side_effect=RuntimeError("bad url"), + ), + ): + with pytest.raises(RuntimeError): + se._acquire_engine(sql_alchemy_dict) - healthy_engine.dispose.assert_called_once() - assert se._per_cell_engine_cache == {} - mock_warning.assert_called_once() + # Tunnel must not be left open if engine construction blows up. + fake_server.close.assert_called_once() def test_create_sql_ssh_uri_missing_key(monkeypatch): From 029e24544562d253f468f3ab1fb7c3408525f1b0 Mon Sep 17 00:00:00 2001 From: Robert Date: Thu, 23 Apr 2026 14:32:48 +0200 Subject: [PATCH 3/3] refactor: Replace sql_session with setup_statements parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drops the sql_session() context manager and the engine lifecycle changes it required (per-cell registry, pool_size=1, SSH tunnel lifetime extension). Restores the original per-call create_engine + dispose flow. In its place, execute_sql / execute_sql_with_connection_json grow an optional setup_statements: list[str] parameter. When provided, those statements are run via connection.exec_driver_sql on the same DBAPI connection inside the same engine.begin() block as the main query, so session state set by `USE WAREHOUSE`, `USE ROLE`, `SET ...`, etc. is in effect for the main query without changing the engine's lifetime. Single call, single engine, single tunnel, single connection — no shared mutable state across calls and no special-casing for SSH or custom pool config. DuckDB's process-wide singleton means setup statements naturally persist there too; the duckdb branch executes them on the singleton before the main query for API parity. Co-Authored-By: Claude Opus 4.7 (1M context) --- deepnote_toolkit/__init__.py | 1 - deepnote_toolkit/sql/sql_execution.py | 296 ++++++++-------------- tests/unit/test_sql_execution.py | 41 +++ tests/unit/test_sql_execution_internal.py | 277 +++++--------------- 4 files changed, 210 insertions(+), 405 deletions(-) diff --git a/deepnote_toolkit/__init__.py b/deepnote_toolkit/__init__.py index 961ef983..39688789 100644 --- a/deepnote_toolkit/__init__.py +++ b/deepnote_toolkit/__init__.py @@ -34,7 +34,6 @@ (".set_notebook_path", "set_notebook_path"), (".sql.sql_execution", "execute_sql"), (".sql.sql_execution", "execute_sql_with_connection_json"), - (".sql.sql_execution", "sql_session"), (".variable_explorer", "deepnote_export_df"), (".variable_explorer", "deepnote_get_data_preview_json"), (".variable_explorer", "get_var_list"), diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index afa79971..c9e3a537 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -1,6 +1,5 @@ import base64 import contextlib -import contextvars import json import re import uuid @@ -103,6 +102,7 @@ def execute_sql_with_connection_json( audit_sql_comment="", sql_cache_mode="cache_disabled", return_variable_type="dataframe", + setup_statements=None, ): """ Executes a SQL query using the given connection JSON (string). @@ -112,6 +112,11 @@ def execute_sql_with_connection_json( :param sql_alchemy_json: String containing JSON with the connection details. Mandatory fields: url, params, param_style :param sql_cache_mode: SQL caching setting for the query. Possible values: "cache_disabled", "always_write", "read_or_write" + :param setup_statements: Optional list of raw SQL statements to run on the + same connection right before *template*. Use for session setup such as + ``USE WAREHOUSE abc`` whose effect must be visible to the main query. + Statements are not Jinja-rendered, parameter-bound, or audit-commented; + they are executed in order via ``connection.exec_driver_sql``. :return: Pandas dataframe with the result """ @@ -222,6 +227,7 @@ class ExecuteSqlError(Exception): sql_cache_mode, return_variable_type, query_preview_source, + setup_statements=setup_statements, ) @@ -231,6 +237,7 @@ def execute_sql( audit_sql_comment="", sql_cache_mode="cache_disabled", return_variable_type="dataframe", + setup_statements=None, ): """ Wrapper around execute_sql_with_connection_json which reads the connection JSON from @@ -238,6 +245,7 @@ def execute_sql( :param template: Templated SQL :param sql_alchemy_json_env_var: Name of the environment variable containing the connection JSON :param sql_cache_mode: SQL caching setting for the query. Possible values: "cache_disabled", "always_write", "read_or_write" + :param setup_statements: See ``execute_sql_with_connection_json``. :return: Pandas dataframe with the result """ @@ -261,6 +269,7 @@ class ExecuteSqlError(Exception): audit_sql_comment=audit_sql_comment, sql_cache_mode=sql_cache_mode, return_variable_type=return_variable_type, + setup_statements=setup_statements, ) @@ -362,53 +371,40 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: ) -def _open_ssh_tunnel(sql_alchemy_dict) -> tuple[Any, Any]: - """Open an SSH tunnel for *sql_alchemy_dict* and return ``(server, url)``. - - Caller is responsible for eventually closing the returned server with - ``_close_ssh_tunnel``. - """ - base64_encoded_key = dnenv.get_env("PRIVATE_SSH_KEY_BLOB") - if not base64_encoded_key: - raise Exception( - "The private key needed to establish the SSH connection is missing. Please try again or contact support." - ) - original_url = make_url(sql_alchemy_dict["url"]) - server = create_ssh_tunnel( - ssh_host=sql_alchemy_dict["ssh_options"]["host"], - ssh_port=int(sql_alchemy_dict["ssh_options"]["port"]), - ssh_user=sql_alchemy_dict["ssh_options"]["user"], - remote_host=original_url.host, - remote_port=int(original_url.port), - private_key=base64.b64decode(base64_encoded_key).decode("utf-8"), - ) - url = URL.create( - drivername=original_url.drivername, - username=original_url.username, - password=original_url.password, - host=server.local_bind_host, - port=server.local_bind_port, - database=original_url.database, - query=original_url.query, - ) - return server, url - - -def _close_ssh_tunnel(server) -> None: - if server is not None and server.is_active: - server.close() - - @contextlib.contextmanager def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict): - if not ssh_enabled: + server = None + if ssh_enabled: + base64_encoded_key = dnenv.get_env("PRIVATE_SSH_KEY_BLOB") + if not base64_encoded_key: + raise Exception( + "The private key needed to establish the SSH connection is missing. Please try again or contact support." + ) + original_url = make_url(sql_alchemy_dict["url"]) + try: + server = create_ssh_tunnel( + ssh_host=sql_alchemy_dict["ssh_options"]["host"], + ssh_port=int(sql_alchemy_dict["ssh_options"]["port"]), + ssh_user=sql_alchemy_dict["ssh_options"]["user"], + remote_host=original_url.host, + remote_port=int(original_url.port), + private_key=base64.b64decode(base64_encoded_key).decode("utf-8"), + ) + url = URL.create( + drivername=original_url.drivername, + username=original_url.username, + password=original_url.password, + host=server.local_bind_host, + port=server.local_bind_port, + database=original_url.database, + query=original_url.query, + ) + yield url + finally: + if server is not None and server.is_active: + server.close() + else: yield None - return - server, url = _open_ssh_tunnel(sql_alchemy_dict) - try: - yield url - finally: - _close_ssh_tunnel(server) def _execute_sql_with_caching( @@ -420,9 +416,14 @@ def _execute_sql_with_caching( sql_cache_mode, return_variable_type, query_preview_source, + setup_statements=None, ): # duckdb SQL is not cached, so we can skip the logic below for duckdb if requires_duckdb: + # DuckDB uses a process-wide singleton connection, so session state set + # by setup_statements naturally persists for the main query. + for stmt in setup_statements or []: + execute_duckdb_sql(stmt, {}) dataframe = execute_duckdb_sql(query, bind_params) # for Chained SQL we return the dataframe with the SQL source attached as DeepnoteQueryPreview object if return_variable_type == "query_preview": @@ -461,126 +462,10 @@ def _execute_sql_with_caching( cache_upload_url, return_variable_type, query_preview_source, # The original query before any transformations such as appending a LIMIT clause + setup_statements=setup_statements, ) -# Active sql_session() registry, or None when no session is active. -# Maps a per-connection key to a (engine, ssh_server) bundle owned by the -# session: subsequent execute_sql* calls inside the with-block reuse them. -_session_registry: contextvars.ContextVar[Optional[dict[str, tuple[Any, Any]]]] = ( - contextvars.ContextVar("deepnote_sql_session_registry", default=None) -) - - -def _session_resource_key(sql_alchemy_dict) -> Optional[str]: - """Key under which to share session resources for this connection, or None - when sharing is unsafe (e.g. user already specifies a custom pool config - that would clash with our ``pool_size=1``). - """ - params = sql_alchemy_dict.get("params") or {} - if "pool_size" in params or "max_overflow" in params or "poolclass" in params: - return None - - integration_id = sql_alchemy_dict.get("integration_id") - if integration_id: - return f"integration:{integration_id}" - - return f"url:{sql_alchemy_dict['url']}" - - -@contextlib.contextmanager -def sql_session(): - """Share SQLAlchemy engines, SSH tunnels, and connections across all - ``execute_sql*`` calls inside the with-block. - - Inside the session each unique connection (keyed by ``integration_id``, - or URL when not present) opens its SSH tunnel and SQLAlchemy engine - once; subsequent calls reuse them. Engines are created with - ``pool_size=1, max_overflow=0`` so ``engine.begin()`` always checks out - the same physical DBAPI connection — that is what makes Snowflake - session state set by ``USE WAREHOUSE``, ``USE ROLE``, ``SET ...`` - persist across multiple ``execute_sql`` calls in a single SQL block. - - Outside a session ``execute_sql*`` falls back to its previous per-call - behavior of opening and disposing everything for each invocation. - - Nested ``sql_session()`` blocks are no-ops; only the outermost block - owns and tears down the resources. - """ - if _session_registry.get() is not None: - yield - return - - registry: dict[str, tuple[Any, Any]] = {} - token = _session_registry.set(registry) - try: - yield - finally: - _session_registry.reset(token) - for engine, ssh_server in registry.values(): - try: - engine.dispose() - except Exception: - logger.warning( - "Error disposing SQL engine on session exit", exc_info=True - ) - try: - _close_ssh_tunnel(ssh_server) - except Exception: - logger.warning( - "Error closing SSH tunnel on session exit", exc_info=True - ) - registry.clear() - - -def _acquire_engine(sql_alchemy_dict) -> tuple[Any, Optional[Any], bool]: - """Return ``(engine, ssh_server, owns_resources)``. - - Inside an active ``sql_session()`` for a shareable connection, returns - cached resources (creating them on first use); ``owns_resources`` is - False and the caller must not dispose. Otherwise opens fresh resources - and ``owns_resources`` is True — caller must dispose the engine and - close the tunnel. - """ - registry = _session_registry.get() - key = _session_resource_key(sql_alchemy_dict) if registry is not None else None - in_session = registry is not None and key is not None - - if in_session: - cached = registry.get(key) - if cached is not None: - engine, ssh_server = cached - return engine, ssh_server, False - - ssh_server: Optional[Any] = None - if sql_alchemy_dict.get("ssh_options", {}).get("enabled"): - ssh_server, url = _open_ssh_tunnel(sql_alchemy_dict) - else: - url = sql_alchemy_dict["url"] - - extra_engine_params: dict[str, Any] = {"pool_pre_ping": True} - if in_session: - # See sql_session() docstring for why pool_size=1 is required to make - # session state (USE WAREHOUSE, ...) persist across calls. - extra_engine_params["pool_size"] = 1 - extra_engine_params["max_overflow"] = 0 - - try: - with suppress_third_party_deprecation_warnings(): - engine = create_engine( - url, **sql_alchemy_dict["params"], **extra_engine_params - ) - except Exception: - _close_ssh_tunnel(ssh_server) - raise - - if in_session: - registry[key] = (engine, ssh_server) - return engine, ssh_server, False - - return engine, ssh_server, True - - @contextlib.contextmanager def suppress_third_party_deprecation_warnings(): """Suppress known deprecation warnings from third-party SQL packages. @@ -609,46 +494,56 @@ def _query_data_source( cache_upload_url, return_variable_type, query_preview_source, + setup_statements=None, ): - engine, ssh_server, owns_resources = _acquire_engine(sql_alchemy_dict) + sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False) - try: - dataframe = _execute_sql_on_engine(engine, query, bind_params) + with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url: + if url is None: + url = sql_alchemy_dict["url"] - if dataframe is None: - return None + with suppress_third_party_deprecation_warnings(): + engine = create_engine( + url, **sql_alchemy_dict["params"], pool_pre_ping=True + ) - # sanitize dataframe so that we can safely call .to_parquet on it - _sanitize_dataframe_for_parquet(dataframe) - - dataframe_size_in_bytes = int(dataframe.memory_usage(deep=True).sum()) - output_sql_metadata( - { - "status": "success_no_cache", - "size_in_bytes": dataframe_size_in_bytes, - "compiled_query": query, - "variable_type": return_variable_type, - "integration_id": sql_alchemy_dict.get("integration_id"), - } - ) + try: + dataframe = _execute_sql_on_engine( + engine, query, bind_params, setup_statements=setup_statements + ) - # for Chained SQL we return the dataframe with the SQL source attached as DeepnoteQueryPreview object - if return_variable_type == "query_preview": - return _convert_dataframe_to_query_preview(dataframe, query_preview_source) + if dataframe is None: + return None + + # sanitize dataframe so that we can safely call .to_parquet on it + _sanitize_dataframe_for_parquet(dataframe) + + dataframe_size_in_bytes = int(dataframe.memory_usage(deep=True).sum()) + output_sql_metadata( + { + "status": "success_no_cache", + "size_in_bytes": dataframe_size_in_bytes, + "compiled_query": query, + "variable_type": return_variable_type, + "integration_id": sql_alchemy_dict.get("integration_id"), + } + ) + + # for Chained SQL we return the dataframe with the SQL source attached as DeepnoteQueryPreview object + if return_variable_type == "query_preview": + return _convert_dataframe_to_query_preview( + dataframe, query_preview_source + ) - # if df is larger than 5GB, don't upload it. See NB-988 - dataframe_is_cacheable = dataframe_size_in_bytes < 5 * 1024 * 1024 * 1024 + # if df is larger than 5GB, don't upload it. See NB-988 + dataframe_is_cacheable = dataframe_size_in_bytes < 5 * 1024 * 1024 * 1024 - if cache_upload_url is not None and dataframe_is_cacheable: - upload_sql_cache(dataframe, cache_upload_url) + if cache_upload_url is not None and dataframe_is_cacheable: + upload_sql_cache(dataframe, cache_upload_url) - return dataframe - finally: - if owns_resources: - try: - engine.dispose() - finally: - _close_ssh_tunnel(ssh_server) + return dataframe + finally: + engine.dispose() class CursorTrackingDBAPIConnection(wrapt.ObjectProxy): @@ -733,13 +628,19 @@ def _cancel_cursor(cursor: "DBAPICursor") -> None: pass # Best effort, ignore all errors -def _execute_sql_on_engine(engine, query, bind_params): +def _execute_sql_on_engine(engine, query, bind_params, setup_statements=None): """Run *query* on *engine* and return a DataFrame. Uses pandas.read_sql_query to execute the query with a SQLAlchemy connection. For pandas 2.2+ and SQLAlchemy < 2.0, which requires a raw DB-API connection with a `.cursor()` attribute, we use the underlying connection. + When *setup_statements* is provided, each statement is executed on the + same DBAPI connection right before the main query so any session state + it sets (e.g. Snowflake ``USE WAREHOUSE``) is in effect when the main + query runs. Setup statements are issued via ``connection.exec_driver_sql`` + and any failure aborts the main query. + On exceptions (including KeyboardInterrupt from cell cancellation), all cursors created during execution are cancelled to stop running queries on the server. """ @@ -763,6 +664,11 @@ def _execute_sql_on_engine(engine, query, bind_params): ) with engine.begin() as connection: + # Run setup statements first on the same physical connection so any + # session state they set is visible to the main query below. + for stmt in setup_statements or []: + connection.exec_driver_sql(stmt) + # For pandas 2.2+ with SQLAlchemy < 2.0, use raw DBAPI connection if needs_raw_connection: tracking_connection = CursorTrackingDBAPIConnection(connection.connection) diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index a684077e..555a6c48 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -131,6 +131,7 @@ def test_sql_executed_with_audit_comment_but_hash_calculated_without_it( mock.ANY, mock.ANY, mock.ANY, + setup_statements=None, ) @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") @@ -155,6 +156,7 @@ def test_return_variable_type_parameter(self, mocked_query_data_source): mock.ANY, "dataframe", mock.ANY, + setup_statements=None, ) # Test with explicit return_variable_type='query_preview' @@ -171,6 +173,7 @@ def test_return_variable_type_parameter(self, mocked_query_data_source): mock.ANY, "query_preview", mock.ANY, + setup_statements=None, ) @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") @@ -199,6 +202,7 @@ def test_query_preview_preserves_trailing_inline_comment( mock.ANY, "query_preview", mock.ANY, + setup_statements=None, ) @mock.patch("deepnote_toolkit.sql.sql_caching._generate_cache_key") @@ -234,6 +238,7 @@ def test_sql_executed_with_audit_comment_with_semicolon( mock.ANY, "dataframe", mock.ANY, + setup_statements=None, ) @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") @@ -331,6 +336,42 @@ def test_execute_sql_with_connection_json_with_snowflake_encrypted_private_key( ) +class TestSetupStatementsPlumbing(TestCase): + """Tests that setup_statements is plumbed from the public API down to _query_data_source.""" + + @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") + def test_setup_statements_passed_to_query_data_source( + self, mocked_query_data_source + ): + os.environ["SQL_ENV_VAR"] = ( + '{"url":"postgresql://postgres:postgres@localhost:5432/postgres",' + '"params":{},"param_style":"qmark","integration_id":"int_1"}' + ) + + execute_sql( + "SELECT 1", + "SQL_ENV_VAR", + setup_statements=["USE WAREHOUSE abc", "USE ROLE r"], + ) + + _, kwargs = mocked_query_data_source.call_args + self.assertEqual( + kwargs["setup_statements"], ["USE WAREHOUSE abc", "USE ROLE r"] + ) + + @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") + def test_no_setup_statements_passes_none_through(self, mocked_query_data_source): + os.environ["SQL_ENV_VAR"] = ( + '{"url":"postgresql://postgres:postgres@localhost:5432/postgres",' + '"params":{},"param_style":"qmark","integration_id":"int_1"}' + ) + + execute_sql("SELECT 1", "SQL_ENV_VAR") + + _, kwargs = mocked_query_data_source.call_args + self.assertIsNone(kwargs["setup_statements"]) + + class TestTrinoParamStyleAutoDetection(TestCase): """Tests for auto-detection of param_style for Trino connections""" diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 41d36d0d..618e3773 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -377,226 +377,85 @@ def test_create_sql_ssh_uri_no_ssh(): assert url is None -def _make_sql_alchemy_dict(integration_id="integration_a", url=None, params=None): - return { - "url": url or "postgresql://u:p@localhost:5432/db", - "params": params if params is not None else {}, - "param_style": "qmark", - "integration_id": integration_id, - } +def test_execute_sql_on_engine_runs_setup_statements_in_order_before_main_query(): + """Setup statements must execute on the same connection as the main query, + in order, before pandas runs the main query.""" + import pandas as pd + mock_cursor = mock.MagicMock() + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) -def test_acquire_engine_outside_session_owns_resources(): - sql_alchemy_dict = _make_sql_alchemy_dict() + call_log: list[str] = [] - with mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() - ) as create_engine_mock: - engine_a, ssh_a, owns_a = se._acquire_engine(sql_alchemy_dict) - engine_b, ssh_b, owns_b = se._acquire_engine(sql_alchemy_dict) - - assert owns_a is True - assert owns_b is True - assert ssh_a is None - assert ssh_b is None - # Two calls outside a session create two engines; caller disposes each. - assert create_engine_mock.call_count == 2 - # Outside a session we don't impose pool_size on the user. - kwargs = create_engine_mock.call_args.kwargs - assert "pool_size" not in kwargs - assert "max_overflow" not in kwargs - assert kwargs["pool_pre_ping"] is True - - -def test_sql_session_reuses_engine_within_block(): - sql_alchemy_dict = _make_sql_alchemy_dict() - fake_engine = mock.Mock() + # Track exec_driver_sql (used for setup) and pd.read_sql_query (the main + # query) on the same SA connection, so we can assert ordering. + sa_connection = mock_engine.begin.return_value.__enter__.return_value + original_exec = sa_connection.exec_driver_sql - with mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", return_value=fake_engine - ) as create_engine_mock: - with se.sql_session(): - engine_a, _, owns_a = se._acquire_engine(sql_alchemy_dict) - engine_b, _, owns_b = se._acquire_engine(sql_alchemy_dict) - - assert engine_a is fake_engine - assert engine_b is fake_engine - assert owns_a is False - assert owns_b is False - assert create_engine_mock.call_count == 1 - # In-session engines must use a single-connection pool so engine.begin() - # returns the same physical DBAPI connection across calls — that is what - # makes session state (USE WAREHOUSE, ...) persist between execute_sql - # calls in the same block. - kwargs = create_engine_mock.call_args.kwargs - assert kwargs["pool_size"] == 1 - assert kwargs["max_overflow"] == 0 - assert kwargs["pool_pre_ping"] is True - # On exit the session disposes the engine it owned. - fake_engine.dispose.assert_called_once() - - -def test_sql_session_separates_engines_per_integration(): - dict_a = _make_sql_alchemy_dict(integration_id="int_a") - dict_b = _make_sql_alchemy_dict(integration_id="int_b") - engines = [mock.Mock(name="engine_a"), mock.Mock(name="engine_b")] + def logging_exec(sql, *args, **kwargs): + call_log.append(f"setup:{sql}") + return original_exec(sql, *args, **kwargs) - with mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", side_effect=engines - ): - with se.sql_session(): - engine_a, _, _ = se._acquire_engine(dict_a) - engine_b, _, _ = se._acquire_engine(dict_b) - engine_a_again, _, _ = se._acquire_engine(dict_a) - - assert engine_a is engines[0] - assert engine_b is engines[1] - assert engine_a_again is engines[0] - - -def test_sql_session_opens_ssh_tunnel_once_and_closes_on_exit(): - sql_alchemy_dict = _make_sql_alchemy_dict() - sql_alchemy_dict["ssh_options"] = { - "enabled": True, - "host": "h", - "port": 22, - "user": "u", - } - fake_engine = mock.Mock() - fake_server = mock.Mock() - fake_server.is_active = True - rewritten_url = "postgresql://u:p@127.0.0.1:65000/db" - - with ( - mock.patch( - "deepnote_toolkit.sql.sql_execution._open_ssh_tunnel", - return_value=(fake_server, rewritten_url), - ) as open_tunnel, - mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", - return_value=fake_engine, - ) as create_engine_mock, - ): - with se.sql_session(): - _, ssh_a, owns_a = se._acquire_engine(sql_alchemy_dict) - _, ssh_b, owns_b = se._acquire_engine(sql_alchemy_dict) - - # Tunnel opened once and reused; not torn down between calls. - assert open_tunnel.call_count == 1 - assert create_engine_mock.call_count == 1 - assert ssh_a is fake_server - assert ssh_b is fake_server - assert owns_a is False - assert owns_b is False - # Engine creation uses the rewritten (tunneled) URL. - assert create_engine_mock.call_args.args[0] == rewritten_url - # On session exit both the engine and the tunnel are torn down. - fake_engine.dispose.assert_called_once() - fake_server.close.assert_called_once() - - -def test_sql_session_skips_sharing_when_user_supplies_pool_config(): - sql_alchemy_dict = _make_sql_alchemy_dict(params={"pool_size": 5}) + sa_connection.exec_driver_sql = logging_exec + + def fake_read_sql_query(sql, **_kwargs): + call_log.append(f"main:{sql}") + return pd.DataFrame({"x": [1]}) with mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", return_value=mock.Mock() - ) as create_engine_mock: - with se.sql_session(): - _, _, owns_a = se._acquire_engine(sql_alchemy_dict) - _, _, owns_b = se._acquire_engine(sql_alchemy_dict) - - # Without sharing, caller owns disposal each call and we don't override the - # user's pool config (their pool_size=5 passes through unchanged). - assert owns_a is True - assert owns_b is True - assert create_engine_mock.call_count == 2 - kwargs = create_engine_mock.call_args.kwargs - assert kwargs["pool_size"] == 5 - - -def test_sql_session_swallows_per_resource_teardown_errors(): - sql_alchemy_dict = _make_sql_alchemy_dict() - failing_engine = mock.Mock() - failing_engine.dispose.side_effect = RuntimeError("boom") - failing_server = mock.Mock() - failing_server.is_active = True - failing_server.close.side_effect = RuntimeError("boom") - - with ( - mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", - return_value=failing_engine, - ), - mock.patch( - "deepnote_toolkit.sql.sql_execution._open_ssh_tunnel", - return_value=(failing_server, "postgresql://u@127.0.0.1:1/db"), - ), - mock.patch.object(se.logger, "warning") as mock_warning, - ): - sql_alchemy_dict["ssh_options"] = { - "enabled": True, - "host": "h", - "port": 22, - "user": "u", - } - with se.sql_session(): - se._acquire_engine(sql_alchemy_dict) + "pandas.read_sql_query", side_effect=fake_read_sql_query + ) as mock_read: + result = se._execute_sql_on_engine( + mock_engine, + "SELECT 1", + {}, + setup_statements=["USE WAREHOUSE abc", "USE ROLE r"], + ) - failing_engine.dispose.assert_called_once() - failing_server.close.assert_called_once() - # One warning per failing resource. - assert mock_warning.call_count == 2 + assert call_log == [ + "setup:USE WAREHOUSE abc", + "setup:USE ROLE r", + "main:SELECT 1", + ] + assert mock_read.call_count == 1 + # And the main query was actually run. + assert list(result.columns) == ["x"] -def test_nested_sql_session_does_not_steal_outer_resources(): - sql_alchemy_dict = _make_sql_alchemy_dict() - fake_engine = mock.Mock() +def test_execute_sql_on_engine_no_setup_statements_runs_only_main_query(): + """No setup_statements (None or empty) is a no-op — main query runs as before.""" + import pandas as pd - with mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", return_value=fake_engine - ) as create_engine_mock: - with se.sql_session(): - engine_outer, _, _ = se._acquire_engine(sql_alchemy_dict) - with se.sql_session(): - engine_inner, _, _ = se._acquire_engine(sql_alchemy_dict) - # Inner block must not have disposed the outer-owned engine. - fake_engine.dispose.assert_not_called() - engine_after_inner, _, _ = se._acquire_engine(sql_alchemy_dict) - - assert engine_outer is fake_engine - assert engine_inner is fake_engine - assert engine_after_inner is fake_engine - assert create_engine_mock.call_count == 1 - # Outer session disposes once on exit. - fake_engine.dispose.assert_called_once() - - -def test_acquire_engine_closes_tunnel_when_engine_creation_fails_outside_session(): - sql_alchemy_dict = _make_sql_alchemy_dict() - sql_alchemy_dict["ssh_options"] = { - "enabled": True, - "host": "h", - "port": 22, - "user": "u", - } - fake_server = mock.Mock() - fake_server.is_active = True - - with ( - mock.patch( - "deepnote_toolkit.sql.sql_execution._open_ssh_tunnel", - return_value=(fake_server, "postgresql://u@127.0.0.1:1/db"), - ), - mock.patch( - "deepnote_toolkit.sql.sql_execution.create_engine", - side_effect=RuntimeError("bad url"), - ), - ): - with pytest.raises(RuntimeError): - se._acquire_engine(sql_alchemy_dict) - - # Tunnel must not be left open if engine construction blows up. - fake_server.close.assert_called_once() + mock_cursor = mock.MagicMock() + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) + sa_connection = mock_engine.begin.return_value.__enter__.return_value + sa_connection.exec_driver_sql = mock.Mock() + + with mock.patch("pandas.read_sql_query", return_value=pd.DataFrame({"x": [1]})): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}, setup_statements=[]) + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}, setup_statements=None) + + sa_connection.exec_driver_sql.assert_not_called() + + +def test_execute_sql_on_engine_aborts_main_query_when_setup_fails(): + """A failing setup statement must propagate and prevent the main query from running.""" + mock_cursor = mock.MagicMock() + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) + sa_connection = mock_engine.begin.return_value.__enter__.return_value + sa_connection.exec_driver_sql = mock.Mock(side_effect=RuntimeError("bad warehouse")) + + with mock.patch("pandas.read_sql_query") as mock_read: + with pytest.raises(RuntimeError, match="bad warehouse"): + se._execute_sql_on_engine( + mock_engine, + "SELECT 1", + {}, + setup_statements=["USE WAREHOUSE missing"], + ) + + mock_read.assert_not_called() def test_create_sql_ssh_uri_missing_key(monkeypatch):