diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 3b8e96e..fb94b6d 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -1,23 +1,24 @@ import base64 import contextlib import json -import logging import re import uuid import warnings -from typing import Any +import weakref +from typing import TYPE_CHECKING, Any, Optional from urllib.parse import quote import google.oauth2.credentials import numpy as np import requests +import wrapt from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from google.api_core.client_info import ClientInfo from google.cloud import bigquery from packaging.version import parse as parse_version from pydantic import BaseModel -from sqlalchemy.engine import URL, create_engine, make_url +from sqlalchemy.engine import URL, Connection, create_engine, make_url from sqlalchemy.exc import ResourceClosedError from deepnote_core.pydantic_compat_helpers import model_validate_compat @@ -28,6 +29,7 @@ get_project_auth_headers, ) from deepnote_toolkit.ipython_utils import output_sql_metadata +from deepnote_toolkit.logging import LoggerManager from deepnote_toolkit.ocelots.pandas.utils import deduplicate_columns from deepnote_toolkit.sql.duckdb_sql import execute_duckdb_sql from deepnote_toolkit.sql.jinjasql_utils import render_jinja_sql_template @@ -37,7 +39,15 @@ from deepnote_toolkit.sql.sql_utils import is_single_select_query from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + try: + from sqlalchemy.engine.interfaces import DBAPIConnection, DBAPICursor + except ImportError: + # Not available in SQLAlchemy < 2.0. We use them only for typing, so replace with Any + DBAPIConnection = Any + DBAPICursor = Any + +logger = LoggerManager().get_logger() class IntegrationFederatedAuthParams(BaseModel): @@ -517,12 +527,97 @@ def _query_data_source( engine.dispose() +class CursorTrackingDBAPIConnection(wrapt.ObjectProxy): + """Wraps DBAPI connection to track cursors as they're created.""" + + def __init__( + self, + wrapped: "DBAPIConnection", + cursor_registry: Optional[weakref.WeakSet["DBAPICursor"]] = None, + ) -> None: + super().__init__(wrapped) + # Use provided registry or create our own + self._self_cursor_registry = ( + cursor_registry if cursor_registry is not None else weakref.WeakSet() + ) + + def cursor(self, *args, **kwargs): + cursor = self.__wrapped__.cursor(*args, **kwargs) + try: + self._self_cursor_registry.add(cursor) + except TypeError: + logger.warning( + f"DBAPI Cursor of type {type(cursor)} can't be added to weakset and thus can't be tracked." + ) + return cursor + + def cancel_all_cursors(self): + """Cancel all tracked cursors. Best-effort, ignores errors.""" + for cursor in self._self_cursor_registry: + _cancel_cursor(cursor) + + +class CursorTrackingSQLAlchemyConnection(wrapt.ObjectProxy): + """A SQLAlchemy connection wrapper that tracks cursors for cancellation. + + This wrapper replaces the internal _dbapi_connection with a tracking proxy, + so all cursors created (including by exec_driver_sql) are tracked. + """ + + def __init__(self, wrapped: Connection) -> None: + super().__init__(wrapped) + self._self_cursors: weakref.WeakSet[DBAPICursor] = weakref.WeakSet() + self._install_dbapi_wrapper() + + def _install_dbapi_wrapper(self): + """Replace SQLAlchemy's internal DBAPI connection with our tracking wrapper.""" + try: + # Access the internal DBAPI connection + if hasattr(self.__wrapped__.connection, "dbapi_connection"): + dbapi_conn = self.__wrapped__.connection.dbapi_connection + dbapi_connection_attr_name = "dbapi_connection" + else: + # SQLAlchemy pre v1.4 + dbapi_conn = self.__wrapped__.connection.connection + dbapi_connection_attr_name = "connection" + if dbapi_conn is None: + logger.warning( + f"DBAPI connection is None (connection type {type(self.__wrapped__)}), cannot install tracking" + ) + return + + setattr( + self.__wrapped__.connection, + dbapi_connection_attr_name, + CursorTrackingDBAPIConnection(dbapi_conn, self._self_cursors), + ) + except Exception as e: + logger.warning(f"Could not install DBAPI wrapper: {e}") + + def cancel_all_cursors(self): + """Cancel all tracked cursors. Best-effort, ignores errors.""" + for cursor in self._self_cursors: + _cancel_cursor(cursor) + + +def _cancel_cursor(cursor: "DBAPICursor") -> None: + """Best-effort cancel a cursor using available methods.""" + try: + if hasattr(cursor, "cancel") and callable(cursor.cancel): + cursor.cancel() + except (Exception, KeyboardInterrupt): + pass # Best effort, ignore all errors + + def _execute_sql_on_engine(engine, query, bind_params): """Run *query* on *engine* and return a DataFrame. Uses pandas.read_sql_query to execute the query with a SQLAlchemy connection. For pandas 2.2+ and SQLAlchemy < 2.0, which requires a raw DB-API connection with a `.cursor()` attribute, we use the underlying connection. + + On exceptions (including KeyboardInterrupt from cell cancellation), all cursors + created during execution are cancelled to stop running queries on the server. """ import pandas as pd @@ -544,12 +639,13 @@ def _execute_sql_on_engine(engine, query, bind_params): ) with engine.begin() as connection: - try: - # For pandas 2.2+, use raw connection to avoid 'cursor' AttributeError - connection_for_pandas = ( - connection.connection if needs_raw_connection else connection - ) + # For pandas 2.2+ with SQLAlchemy < 2.0, use raw DBAPI connection + if needs_raw_connection: + tracking_connection = CursorTrackingDBAPIConnection(connection.connection) + else: + tracking_connection = CursorTrackingSQLAlchemyConnection(connection) + try: # pandas.read_sql_query expects params as tuple (not list) for qmark/format style params_for_pandas = ( tuple(bind_params) if isinstance(bind_params, list) else bind_params @@ -557,13 +653,16 @@ def _execute_sql_on_engine(engine, query, bind_params): return pd.read_sql_query( query, - con=connection_for_pandas, + con=tracking_connection, params=params_for_pandas, coerce_float=coerce_float, ) except ResourceClosedError: # this happens if the query is e.g. UPDATE and pandas tries to create a dataframe from its result return None + except BaseException: + tracking_connection.cancel_all_cursors() + raise def _build_params_for_bigquery_oauth(params): diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 05eecf0..6826ab6 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -1,4 +1,5 @@ import uuid +from typing import Any from unittest import mock import numpy as np @@ -9,6 +10,42 @@ from deepnote_toolkit.sql import sql_execution as se +def _setup_mock_engine_with_cursor(mock_cursor: mock.Mock) -> mock.Mock: + """Helper to set up mock engine and connection with a custom cursor. + + Returns mock_engine that can be passed to _execute_sql_on_engine. + """ + import sqlalchemy + + mock_dbapi_connection: mock.Mock = mock.Mock() + mock_dbapi_connection.cursor.return_value = mock_cursor + + mock_pool_connection = mock.Mock() + mock_pool_connection.dbapi_connection = mock_dbapi_connection + mock_pool_connection.cursor.side_effect = ( + lambda: mock_pool_connection.dbapi_connection.cursor() + ) + + mock_sa_connection = mock.Mock(spec=sqlalchemy.engine.Connection) + mock_sa_connection.connection = mock_pool_connection + mock_sa_connection.in_transaction.return_value = False + + def mock_exec_driver_sql(sql: str, *args: Any) -> mock.Mock: + cursor: mock.Mock = mock_sa_connection.connection.cursor() + cursor.execute(sql, *args) + return cursor + + mock_sa_connection.exec_driver_sql = mock_exec_driver_sql + + mock_engine = mock.Mock() + mock_engine.begin.return_value.__enter__ = mock.Mock( + return_value=mock_sa_connection + ) + mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False) + + return mock_engine + + def test_bigquery_wait_or_cancel_handles_keyboard_interrupt(): import google.cloud.bigquery._job_helpers as _job_helpers @@ -30,6 +67,91 @@ def test_bigquery_wait_or_cancel_handles_keyboard_interrupt(): mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0) +def test_execute_sql_on_engine_cancels_cursor_on_keyboard_interrupt(): + """Test that _execute_sql_on_engine cancels cursors on KeyboardInterrupt.""" + + mock_cursor = mock.MagicMock() + mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled") + + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) + + with pytest.raises(KeyboardInterrupt): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + + mock_cursor.cancel.assert_called_once() + + +def test_execute_sql_on_engine_handles_cancel_errors_gracefully(): + """Test that _execute_sql_on_engine handles cancel errors gracefully.""" + + mock_cursor = mock.MagicMock() + mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled") + mock_cursor.cancel.side_effect = RuntimeError("Cancel failed") + + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) + + # Should raise original KeyboardInterrupt, not the cancel error + with pytest.raises(KeyboardInterrupt): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + + mock_cursor.cancel.assert_called_once() + + +def test_cursor_tracking_dbapi_connection_cancel_all_cursors(): + """Test that CursorTrackingDBAPIConnection.cancel_all_cursors cancels all tracked cursors.""" + mock_wrapped_conn = mock.Mock() + cursor1 = mock.Mock() + cursor2 = mock.Mock() + mock_wrapped_conn.cursor.side_effect = [cursor1, cursor2] + + tracking_conn = se.CursorTrackingDBAPIConnection(mock_wrapped_conn) + + # Create two cursors + tracking_conn.cursor() + tracking_conn.cursor() + + # Cancel all cursors + tracking_conn.cancel_all_cursors() + + cursor1.cancel.assert_called_once() + cursor2.cancel.assert_called_once() + + +def test_cursor_tracking_dbapi_connection_handles_unhashable_cursor(): + """Test that CursorTrackingDBAPIConnection handles cursors that can't be added to weakset.""" + mock_wrapped_conn = mock.Mock() + + class UnhashableCursor: + __hash__ = None + + unhashable_cursor = UnhashableCursor() + mock_wrapped_conn.cursor.return_value = unhashable_cursor + + tracking_conn = se.CursorTrackingDBAPIConnection(mock_wrapped_conn) + + with mock.patch.object(se.logger, "warning") as mock_warning: + result = tracking_conn.cursor() + + assert result is unhashable_cursor + mock_warning.assert_called_once() + assert "can't be added to weakset" in mock_warning.call_args[0][0] + + +def test_cursor_tracking_sqlalchemy_connection_handles_none_dbapi_connection(): + """Test that CursorTrackingSQLAlchemyConnection handles None dbapi connection.""" + mock_conn_pool = mock.Mock() + mock_conn_pool.dbapi_connection = None + + mock_sa_conn = mock.Mock() + mock_sa_conn.connection = mock_conn_pool + + with mock.patch.object(se.logger, "warning") as mock_warning: + se.CursorTrackingSQLAlchemyConnection(mock_sa_conn) + + mock_warning.assert_called_once() + assert "DBAPI connection is None" in mock_warning.call_args[0][0] + + def test_build_params_for_bigquery_oauth_ok(): with mock.patch( "deepnote_toolkit.sql.sql_execution.bigquery.Client"