Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def execute_sql_with_connection_json(
audit_sql_comment="",
sql_cache_mode="cache_disabled",
return_variable_type="dataframe",
setup_statements=None,
):
Comment on lines +105 to 106
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l deepnote_toolkit/sql/sql_execution.py

Repository: deepnote/deepnote-toolkit

Length of output: 110


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '100,115p'

Repository: deepnote/deepnote-toolkit

Length of output: 976


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '235,250p'

Repository: deepnote/deepnote-toolkit

Length of output: 868


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '414,430p'

Repository: deepnote/deepnote-toolkit

Length of output: 973


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '492,505p'

Repository: deepnote/deepnote-toolkit

Length of output: 581


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '625,640p'

Repository: deepnote/deepnote-toolkit

Length of output: 957


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '95,110p'

Repository: deepnote/deepnote-toolkit

Length of output: 682


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '230,250p'

Repository: deepnote/deepnote-toolkit

Length of output: 979


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '408,435p'

Repository: deepnote/deepnote-toolkit

Length of output: 1272


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '485,510p'

Repository: deepnote/deepnote-toolkit

Length of output: 928


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '625,660p'

Repository: deepnote/deepnote-toolkit

Length of output: 1882


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '99,107p'

Repository: deepnote/deepnote-toolkit

Length of output: 344


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '234,242p'

Repository: deepnote/deepnote-toolkit

Length of output: 331


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '410,421p'

Repository: deepnote/deepnote-toolkit

Length of output: 453


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '490,499p'

Repository: deepnote/deepnote-toolkit

Length of output: 395


🏁 Script executed:

cat -n deepnote_toolkit/sql/sql_execution.py | sed -n '631,647p'

Repository: deepnote/deepnote-toolkit

Length of output: 1076


Add type hints to all function parameters and return annotations.

The five functions at lines 99, 234, 410, 490, and 631 lack any parameter or return type annotations. Replace setup_statements=None with setup_statements: Optional[...] and add explicit return type annotations to each function. Per coding guidelines, use Optional[T] instead of T = None, and include return types for all functions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepnote_toolkit/sql/sql_execution.py` around lines 105 - 106, Several
functions in this module are missing parameter and return type annotations;
replace any parameter defaults like setup_statements=None with an explicit
Optional type (e.g., setup_statements: Optional[Sequence[str]] = None) and add
explicit return type annotations for each function that currently lacks them.
Import typing primitives you need (from typing import Optional, Sequence,
Iterable, Any) and update the function signatures (including the function that
declares setup_statements) to use Optional[T] rather than T = None, and add
concrete return types (e.g., -> None, -> int, -> List[Any], or -> Iterable[str]
as appropriate) to the five un-annotated functions in this module so all
parameters and returns are fully typed. Ensure the parameter name
setup_statements is changed to setup_statements: Optional[<appropriate type>] =
None and that each function signature and its return type reflect the actual
values they produce.

"""
Executes a SQL query using the given connection JSON (string).
Expand All @@ -111,6 +112,11 @@ def execute_sql_with_connection_json(
:param sql_alchemy_json: String containing JSON with the connection details.
Mandatory fields: url, params, param_style
:param sql_cache_mode: SQL caching setting for the query. Possible values: "cache_disabled", "always_write", "read_or_write"
:param setup_statements: Optional list of raw SQL statements to run on the
same connection right before *template*. Use for session setup such as
``USE WAREHOUSE abc`` whose effect must be visible to the main query.
Statements are not Jinja-rendered, parameter-bound, or audit-commented;
they are executed in order via ``connection.exec_driver_sql``.
:return: Pandas dataframe with the result
"""

Expand Down Expand Up @@ -221,6 +227,7 @@ class ExecuteSqlError(Exception):
sql_cache_mode,
return_variable_type,
query_preview_source,
setup_statements=setup_statements,
)


Expand All @@ -230,13 +237,15 @@ def execute_sql(
audit_sql_comment="",
sql_cache_mode="cache_disabled",
return_variable_type="dataframe",
setup_statements=None,
):
"""
Wrapper around execute_sql_with_connection_json which reads the connection JSON from
environment variable.
:param template: Templated SQL
:param sql_alchemy_json_env_var: Name of the environment variable containing the connection JSON
:param sql_cache_mode: SQL caching setting for the query. Possible values: "cache_disabled", "always_write", "read_or_write"
:param setup_statements: See ``execute_sql_with_connection_json``.
:return: Pandas dataframe with the result
"""

Expand All @@ -260,6 +269,7 @@ class ExecuteSqlError(Exception):
audit_sql_comment=audit_sql_comment,
sql_cache_mode=sql_cache_mode,
return_variable_type=return_variable_type,
setup_statements=setup_statements,
)


Expand Down Expand Up @@ -406,9 +416,14 @@ def _execute_sql_with_caching(
sql_cache_mode,
return_variable_type,
query_preview_source,
setup_statements=None,
):
# duckdb SQL is not cached, so we can skip the logic below for duckdb
if requires_duckdb:
# DuckDB uses a process-wide singleton connection, so session state set
# by setup_statements naturally persists for the main query.
for stmt in setup_statements or []:
execute_duckdb_sql(stmt, {})
dataframe = execute_duckdb_sql(query, bind_params)
# for Chained SQL we return the dataframe with the SQL source attached as DeepnoteQueryPreview object
if return_variable_type == "query_preview":
Expand Down Expand Up @@ -447,6 +462,7 @@ def _execute_sql_with_caching(
cache_upload_url,
return_variable_type,
query_preview_source, # The original query before any transformations such as appending a LIMIT clause
setup_statements=setup_statements,
)


Expand Down Expand Up @@ -478,6 +494,7 @@ def _query_data_source(
cache_upload_url,
return_variable_type,
query_preview_source,
setup_statements=None,
):
sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False)

Expand All @@ -491,7 +508,9 @@ def _query_data_source(
)

try:
dataframe = _execute_sql_on_engine(engine, query, bind_params)
dataframe = _execute_sql_on_engine(
engine, query, bind_params, setup_statements=setup_statements
)

if dataframe is None:
return None
Expand Down Expand Up @@ -609,13 +628,19 @@ def _cancel_cursor(cursor: "DBAPICursor") -> None:
pass # Best effort, ignore all errors


def _execute_sql_on_engine(engine, query, bind_params):
def _execute_sql_on_engine(engine, query, bind_params, setup_statements=None):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate setup_statements type before iterating.

Line 669 accepts any iterable. If a caller passes a string, it runs one-character “statements” and fails with misleading SQL errors.

Suggested fix
 def _execute_sql_on_engine(engine, query, bind_params, setup_statements=None):
@@
+    if setup_statements is None:
+        normalized_setup_statements: list[str] = []
+    elif isinstance(setup_statements, list) and all(
+        isinstance(stmt, str) for stmt in setup_statements
+    ):
+        normalized_setup_statements = setup_statements
+    else:
+        raise TypeError("setup_statements must be a list[str] or None")
+
     with engine.begin() as connection:
@@
-        for stmt in setup_statements or []:
+        for stmt in normalized_setup_statements:
             connection.exec_driver_sql(stmt)

Also applies to: 667-670

🧰 Tools
🪛 Ruff (0.15.10)

[warning] 631-631: Missing return type annotation for private function _execute_sql_on_engine

(ANN202)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepnote_toolkit/sql/sql_execution.py` at line 631, In
_execute_sql_on_engine, validate setup_statements before iterating: ensure it's
either None or an iterable of statements (not a plain string). Add a guard that
detects isinstance(setup_statements, str) and either wrap it as a single-item
list or raise a TypeError, and also validate items are strings (or convert them)
before the loop that executes each statement to avoid iterating over characters
and producing misleading SQL errors.

"""Run *query* on *engine* and return a DataFrame.

Uses pandas.read_sql_query to execute the query with a SQLAlchemy connection.
For pandas 2.2+ and SQLAlchemy < 2.0, which requires a raw DB-API connection with a `.cursor()` attribute,
we use the underlying connection.

When *setup_statements* is provided, each statement is executed on the
same DBAPI connection right before the main query so any session state
it sets (e.g. Snowflake ``USE WAREHOUSE``) is in effect when the main
query runs. Setup statements are issued via ``connection.exec_driver_sql``
and any failure aborts the main query.

On exceptions (including KeyboardInterrupt from cell cancellation), all cursors
created during execution are cancelled to stop running queries on the server.
"""
Expand All @@ -639,6 +664,11 @@ def _execute_sql_on_engine(engine, query, bind_params):
)

with engine.begin() as connection:
# Run setup statements first on the same physical connection so any
# session state they set is visible to the main query below.
for stmt in setup_statements or []:
connection.exec_driver_sql(stmt)

# For pandas 2.2+ with SQLAlchemy < 2.0, use raw DBAPI connection
if needs_raw_connection:
tracking_connection = CursorTrackingDBAPIConnection(connection.connection)
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/test_sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_sql_executed_with_audit_comment_but_hash_calculated_without_it(
mock.ANY,
mock.ANY,
mock.ANY,
setup_statements=None,
)

@mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source")
Expand All @@ -155,6 +156,7 @@ def test_return_variable_type_parameter(self, mocked_query_data_source):
mock.ANY,
"dataframe",
mock.ANY,
setup_statements=None,
)

# Test with explicit return_variable_type='query_preview'
Expand All @@ -171,6 +173,7 @@ def test_return_variable_type_parameter(self, mocked_query_data_source):
mock.ANY,
"query_preview",
mock.ANY,
setup_statements=None,
)

@mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source")
Expand Down Expand Up @@ -199,6 +202,7 @@ def test_query_preview_preserves_trailing_inline_comment(
mock.ANY,
"query_preview",
mock.ANY,
setup_statements=None,
)

@mock.patch("deepnote_toolkit.sql.sql_caching._generate_cache_key")
Expand Down Expand Up @@ -234,6 +238,7 @@ def test_sql_executed_with_audit_comment_with_semicolon(
mock.ANY,
"dataframe",
mock.ANY,
setup_statements=None,
)

@mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source")
Expand Down Expand Up @@ -331,6 +336,42 @@ def test_execute_sql_with_connection_json_with_snowflake_encrypted_private_key(
)


class TestSetupStatementsPlumbing(TestCase):
"""Tests that setup_statements is plumbed from the public API down to _query_data_source."""

@mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source")
def test_setup_statements_passed_to_query_data_source(
self, mocked_query_data_source
):
os.environ["SQL_ENV_VAR"] = (
'{"url":"postgresql://postgres:postgres@localhost:5432/postgres",'
'"params":{},"param_style":"qmark","integration_id":"int_1"}'
)

execute_sql(
"SELECT 1",
"SQL_ENV_VAR",
setup_statements=["USE WAREHOUSE abc", "USE ROLE r"],
)

_, kwargs = mocked_query_data_source.call_args
self.assertEqual(
kwargs["setup_statements"], ["USE WAREHOUSE abc", "USE ROLE r"]
)

@mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source")
def test_no_setup_statements_passes_none_through(self, mocked_query_data_source):
os.environ["SQL_ENV_VAR"] = (
'{"url":"postgresql://postgres:postgres@localhost:5432/postgres",'
'"params":{},"param_style":"qmark","integration_id":"int_1"}'
)

execute_sql("SELECT 1", "SQL_ENV_VAR")

_, kwargs = mocked_query_data_source.call_args
self.assertIsNone(kwargs["setup_statements"])


class TestTrinoParamStyleAutoDetection(TestCase):
"""Tests for auto-detection of param_style for Trino connections"""

Expand Down
81 changes: 81 additions & 0 deletions tests/unit/test_sql_execution_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,87 @@ def test_create_sql_ssh_uri_no_ssh():
assert url is None


def test_execute_sql_on_engine_runs_setup_statements_in_order_before_main_query():
"""Setup statements must execute on the same connection as the main query,
in order, before pandas runs the main query."""
import pandas as pd

mock_cursor = mock.MagicMock()
mock_engine = _setup_mock_engine_with_cursor(mock_cursor)

call_log: list[str] = []

# Track exec_driver_sql (used for setup) and pd.read_sql_query (the main
# query) on the same SA connection, so we can assert ordering.
sa_connection = mock_engine.begin.return_value.__enter__.return_value
original_exec = sa_connection.exec_driver_sql

def logging_exec(sql, *args, **kwargs):
call_log.append(f"setup:{sql}")
return original_exec(sql, *args, **kwargs)

sa_connection.exec_driver_sql = logging_exec

def fake_read_sql_query(sql, **_kwargs):
call_log.append(f"main:{sql}")
return pd.DataFrame({"x": [1]})

with mock.patch(
"pandas.read_sql_query", side_effect=fake_read_sql_query
) as mock_read:
result = se._execute_sql_on_engine(
mock_engine,
"SELECT 1",
{},
setup_statements=["USE WAREHOUSE abc", "USE ROLE r"],
)

assert call_log == [
"setup:USE WAREHOUSE abc",
"setup:USE ROLE r",
"main:SELECT 1",
]
assert mock_read.call_count == 1
# And the main query was actually run.
assert list(result.columns) == ["x"]


def test_execute_sql_on_engine_no_setup_statements_runs_only_main_query():
"""No setup_statements (None or empty) is a no-op — main query runs as before."""
import pandas as pd

mock_cursor = mock.MagicMock()
mock_engine = _setup_mock_engine_with_cursor(mock_cursor)
sa_connection = mock_engine.begin.return_value.__enter__.return_value
sa_connection.exec_driver_sql = mock.Mock()

with mock.patch("pandas.read_sql_query", return_value=pd.DataFrame({"x": [1]})):
se._execute_sql_on_engine(mock_engine, "SELECT 1", {})
se._execute_sql_on_engine(mock_engine, "SELECT 1", {}, setup_statements=[])
se._execute_sql_on_engine(mock_engine, "SELECT 1", {}, setup_statements=None)

sa_connection.exec_driver_sql.assert_not_called()


def test_execute_sql_on_engine_aborts_main_query_when_setup_fails():
"""A failing setup statement must propagate and prevent the main query from running."""
mock_cursor = mock.MagicMock()
mock_engine = _setup_mock_engine_with_cursor(mock_cursor)
sa_connection = mock_engine.begin.return_value.__enter__.return_value
sa_connection.exec_driver_sql = mock.Mock(side_effect=RuntimeError("bad warehouse"))

with mock.patch("pandas.read_sql_query") as mock_read:
with pytest.raises(RuntimeError, match="bad warehouse"):
se._execute_sql_on_engine(
mock_engine,
"SELECT 1",
{},
setup_statements=["USE WAREHOUSE missing"],
)

mock_read.assert_not_called()


def test_create_sql_ssh_uri_missing_key(monkeypatch):
def fake_get_env(name, default=None):
if name == "PRIVATE_SSH_KEY_BLOB":
Expand Down
Loading