diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 07b61fe..c9e3a53 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -102,6 +102,7 @@ def execute_sql_with_connection_json( audit_sql_comment="", sql_cache_mode="cache_disabled", return_variable_type="dataframe", + setup_statements=None, ): """ Executes a SQL query using the given connection JSON (string). @@ -111,6 +112,11 @@ def execute_sql_with_connection_json( :param sql_alchemy_json: String containing JSON with the connection details. Mandatory fields: url, params, param_style :param sql_cache_mode: SQL caching setting for the query. Possible values: "cache_disabled", "always_write", "read_or_write" + :param setup_statements: Optional list of raw SQL statements to run on the + same connection right before *template*. Use for session setup such as + ``USE WAREHOUSE abc`` whose effect must be visible to the main query. + Statements are not Jinja-rendered, parameter-bound, or audit-commented; + they are executed in order via ``connection.exec_driver_sql``. :return: Pandas dataframe with the result """ @@ -221,6 +227,7 @@ class ExecuteSqlError(Exception): sql_cache_mode, return_variable_type, query_preview_source, + setup_statements=setup_statements, ) @@ -230,6 +237,7 @@ def execute_sql( audit_sql_comment="", sql_cache_mode="cache_disabled", return_variable_type="dataframe", + setup_statements=None, ): """ Wrapper around execute_sql_with_connection_json which reads the connection JSON from @@ -237,6 +245,7 @@ def execute_sql( :param template: Templated SQL :param sql_alchemy_json_env_var: Name of the environment variable containing the connection JSON :param sql_cache_mode: SQL caching setting for the query. Possible values: "cache_disabled", "always_write", "read_or_write" + :param setup_statements: See ``execute_sql_with_connection_json``. :return: Pandas dataframe with the result """ @@ -260,6 +269,7 @@ class ExecuteSqlError(Exception): audit_sql_comment=audit_sql_comment, sql_cache_mode=sql_cache_mode, return_variable_type=return_variable_type, + setup_statements=setup_statements, ) @@ -406,9 +416,14 @@ def _execute_sql_with_caching( sql_cache_mode, return_variable_type, query_preview_source, + setup_statements=None, ): # duckdb SQL is not cached, so we can skip the logic below for duckdb if requires_duckdb: + # DuckDB uses a process-wide singleton connection, so session state set + # by setup_statements naturally persists for the main query. + for stmt in setup_statements or []: + execute_duckdb_sql(stmt, {}) dataframe = execute_duckdb_sql(query, bind_params) # for Chained SQL we return the dataframe with the SQL source attached as DeepnoteQueryPreview object if return_variable_type == "query_preview": @@ -447,6 +462,7 @@ def _execute_sql_with_caching( cache_upload_url, return_variable_type, query_preview_source, # The original query before any transformations such as appending a LIMIT clause + setup_statements=setup_statements, ) @@ -478,6 +494,7 @@ def _query_data_source( cache_upload_url, return_variable_type, query_preview_source, + setup_statements=None, ): sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False) @@ -491,7 +508,9 @@ def _query_data_source( ) try: - dataframe = _execute_sql_on_engine(engine, query, bind_params) + dataframe = _execute_sql_on_engine( + engine, query, bind_params, setup_statements=setup_statements + ) if dataframe is None: return None @@ -609,13 +628,19 @@ def _cancel_cursor(cursor: "DBAPICursor") -> None: pass # Best effort, ignore all errors -def _execute_sql_on_engine(engine, query, bind_params): +def _execute_sql_on_engine(engine, query, bind_params, setup_statements=None): """Run *query* on *engine* and return a DataFrame. Uses pandas.read_sql_query to execute the query with a SQLAlchemy connection. For pandas 2.2+ and SQLAlchemy < 2.0, which requires a raw DB-API connection with a `.cursor()` attribute, we use the underlying connection. + When *setup_statements* is provided, each statement is executed on the + same DBAPI connection right before the main query so any session state + it sets (e.g. Snowflake ``USE WAREHOUSE``) is in effect when the main + query runs. Setup statements are issued via ``connection.exec_driver_sql`` + and any failure aborts the main query. + On exceptions (including KeyboardInterrupt from cell cancellation), all cursors created during execution are cancelled to stop running queries on the server. """ @@ -639,6 +664,11 @@ def _execute_sql_on_engine(engine, query, bind_params): ) with engine.begin() as connection: + # Run setup statements first on the same physical connection so any + # session state they set is visible to the main query below. + for stmt in setup_statements or []: + connection.exec_driver_sql(stmt) + # For pandas 2.2+ with SQLAlchemy < 2.0, use raw DBAPI connection if needs_raw_connection: tracking_connection = CursorTrackingDBAPIConnection(connection.connection) diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index a684077..555a6c4 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -131,6 +131,7 @@ def test_sql_executed_with_audit_comment_but_hash_calculated_without_it( mock.ANY, mock.ANY, mock.ANY, + setup_statements=None, ) @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") @@ -155,6 +156,7 @@ def test_return_variable_type_parameter(self, mocked_query_data_source): mock.ANY, "dataframe", mock.ANY, + setup_statements=None, ) # Test with explicit return_variable_type='query_preview' @@ -171,6 +173,7 @@ def test_return_variable_type_parameter(self, mocked_query_data_source): mock.ANY, "query_preview", mock.ANY, + setup_statements=None, ) @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") @@ -199,6 +202,7 @@ def test_query_preview_preserves_trailing_inline_comment( mock.ANY, "query_preview", mock.ANY, + setup_statements=None, ) @mock.patch("deepnote_toolkit.sql.sql_caching._generate_cache_key") @@ -234,6 +238,7 @@ def test_sql_executed_with_audit_comment_with_semicolon( mock.ANY, "dataframe", mock.ANY, + setup_statements=None, ) @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") @@ -331,6 +336,42 @@ def test_execute_sql_with_connection_json_with_snowflake_encrypted_private_key( ) +class TestSetupStatementsPlumbing(TestCase): + """Tests that setup_statements is plumbed from the public API down to _query_data_source.""" + + @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") + def test_setup_statements_passed_to_query_data_source( + self, mocked_query_data_source + ): + os.environ["SQL_ENV_VAR"] = ( + '{"url":"postgresql://postgres:postgres@localhost:5432/postgres",' + '"params":{},"param_style":"qmark","integration_id":"int_1"}' + ) + + execute_sql( + "SELECT 1", + "SQL_ENV_VAR", + setup_statements=["USE WAREHOUSE abc", "USE ROLE r"], + ) + + _, kwargs = mocked_query_data_source.call_args + self.assertEqual( + kwargs["setup_statements"], ["USE WAREHOUSE abc", "USE ROLE r"] + ) + + @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") + def test_no_setup_statements_passes_none_through(self, mocked_query_data_source): + os.environ["SQL_ENV_VAR"] = ( + '{"url":"postgresql://postgres:postgres@localhost:5432/postgres",' + '"params":{},"param_style":"qmark","integration_id":"int_1"}' + ) + + execute_sql("SELECT 1", "SQL_ENV_VAR") + + _, kwargs = mocked_query_data_source.call_args + self.assertIsNone(kwargs["setup_statements"]) + + class TestTrinoParamStyleAutoDetection(TestCase): """Tests for auto-detection of param_style for Trino connections""" diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index ad4ec00..618e377 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -377,6 +377,87 @@ def test_create_sql_ssh_uri_no_ssh(): assert url is None +def test_execute_sql_on_engine_runs_setup_statements_in_order_before_main_query(): + """Setup statements must execute on the same connection as the main query, + in order, before pandas runs the main query.""" + import pandas as pd + + mock_cursor = mock.MagicMock() + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) + + call_log: list[str] = [] + + # Track exec_driver_sql (used for setup) and pd.read_sql_query (the main + # query) on the same SA connection, so we can assert ordering. + sa_connection = mock_engine.begin.return_value.__enter__.return_value + original_exec = sa_connection.exec_driver_sql + + def logging_exec(sql, *args, **kwargs): + call_log.append(f"setup:{sql}") + return original_exec(sql, *args, **kwargs) + + sa_connection.exec_driver_sql = logging_exec + + def fake_read_sql_query(sql, **_kwargs): + call_log.append(f"main:{sql}") + return pd.DataFrame({"x": [1]}) + + with mock.patch( + "pandas.read_sql_query", side_effect=fake_read_sql_query + ) as mock_read: + result = se._execute_sql_on_engine( + mock_engine, + "SELECT 1", + {}, + setup_statements=["USE WAREHOUSE abc", "USE ROLE r"], + ) + + assert call_log == [ + "setup:USE WAREHOUSE abc", + "setup:USE ROLE r", + "main:SELECT 1", + ] + assert mock_read.call_count == 1 + # And the main query was actually run. + assert list(result.columns) == ["x"] + + +def test_execute_sql_on_engine_no_setup_statements_runs_only_main_query(): + """No setup_statements (None or empty) is a no-op — main query runs as before.""" + import pandas as pd + + mock_cursor = mock.MagicMock() + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) + sa_connection = mock_engine.begin.return_value.__enter__.return_value + sa_connection.exec_driver_sql = mock.Mock() + + with mock.patch("pandas.read_sql_query", return_value=pd.DataFrame({"x": [1]})): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}, setup_statements=[]) + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}, setup_statements=None) + + sa_connection.exec_driver_sql.assert_not_called() + + +def test_execute_sql_on_engine_aborts_main_query_when_setup_fails(): + """A failing setup statement must propagate and prevent the main query from running.""" + mock_cursor = mock.MagicMock() + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) + sa_connection = mock_engine.begin.return_value.__enter__.return_value + sa_connection.exec_driver_sql = mock.Mock(side_effect=RuntimeError("bad warehouse")) + + with mock.patch("pandas.read_sql_query") as mock_read: + with pytest.raises(RuntimeError, match="bad warehouse"): + se._execute_sql_on_engine( + mock_engine, + "SELECT 1", + {}, + setup_statements=["USE WAREHOUSE missing"], + ) + + mock_read.assert_not_called() + + def test_create_sql_ssh_uri_missing_key(monkeypatch): def fake_get_env(name, default=None): if name == "PRIVATE_SSH_KEY_BLOB":