Skip to content
119 changes: 109 additions & 10 deletions deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import base64
import contextlib
import json
import logging
import re
import uuid
import warnings
from typing import Any
import weakref
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import quote

import google.oauth2.credentials
import numpy as np
import requests
import wrapt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from google.api_core.client_info import ClientInfo
from google.cloud import bigquery
from packaging.version import parse as parse_version
from pydantic import BaseModel
from sqlalchemy.engine import URL, create_engine, make_url
from sqlalchemy.engine import URL, Connection, create_engine, make_url
from sqlalchemy.exc import ResourceClosedError

from deepnote_core.pydantic_compat_helpers import model_validate_compat
Expand All @@ -28,6 +29,7 @@
get_project_auth_headers,
)
from deepnote_toolkit.ipython_utils import output_sql_metadata
from deepnote_toolkit.logging import LoggerManager
from deepnote_toolkit.ocelots.pandas.utils import deduplicate_columns
from deepnote_toolkit.sql.duckdb_sql import execute_duckdb_sql
from deepnote_toolkit.sql.jinjasql_utils import render_jinja_sql_template
Expand All @@ -37,7 +39,15 @@
from deepnote_toolkit.sql.sql_utils import is_single_select_query
from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url

logger = logging.getLogger(__name__)
if TYPE_CHECKING:
try:
from sqlalchemy.engine.interfaces import DBAPIConnection, DBAPICursor
except ImportError:
# Not available in SQLAlchemy < 2.0. We use them only for typing, so replace with Any
DBAPIConnection = Any
DBAPICursor = Any

logger = LoggerManager().get_logger()


class IntegrationFederatedAuthParams(BaseModel):
Expand Down Expand Up @@ -517,12 +527,97 @@ def _query_data_source(
engine.dispose()


class CursorTrackingDBAPIConnection(wrapt.ObjectProxy):
"""Wraps DBAPI connection to track cursors as they're created."""

def __init__(
self,
wrapped: "DBAPIConnection",
cursor_registry: Optional[weakref.WeakSet["DBAPICursor"]] = None,
) -> None:
super().__init__(wrapped)
# Use provided registry or create our own
self._self_cursor_registry = (
cursor_registry if cursor_registry is not None else weakref.WeakSet()
)

def cursor(self, *args, **kwargs):
cursor = self.__wrapped__.cursor(*args, **kwargs)
try:
self._self_cursor_registry.add(cursor)
except TypeError:
logger.warning(
f"DBAPI Cursor of type {type(cursor)} can't be added to weakset and thus can't be tracked."
)
return cursor

def cancel_all_cursors(self):
"""Cancel all tracked cursors. Best-effort, ignores errors."""
for cursor in self._self_cursor_registry:
_cancel_cursor(cursor)


class CursorTrackingSQLAlchemyConnection(wrapt.ObjectProxy):
"""A SQLAlchemy connection wrapper that tracks cursors for cancellation.

This wrapper replaces the internal _dbapi_connection with a tracking proxy,
so all cursors created (including by exec_driver_sql) are tracked.
"""

def __init__(self, wrapped: Connection) -> None:
super().__init__(wrapped)
self._self_cursors: weakref.WeakSet[DBAPICursor] = weakref.WeakSet()
self._install_dbapi_wrapper()

def _install_dbapi_wrapper(self):
"""Replace SQLAlchemy's internal DBAPI connection with our tracking wrapper."""
try:
# Access the internal DBAPI connection
if hasattr(self.__wrapped__.connection, "dbapi_connection"):
dbapi_conn = self.__wrapped__.connection.dbapi_connection
dbapi_connection_attr_name = "dbapi_connection"
else:
# SQLAlchemy pre v1.4
dbapi_conn = self.__wrapped__.connection.connection
dbapi_connection_attr_name = "connection"
if dbapi_conn is None:
logger.warning(
f"DBAPI connection is None (connection type {type(self.__wrapped__)}), cannot install tracking"
)
return

setattr(
self.__wrapped__.connection,
dbapi_connection_attr_name,
CursorTrackingDBAPIConnection(dbapi_conn, self._self_cursors),
)
except Exception as e:
logger.warning(f"Could not install DBAPI wrapper: {e}")

def cancel_all_cursors(self):
"""Cancel all tracked cursors. Best-effort, ignores errors."""
for cursor in self._self_cursors:
_cancel_cursor(cursor)


def _cancel_cursor(cursor: "DBAPICursor") -> None:
"""Best-effort cancel a cursor using available methods."""
try:
if hasattr(cursor, "cancel") and callable(cursor.cancel):
cursor.cancel()
except (Exception, KeyboardInterrupt):
pass # Best effort, ignore all errors


def _execute_sql_on_engine(engine, query, bind_params):
"""Run *query* on *engine* and return a DataFrame.

Uses pandas.read_sql_query to execute the query with a SQLAlchemy connection.
For pandas 2.2+ and SQLAlchemy < 2.0, which requires a raw DB-API connection with a `.cursor()` attribute,
we use the underlying connection.

On exceptions (including KeyboardInterrupt from cell cancellation), all cursors
created during execution are cancelled to stop running queries on the server.
"""

import pandas as pd
Expand All @@ -544,26 +639,30 @@ def _execute_sql_on_engine(engine, query, bind_params):
)

with engine.begin() as connection:
try:
# For pandas 2.2+, use raw connection to avoid 'cursor' AttributeError
connection_for_pandas = (
connection.connection if needs_raw_connection else connection
)
# For pandas 2.2+ with SQLAlchemy < 2.0, use raw DBAPI connection
if needs_raw_connection:
tracking_connection = CursorTrackingDBAPIConnection(connection.connection)
else:
tracking_connection = CursorTrackingSQLAlchemyConnection(connection)

try:
# pandas.read_sql_query expects params as tuple (not list) for qmark/format style
params_for_pandas = (
tuple(bind_params) if isinstance(bind_params, list) else bind_params
)

return pd.read_sql_query(
query,
con=connection_for_pandas,
con=tracking_connection,
params=params_for_pandas,
coerce_float=coerce_float,
)
except ResourceClosedError:
# this happens if the query is e.g. UPDATE and pandas tries to create a dataframe from its result
return None
except BaseException:
tracking_connection.cancel_all_cursors()
raise


def _build_params_for_bigquery_oauth(params):
Expand Down
122 changes: 122 additions & 0 deletions tests/unit/test_sql_execution_internal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from typing import Any
from unittest import mock

import numpy as np
Expand All @@ -9,6 +10,42 @@
from deepnote_toolkit.sql import sql_execution as se


def _setup_mock_engine_with_cursor(mock_cursor: mock.Mock) -> mock.Mock:
"""Helper to set up mock engine and connection with a custom cursor.

Returns mock_engine that can be passed to _execute_sql_on_engine.
"""
import sqlalchemy

mock_dbapi_connection: mock.Mock = mock.Mock()
mock_dbapi_connection.cursor.return_value = mock_cursor

mock_pool_connection = mock.Mock()
mock_pool_connection.dbapi_connection = mock_dbapi_connection
mock_pool_connection.cursor.side_effect = (
lambda: mock_pool_connection.dbapi_connection.cursor()
)

mock_sa_connection = mock.Mock(spec=sqlalchemy.engine.Connection)
mock_sa_connection.connection = mock_pool_connection
mock_sa_connection.in_transaction.return_value = False

def mock_exec_driver_sql(sql: str, *args: Any) -> mock.Mock:
cursor: mock.Mock = mock_sa_connection.connection.cursor()
cursor.execute(sql, *args)
return cursor

mock_sa_connection.exec_driver_sql = mock_exec_driver_sql

mock_engine = mock.Mock()
mock_engine.begin.return_value.__enter__ = mock.Mock(
return_value=mock_sa_connection
)
mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False)

return mock_engine


def test_bigquery_wait_or_cancel_handles_keyboard_interrupt():
import google.cloud.bigquery._job_helpers as _job_helpers

Expand All @@ -30,6 +67,91 @@ def test_bigquery_wait_or_cancel_handles_keyboard_interrupt():
mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0)


def test_execute_sql_on_engine_cancels_cursor_on_keyboard_interrupt():
"""Test that _execute_sql_on_engine cancels cursors on KeyboardInterrupt."""

mock_cursor = mock.MagicMock()
mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled")

mock_engine = _setup_mock_engine_with_cursor(mock_cursor)

with pytest.raises(KeyboardInterrupt):
se._execute_sql_on_engine(mock_engine, "SELECT 1", {})

mock_cursor.cancel.assert_called_once()


def test_execute_sql_on_engine_handles_cancel_errors_gracefully():
"""Test that _execute_sql_on_engine handles cancel errors gracefully."""

mock_cursor = mock.MagicMock()
mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled")
mock_cursor.cancel.side_effect = RuntimeError("Cancel failed")

mock_engine = _setup_mock_engine_with_cursor(mock_cursor)

# Should raise original KeyboardInterrupt, not the cancel error
with pytest.raises(KeyboardInterrupt):
se._execute_sql_on_engine(mock_engine, "SELECT 1", {})

mock_cursor.cancel.assert_called_once()


def test_cursor_tracking_dbapi_connection_cancel_all_cursors():
"""Test that CursorTrackingDBAPIConnection.cancel_all_cursors cancels all tracked cursors."""
mock_wrapped_conn = mock.Mock()
cursor1 = mock.Mock()
cursor2 = mock.Mock()
mock_wrapped_conn.cursor.side_effect = [cursor1, cursor2]

tracking_conn = se.CursorTrackingDBAPIConnection(mock_wrapped_conn)

# Create two cursors
tracking_conn.cursor()
tracking_conn.cursor()

# Cancel all cursors
tracking_conn.cancel_all_cursors()

cursor1.cancel.assert_called_once()
cursor2.cancel.assert_called_once()


def test_cursor_tracking_dbapi_connection_handles_unhashable_cursor():
"""Test that CursorTrackingDBAPIConnection handles cursors that can't be added to weakset."""
mock_wrapped_conn = mock.Mock()

class UnhashableCursor:
__hash__ = None

unhashable_cursor = UnhashableCursor()
mock_wrapped_conn.cursor.return_value = unhashable_cursor

tracking_conn = se.CursorTrackingDBAPIConnection(mock_wrapped_conn)

with mock.patch.object(se.logger, "warning") as mock_warning:
result = tracking_conn.cursor()

assert result is unhashable_cursor
mock_warning.assert_called_once()
assert "can't be added to weakset" in mock_warning.call_args[0][0]


def test_cursor_tracking_sqlalchemy_connection_handles_none_dbapi_connection():
"""Test that CursorTrackingSQLAlchemyConnection handles None dbapi connection."""
mock_conn_pool = mock.Mock()
mock_conn_pool.dbapi_connection = None

mock_sa_conn = mock.Mock()
mock_sa_conn.connection = mock_conn_pool

with mock.patch.object(se.logger, "warning") as mock_warning:
se.CursorTrackingSQLAlchemyConnection(mock_sa_conn)

mock_warning.assert_called_once()
assert "DBAPI connection is None" in mock_warning.call_args[0][0]


def test_build_params_for_bigquery_oauth_ok():
with mock.patch(
"deepnote_toolkit.sql.sql_execution.bigquery.Client"
Expand Down