From 74ae5bbbc03d61843fd9cacb21daee66228b9e42 Mon Sep 17 00:00:00 2001 From: OlegWock Date: Fri, 30 Jan 2026 13:20:03 +0100 Subject: [PATCH 01/12] fix: Replace runtime patches with general solution --- deepnote_toolkit/runtime_initialization.py | 7 -- deepnote_toolkit/runtime_patches.py | 53 ----------- deepnote_toolkit/sql/sql_execution.py | 100 +++++++++++++++++++-- tests/unit/conftest.py | 8 -- tests/unit/test_sql_execution_internal.py | 100 +++++++++++++++++---- 5 files changed, 173 insertions(+), 95 deletions(-) delete mode 100644 deepnote_toolkit/runtime_patches.py diff --git a/deepnote_toolkit/runtime_initialization.py b/deepnote_toolkit/runtime_initialization.py index bfafaf5..d8a7e3c 100644 --- a/deepnote_toolkit/runtime_initialization.py +++ b/deepnote_toolkit/runtime_initialization.py @@ -6,8 +6,6 @@ import psycopg2.extensions import psycopg2.extras -from deepnote_toolkit.runtime_patches import apply_runtime_patches - from .dataframe_utils import add_formatters from .execute_post_start_hooks import execute_post_start_hooks from .logging import LoggerManager @@ -26,11 +24,6 @@ def init_deepnote_runtime(): logger.debug("Initializing Deepnote runtime environment started.") - try: - apply_runtime_patches() - except Exception as e: - logger.error("Failed to apply runtime patches with a error: %s", e) - # Register sparksql magic try: IPython.get_ipython().register_magics(SparkSql) diff --git a/deepnote_toolkit/runtime_patches.py b/deepnote_toolkit/runtime_patches.py deleted file mode 100644 index 9dda657..0000000 --- a/deepnote_toolkit/runtime_patches.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Any, Optional, Union - -from deepnote_toolkit.logging import LoggerManager - -logger = LoggerManager().get_logger() - - -# TODO(BLU-5171): Temporary hack to allow cancelling BigQuery jobs on KeyboardInterrupt (e.g. when user cancels cell execution) -# Can be removed once -# 1. https://github.com/googleapis/python-bigquery/pull/2331 is merged and released -# 2. Dependencies updated for the toolkit. We don't depend on google-cloud-bigquery directly, but it's transitive -# dependency through sqlalchemy-bigquery -def _monkeypatch_bigquery_wait_or_cancel(): - try: - import google.cloud.bigquery._job_helpers as _job_helpers - from google.cloud.bigquery import job, table - - def _wait_or_cancel( - job_obj: job.QueryJob, - api_timeout: Optional[float], - wait_timeout: Optional[Union[object, float]], - retry: Optional[Any], - page_size: Optional[int], - max_results: Optional[int], - ) -> table.RowIterator: - try: - return job_obj.result( - page_size=page_size, - max_results=max_results, - retry=retry, - timeout=wait_timeout, - ) - except (KeyboardInterrupt, Exception): - try: - job_obj.cancel(retry=retry, timeout=api_timeout) - except (KeyboardInterrupt, Exception): - pass - raise - - _job_helpers._wait_or_cancel = _wait_or_cancel - logger.debug( - "Successfully monkeypatched google.cloud.bigquery._job_helpers._wait_or_cancel" - ) - except ImportError: - logger.warning( - "Could not monkeypatch BigQuery _wait_or_cancel: google.cloud.bigquery not available" - ) - except Exception as e: - logger.warning("Failed to monkeypatch BigQuery _wait_or_cancel: %s", repr(e)) - - -def apply_runtime_patches(): - _monkeypatch_bigquery_wait_or_cancel() diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 3b8e96e..f51a174 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -1,16 +1,17 @@ import base64 import contextlib import json -import logging import re import uuid import warnings +import weakref from typing import Any from urllib.parse import quote import google.oauth2.credentials import numpy as np import requests +import wrapt from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from google.api_core.client_info import ClientInfo @@ -28,6 +29,7 @@ get_project_auth_headers, ) from deepnote_toolkit.ipython_utils import output_sql_metadata +from deepnote_toolkit.logging import LoggerManager from deepnote_toolkit.ocelots.pandas.utils import deduplicate_columns from deepnote_toolkit.sql.duckdb_sql import execute_duckdb_sql from deepnote_toolkit.sql.jinjasql_utils import render_jinja_sql_template @@ -37,7 +39,7 @@ from deepnote_toolkit.sql.sql_utils import is_single_select_query from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url -logger = logging.getLogger(__name__) +logger = LoggerManager().get_logger() class IntegrationFederatedAuthParams(BaseModel): @@ -517,12 +519,90 @@ def _query_data_source( engine.dispose() +class CursorTrackingDBAPIConnection(wrapt.ObjectProxy): + """Wraps DBAPI connection to track cursors as they're created.""" + + def __init__(self, wrapped, cursor_registry=None): + super().__init__(wrapped) + # Use provided registry or create our own + self._self_cursor_registry = ( + cursor_registry if cursor_registry is not None else weakref.WeakSet() + ) + + def cursor(self, *args, **kwargs): + cursor = self.__wrapped__.cursor(*args, **kwargs) + self._self_cursor_registry.add(cursor) + return cursor + + def cancel_all_cursors(self): + """Cancel all tracked cursors. Best-effort, ignores errors.""" + for cursor in list(self._self_cursor_registry): + _cancel_cursor(cursor) + + +class CursorTrackingSQLAlchemyConnection(wrapt.ObjectProxy): + """A SQLAlchemy connection wrapper that tracks cursors for cancellation. + + This wrapper replaces the internal _dbapi_connection with a tracking proxy, + so all cursors created (including by exec_driver_sql) are tracked. + """ + + def __init__(self, wrapped): + super().__init__(wrapped) + self._self_cursors = weakref.WeakSet() + self._install_dbapi_wrapper() + + def _install_dbapi_wrapper(self): + """Replace SQLAlchemy's internal DBAPI connection with our tracking wrapper.""" + try: + # Access the internal DBAPI connection + dbapi_conn = self.__wrapped__._dbapi_connection + if dbapi_conn is None: + logger.warning("DBAPI connection is None, cannot install tracking") + return + + self.__wrapped__._dbapi_connection = CursorTrackingDBAPIConnection( + dbapi_conn, self._self_cursors + ) + except Exception as e: + logger.warning(f"Could not install DBAPI wrapper: {e}") + + def cancel_all_cursors(self): + """Cancel all tracked cursors. Best-effort, ignores errors.""" + for cursor in list(self._self_cursors): + _cancel_cursor(cursor) + + +def _cancel_cursor(cursor): + """Best-effort cancel a cursor using available methods.""" + try: + # BigQuery: cancel via query_job if available + query_job = getattr(cursor, "query_job", None) + if query_job is not None and hasattr(query_job, "cancel"): + try: + query_job.cancel() + except (Exception, KeyboardInterrupt): + pass + + # Generic DBAPI: try cursor.cancel() if available (Trino, etc.) + if hasattr(cursor, "cancel"): + try: + cursor.cancel() + except (Exception, KeyboardInterrupt): + pass + except (Exception, KeyboardInterrupt): + pass # Best effort, ignore all errors + + def _execute_sql_on_engine(engine, query, bind_params): """Run *query* on *engine* and return a DataFrame. Uses pandas.read_sql_query to execute the query with a SQLAlchemy connection. For pandas 2.2+ and SQLAlchemy < 2.0, which requires a raw DB-API connection with a `.cursor()` attribute, we use the underlying connection. + + On exceptions (including KeyboardInterrupt from cell cancellation), all cursors + created during execution are cancelled to stop running queries on the server. """ import pandas as pd @@ -544,12 +624,13 @@ def _execute_sql_on_engine(engine, query, bind_params): ) with engine.begin() as connection: - try: - # For pandas 2.2+, use raw connection to avoid 'cursor' AttributeError - connection_for_pandas = ( - connection.connection if needs_raw_connection else connection - ) + # For pandas 2.2+ with SQLAlchemy < 2.0, use raw DBAPI connection + if needs_raw_connection: + tracking_connection = CursorTrackingDBAPIConnection(connection.connection) + else: + tracking_connection = CursorTrackingSQLAlchemyConnection(connection) + try: # pandas.read_sql_query expects params as tuple (not list) for qmark/format style params_for_pandas = ( tuple(bind_params) if isinstance(bind_params, list) else bind_params @@ -557,13 +638,16 @@ def _execute_sql_on_engine(engine, query, bind_params): return pd.read_sql_query( query, - con=connection_for_pandas, + con=tracking_connection, params=params_for_pandas, coerce_float=coerce_float, ) except ResourceClosedError: # this happens if the query is e.g. UPDATE and pandas tries to create a dataframe from its result return None + except BaseException: + tracking_connection.cancel_all_cursors() + raise def _build_params_for_bigquery_oauth(params): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5f9179d..a871dd8 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,14 +7,6 @@ import pytest -@pytest.fixture(autouse=True, scope="session") -def apply_patches() -> None: - """Apply runtime patches once before any tests run.""" - from deepnote_toolkit.runtime_patches import apply_runtime_patches - - apply_runtime_patches() - - @pytest.fixture(autouse=True) def clean_runtime_state() -> Generator[None, None, None]: """Automatically clean in-memory env state and config cache before and after each test.""" diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 05eecf0..5472d1c 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -9,25 +9,87 @@ from deepnote_toolkit.sql import sql_execution as se -def test_bigquery_wait_or_cancel_handles_keyboard_interrupt(): - import google.cloud.bigquery._job_helpers as _job_helpers - - mock_job = mock.Mock() - mock_job.result.side_effect = KeyboardInterrupt("User interrupted") - mock_job.cancel = mock.Mock() - - with pytest.raises(KeyboardInterrupt): - # _wait_or_cancel should be monkeypatched by `_monkeypatch_bigquery_wait_or_cancel` - _job_helpers._wait_or_cancel( - job_obj=mock_job, - api_timeout=30.0, - wait_timeout=60.0, - retry=None, - page_size=None, - max_results=None, - ) - - mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0) +def test_execute_sql_on_engine_cancels_cursor_on_keyboard_interrupt(): + """Test that _execute_sql_on_engine cancels cursors on KeyboardInterrupt. + + We replace SQLAlchemy's _dbapi_connection with our tracking wrapper. + When SQLAlchemy creates a cursor, it goes through our wrapper. + """ + mock_dbapi_cursor = mock.Mock() + mock_dbapi_connection = mock.Mock() + mock_dbapi_connection.cursor.return_value = mock_dbapi_cursor + + def mock_read_sql_query(query, con, **kwargs): + # Simulate SQLAlchemy creating a cursor via the DBAPI connection + # After our wrapper is installed, _dbapi_connection.cursor() is tracked + con._dbapi_connection.cursor() + raise KeyboardInterrupt("Cancelled") + + mock_engine = mock.Mock() + mock_connection = mock.Mock() + mock_connection._dbapi_connection = mock_dbapi_connection + + mock_engine.begin.return_value.__enter__ = mock.Mock(return_value=mock_connection) + mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False) + + with mock.patch("pandas.read_sql_query", side_effect=mock_read_sql_query): + with pytest.raises(KeyboardInterrupt): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + + mock_dbapi_cursor.cancel.assert_called_once() + + +def test_execute_sql_on_engine_cancels_bigquery_query_job(): + """Test that _execute_sql_on_engine cancels BigQuery query_job if present.""" + mock_dbapi_cursor = mock.Mock() + mock_query_job = mock.Mock() + mock_dbapi_cursor.query_job = mock_query_job + mock_dbapi_connection = mock.Mock() + mock_dbapi_connection.cursor.return_value = mock_dbapi_cursor + + def mock_read_sql_query(query, con, **kwargs): + con._dbapi_connection.cursor() + raise KeyboardInterrupt("Cancelled") + + mock_engine = mock.Mock() + mock_connection = mock.Mock() + mock_connection._dbapi_connection = mock_dbapi_connection + + mock_engine.begin.return_value.__enter__ = mock.Mock(return_value=mock_connection) + mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False) + + with mock.patch("pandas.read_sql_query", side_effect=mock_read_sql_query): + with pytest.raises(KeyboardInterrupt): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + + mock_query_job.cancel.assert_called_once() + mock_dbapi_cursor.cancel.assert_called_once() + + +def test_execute_sql_on_engine_handles_cancel_errors_gracefully(): + """Test that _execute_sql_on_engine handles cancel errors gracefully.""" + mock_dbapi_cursor = mock.Mock() + mock_dbapi_cursor.cancel.side_effect = RuntimeError("Cancel failed") + mock_dbapi_connection = mock.Mock() + mock_dbapi_connection.cursor.return_value = mock_dbapi_cursor + + def mock_read_sql_query(query, con, **kwargs): + con._dbapi_connection.cursor() + raise KeyboardInterrupt("Cancelled") + + mock_engine = mock.Mock() + mock_connection = mock.Mock() + mock_connection._dbapi_connection = mock_dbapi_connection + + mock_engine.begin.return_value.__enter__ = mock.Mock(return_value=mock_connection) + mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False) + + with mock.patch("pandas.read_sql_query", side_effect=mock_read_sql_query): + # Should raise original KeyboardInterrupt, not the cancel error + with pytest.raises(KeyboardInterrupt): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + + mock_dbapi_cursor.cancel.assert_called_once() def test_build_params_for_bigquery_oauth_ok(): From 880b1393d5a0079a1c320520b1a4b834937481e9 Mon Sep 17 00:00:00 2001 From: OlegWock Date: Fri, 30 Jan 2026 13:54:41 +0100 Subject: [PATCH 02/12] Update tests --- tests/unit/test_sql_execution_internal.py | 118 +++++++++++----------- 1 file changed, 61 insertions(+), 57 deletions(-) diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 5472d1c..b81b160 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -9,87 +9,91 @@ from deepnote_toolkit.sql import sql_execution as se -def test_execute_sql_on_engine_cancels_cursor_on_keyboard_interrupt(): - """Test that _execute_sql_on_engine cancels cursors on KeyboardInterrupt. +def _setup_mock_engine_with_cursor(mock_cursor): + """Helper to set up mock engine and connection with a custom cursor. - We replace SQLAlchemy's _dbapi_connection with our tracking wrapper. - When SQLAlchemy creates a cursor, it goes through our wrapper. + Returns mock_engine that can be passed to _execute_sql_on_engine. """ - mock_dbapi_cursor = mock.Mock() + import sqlalchemy + mock_dbapi_connection = mock.Mock() - mock_dbapi_connection.cursor.return_value = mock_dbapi_cursor + mock_dbapi_connection.cursor.return_value = mock_cursor - def mock_read_sql_query(query, con, **kwargs): - # Simulate SQLAlchemy creating a cursor via the DBAPI connection - # After our wrapper is installed, _dbapi_connection.cursor() is tracked - con._dbapi_connection.cursor() - raise KeyboardInterrupt("Cancelled") + mock_sa_connection = mock.Mock(spec=sqlalchemy.engine.Connection) + mock_sa_connection._dbapi_connection = mock_dbapi_connection + mock_sa_connection.connection = mock_dbapi_connection + mock_sa_connection.in_transaction.return_value = False - mock_engine = mock.Mock() - mock_connection = mock.Mock() - mock_connection._dbapi_connection = mock_dbapi_connection + # Mock exec_driver_sql to simulate cursor creation and execute + def mock_exec_driver_sql(sql, *args): + cursor = mock_sa_connection._dbapi_connection.cursor() + cursor.execute(sql, *args) + return cursor - mock_engine.begin.return_value.__enter__ = mock.Mock(return_value=mock_connection) + mock_sa_connection.exec_driver_sql = mock_exec_driver_sql + + mock_engine = mock.Mock() + mock_engine.begin.return_value.__enter__ = mock.Mock(return_value=mock_sa_connection) mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False) - with mock.patch("pandas.read_sql_query", side_effect=mock_read_sql_query): - with pytest.raises(KeyboardInterrupt): - se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + return mock_engine - mock_dbapi_cursor.cancel.assert_called_once() +def test_execute_sql_on_engine_cancels_cursor_on_keyboard_interrupt(): + """Test that _execute_sql_on_engine cancels cursors on KeyboardInterrupt. -def test_execute_sql_on_engine_cancels_bigquery_query_job(): - """Test that _execute_sql_on_engine cancels BigQuery query_job if present.""" - mock_dbapi_cursor = mock.Mock() - mock_query_job = mock.Mock() - mock_dbapi_cursor.query_job = mock_query_job - mock_dbapi_connection = mock.Mock() - mock_dbapi_connection.cursor.return_value = mock_dbapi_cursor + Uses real pandas.read_sql_query. The cursor's execute method throws + KeyboardInterrupt, simulating cell cancellation. + """ + mock_cursor = mock.MagicMock() + mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled") - def mock_read_sql_query(query, con, **kwargs): - con._dbapi_connection.cursor() - raise KeyboardInterrupt("Cancelled") + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) - mock_engine = mock.Mock() - mock_connection = mock.Mock() - mock_connection._dbapi_connection = mock_dbapi_connection + with pytest.raises(KeyboardInterrupt): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) - mock_engine.begin.return_value.__enter__ = mock.Mock(return_value=mock_connection) - mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False) + mock_cursor.cancel.assert_called_once() + + +def test_execute_sql_on_engine_cancels_bigquery_query_job(): + """Test that _execute_sql_on_engine cancels BigQuery query_job if present. + + Uses real pandas.read_sql_query. The cursor's execute method throws + KeyboardInterrupt, simulating cell cancellation. + """ + mock_query_job = mock.Mock() + mock_cursor = mock.MagicMock() + mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled") + mock_cursor.query_job = mock_query_job - with mock.patch("pandas.read_sql_query", side_effect=mock_read_sql_query): - with pytest.raises(KeyboardInterrupt): - se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) + + with pytest.raises(KeyboardInterrupt): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) mock_query_job.cancel.assert_called_once() - mock_dbapi_cursor.cancel.assert_called_once() + mock_cursor.cancel.assert_called_once() def test_execute_sql_on_engine_handles_cancel_errors_gracefully(): - """Test that _execute_sql_on_engine handles cancel errors gracefully.""" - mock_dbapi_cursor = mock.Mock() - mock_dbapi_cursor.cancel.side_effect = RuntimeError("Cancel failed") - mock_dbapi_connection = mock.Mock() - mock_dbapi_connection.cursor.return_value = mock_dbapi_cursor - - def mock_read_sql_query(query, con, **kwargs): - con._dbapi_connection.cursor() - raise KeyboardInterrupt("Cancelled") + """Test that _execute_sql_on_engine handles cancel errors gracefully. - mock_engine = mock.Mock() - mock_connection = mock.Mock() - mock_connection._dbapi_connection = mock_dbapi_connection + Uses real pandas.read_sql_query. The cursor's execute method throws + KeyboardInterrupt. The cursor's cancel method throws an error, which + should be silently handled. + """ + mock_cursor = mock.MagicMock() + mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled") + mock_cursor.cancel.side_effect = RuntimeError("Cancel failed") - mock_engine.begin.return_value.__enter__ = mock.Mock(return_value=mock_connection) - mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False) + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) - with mock.patch("pandas.read_sql_query", side_effect=mock_read_sql_query): - # Should raise original KeyboardInterrupt, not the cancel error - with pytest.raises(KeyboardInterrupt): - se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) + # Should raise original KeyboardInterrupt, not the cancel error + with pytest.raises(KeyboardInterrupt): + se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) - mock_dbapi_cursor.cancel.assert_called_once() + mock_cursor.cancel.assert_called_once() def test_build_params_for_bigquery_oauth_ok(): From 1ad219e0ce577bf7b2d6366523275167d4be17e0 Mon Sep 17 00:00:00 2001 From: OlegWock Date: Fri, 30 Jan 2026 13:57:58 +0100 Subject: [PATCH 03/12] Formatting --- tests/unit/test_sql_execution_internal.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index b81b160..3136ca7 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -33,7 +33,9 @@ def mock_exec_driver_sql(sql, *args): mock_sa_connection.exec_driver_sql = mock_exec_driver_sql mock_engine = mock.Mock() - mock_engine.begin.return_value.__enter__ = mock.Mock(return_value=mock_sa_connection) + mock_engine.begin.return_value.__enter__ = mock.Mock( + return_value=mock_sa_connection + ) mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False) return mock_engine From fc8b17c0174b1201e9b7aef932f0804174d48f54 Mon Sep 17 00:00:00 2001 From: OlegWock Date: Tue, 3 Feb 2026 14:46:32 +0100 Subject: [PATCH 04/12] Restore runtime patch for BigQuery --- deepnote_toolkit/runtime_initialization.py | 7 +++ deepnote_toolkit/runtime_patches.py | 53 ++++++++++++++++++++++ deepnote_toolkit/sql/sql_execution.py | 14 +----- tests/unit/conftest.py | 8 ++++ tests/unit/test_sql_execution_internal.py | 21 +++++++++ 5 files changed, 90 insertions(+), 13 deletions(-) create mode 100644 deepnote_toolkit/runtime_patches.py diff --git a/deepnote_toolkit/runtime_initialization.py b/deepnote_toolkit/runtime_initialization.py index d8a7e3c..bfafaf5 100644 --- a/deepnote_toolkit/runtime_initialization.py +++ b/deepnote_toolkit/runtime_initialization.py @@ -6,6 +6,8 @@ import psycopg2.extensions import psycopg2.extras +from deepnote_toolkit.runtime_patches import apply_runtime_patches + from .dataframe_utils import add_formatters from .execute_post_start_hooks import execute_post_start_hooks from .logging import LoggerManager @@ -24,6 +26,11 @@ def init_deepnote_runtime(): logger.debug("Initializing Deepnote runtime environment started.") + try: + apply_runtime_patches() + except Exception as e: + logger.error("Failed to apply runtime patches with a error: %s", e) + # Register sparksql magic try: IPython.get_ipython().register_magics(SparkSql) diff --git a/deepnote_toolkit/runtime_patches.py b/deepnote_toolkit/runtime_patches.py new file mode 100644 index 0000000..9dda657 --- /dev/null +++ b/deepnote_toolkit/runtime_patches.py @@ -0,0 +1,53 @@ +from typing import Any, Optional, Union + +from deepnote_toolkit.logging import LoggerManager + +logger = LoggerManager().get_logger() + + +# TODO(BLU-5171): Temporary hack to allow cancelling BigQuery jobs on KeyboardInterrupt (e.g. when user cancels cell execution) +# Can be removed once +# 1. https://github.com/googleapis/python-bigquery/pull/2331 is merged and released +# 2. Dependencies updated for the toolkit. We don't depend on google-cloud-bigquery directly, but it's transitive +# dependency through sqlalchemy-bigquery +def _monkeypatch_bigquery_wait_or_cancel(): + try: + import google.cloud.bigquery._job_helpers as _job_helpers + from google.cloud.bigquery import job, table + + def _wait_or_cancel( + job_obj: job.QueryJob, + api_timeout: Optional[float], + wait_timeout: Optional[Union[object, float]], + retry: Optional[Any], + page_size: Optional[int], + max_results: Optional[int], + ) -> table.RowIterator: + try: + return job_obj.result( + page_size=page_size, + max_results=max_results, + retry=retry, + timeout=wait_timeout, + ) + except (KeyboardInterrupt, Exception): + try: + job_obj.cancel(retry=retry, timeout=api_timeout) + except (KeyboardInterrupt, Exception): + pass + raise + + _job_helpers._wait_or_cancel = _wait_or_cancel + logger.debug( + "Successfully monkeypatched google.cloud.bigquery._job_helpers._wait_or_cancel" + ) + except ImportError: + logger.warning( + "Could not monkeypatch BigQuery _wait_or_cancel: google.cloud.bigquery not available" + ) + except Exception as e: + logger.warning("Failed to monkeypatch BigQuery _wait_or_cancel: %s", repr(e)) + + +def apply_runtime_patches(): + _monkeypatch_bigquery_wait_or_cancel() diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index f51a174..1f803f2 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -576,20 +576,8 @@ def cancel_all_cursors(self): def _cancel_cursor(cursor): """Best-effort cancel a cursor using available methods.""" try: - # BigQuery: cancel via query_job if available - query_job = getattr(cursor, "query_job", None) - if query_job is not None and hasattr(query_job, "cancel"): - try: - query_job.cancel() - except (Exception, KeyboardInterrupt): - pass - - # Generic DBAPI: try cursor.cancel() if available (Trino, etc.) if hasattr(cursor, "cancel"): - try: - cursor.cancel() - except (Exception, KeyboardInterrupt): - pass + cursor.cancel() except (Exception, KeyboardInterrupt): pass # Best effort, ignore all errors diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a871dd8..5f9179d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,6 +7,14 @@ import pytest +@pytest.fixture(autouse=True, scope="session") +def apply_patches() -> None: + """Apply runtime patches once before any tests run.""" + from deepnote_toolkit.runtime_patches import apply_runtime_patches + + apply_runtime_patches() + + @pytest.fixture(autouse=True) def clean_runtime_state() -> Generator[None, None, None]: """Automatically clean in-memory env state and config cache before and after each test.""" diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 3136ca7..8160212 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -41,6 +41,27 @@ def mock_exec_driver_sql(sql, *args): return mock_engine +def test_bigquery_wait_or_cancel_handles_keyboard_interrupt(): + import google.cloud.bigquery._job_helpers as _job_helpers + + mock_job = mock.Mock() + mock_job.result.side_effect = KeyboardInterrupt("User interrupted") + mock_job.cancel = mock.Mock() + + with pytest.raises(KeyboardInterrupt): + # _wait_or_cancel should be monkeypatched by `_monkeypatch_bigquery_wait_or_cancel` + _job_helpers._wait_or_cancel( + job_obj=mock_job, + api_timeout=30.0, + wait_timeout=60.0, + retry=None, + page_size=None, + max_results=None, + ) + + mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0) + + def test_execute_sql_on_engine_cancels_cursor_on_keyboard_interrupt(): """Test that _execute_sql_on_engine cancels cursors on KeyboardInterrupt. From a561b844e9bb810620d8628e1e043b87c8070d16 Mon Sep 17 00:00:00 2001 From: OlegWock Date: Tue, 3 Feb 2026 14:55:47 +0100 Subject: [PATCH 05/12] Address rabbit comments --- deepnote_toolkit/sql/sql_execution.py | 7 +++++-- tests/unit/test_sql_execution_internal.py | 10 +++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 1f803f2..5e249a5 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -531,7 +531,10 @@ def __init__(self, wrapped, cursor_registry=None): def cursor(self, *args, **kwargs): cursor = self.__wrapped__.cursor(*args, **kwargs) - self._self_cursor_registry.add(cursor) + try: + self._self_cursor_registry.add(cursor) + except TypeError: + logger.warning(f"DBAPI Cursor of type {type(cursor)} can't be added to weakset and thus can't be tracked.") return cursor def cancel_all_cursors(self): @@ -558,7 +561,7 @@ def _install_dbapi_wrapper(self): # Access the internal DBAPI connection dbapi_conn = self.__wrapped__._dbapi_connection if dbapi_conn is None: - logger.warning("DBAPI connection is None, cannot install tracking") + logger.warning(f"DBAPI connection is None (connection type {type(self.__wrapped__)}), cannot install tracking") return self.__wrapped__._dbapi_connection = CursorTrackingDBAPIConnection( diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 8160212..7db56d6 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -1,4 +1,5 @@ import uuid +from typing import Any from unittest import mock import numpy as np @@ -9,14 +10,14 @@ from deepnote_toolkit.sql import sql_execution as se -def _setup_mock_engine_with_cursor(mock_cursor): +def _setup_mock_engine_with_cursor(mock_cursor: mock.Mock) -> mock.Mock: """Helper to set up mock engine and connection with a custom cursor. Returns mock_engine that can be passed to _execute_sql_on_engine. """ import sqlalchemy - mock_dbapi_connection = mock.Mock() + mock_dbapi_connection: mock.Mock = mock.Mock() mock_dbapi_connection.cursor.return_value = mock_cursor mock_sa_connection = mock.Mock(spec=sqlalchemy.engine.Connection) @@ -24,9 +25,8 @@ def _setup_mock_engine_with_cursor(mock_cursor): mock_sa_connection.connection = mock_dbapi_connection mock_sa_connection.in_transaction.return_value = False - # Mock exec_driver_sql to simulate cursor creation and execute - def mock_exec_driver_sql(sql, *args): - cursor = mock_sa_connection._dbapi_connection.cursor() + def mock_exec_driver_sql(sql: str, *args: Any) -> mock.Mock: + cursor: mock.Mock = mock_sa_connection._dbapi_connection.cursor() cursor.execute(sql, *args) return cursor From 687aa9f86c96aa843f3c48688fe23c56aacfbb8a Mon Sep 17 00:00:00 2001 From: OlegWock Date: Tue, 3 Feb 2026 14:58:39 +0100 Subject: [PATCH 06/12] Formatting & remove outdated test --- deepnote_toolkit/sql/sql_execution.py | 8 ++++-- tests/unit/test_sql_execution_internal.py | 31 ++--------------------- 2 files changed, 8 insertions(+), 31 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 5e249a5..4895b2e 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -534,7 +534,9 @@ def cursor(self, *args, **kwargs): try: self._self_cursor_registry.add(cursor) except TypeError: - logger.warning(f"DBAPI Cursor of type {type(cursor)} can't be added to weakset and thus can't be tracked.") + logger.warning( + f"DBAPI Cursor of type {type(cursor)} can't be added to weakset and thus can't be tracked." + ) return cursor def cancel_all_cursors(self): @@ -561,7 +563,9 @@ def _install_dbapi_wrapper(self): # Access the internal DBAPI connection dbapi_conn = self.__wrapped__._dbapi_connection if dbapi_conn is None: - logger.warning(f"DBAPI connection is None (connection type {type(self.__wrapped__)}), cannot install tracking") + logger.warning( + f"DBAPI connection is None (connection type {type(self.__wrapped__)}), cannot install tracking" + ) return self.__wrapped__._dbapi_connection = CursorTrackingDBAPIConnection( diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 7db56d6..5654edd 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -63,11 +63,8 @@ def test_bigquery_wait_or_cancel_handles_keyboard_interrupt(): def test_execute_sql_on_engine_cancels_cursor_on_keyboard_interrupt(): - """Test that _execute_sql_on_engine cancels cursors on KeyboardInterrupt. + """Test that _execute_sql_on_engine cancels cursors on KeyboardInterrupt.""" - Uses real pandas.read_sql_query. The cursor's execute method throws - KeyboardInterrupt, simulating cell cancellation. - """ mock_cursor = mock.MagicMock() mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled") @@ -79,33 +76,9 @@ def test_execute_sql_on_engine_cancels_cursor_on_keyboard_interrupt(): mock_cursor.cancel.assert_called_once() -def test_execute_sql_on_engine_cancels_bigquery_query_job(): - """Test that _execute_sql_on_engine cancels BigQuery query_job if present. - - Uses real pandas.read_sql_query. The cursor's execute method throws - KeyboardInterrupt, simulating cell cancellation. - """ - mock_query_job = mock.Mock() - mock_cursor = mock.MagicMock() - mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled") - mock_cursor.query_job = mock_query_job - - mock_engine = _setup_mock_engine_with_cursor(mock_cursor) - - with pytest.raises(KeyboardInterrupt): - se._execute_sql_on_engine(mock_engine, "SELECT 1", {}) - - mock_query_job.cancel.assert_called_once() - mock_cursor.cancel.assert_called_once() - - def test_execute_sql_on_engine_handles_cancel_errors_gracefully(): - """Test that _execute_sql_on_engine handles cancel errors gracefully. + """Test that _execute_sql_on_engine handles cancel errors gracefully.""" - Uses real pandas.read_sql_query. The cursor's execute method throws - KeyboardInterrupt. The cursor's cancel method throws an error, which - should be silently handled. - """ mock_cursor = mock.MagicMock() mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled") mock_cursor.cancel.side_effect = RuntimeError("Cancel failed") From 1cc81d680bd1049fbd89fb62171725fe844f712c Mon Sep 17 00:00:00 2001 From: OlegWock Date: Tue, 3 Feb 2026 16:11:07 +0100 Subject: [PATCH 07/12] Add tests --- tests/unit/test_sql_execution_internal.py | 52 +++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 5654edd..b0e8a7c 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -92,6 +92,58 @@ def test_execute_sql_on_engine_handles_cancel_errors_gracefully(): mock_cursor.cancel.assert_called_once() +def test_cursor_tracking_dbapi_connection_cancel_all_cursors(): + """Test that CursorTrackingDBAPIConnection.cancel_all_cursors cancels all tracked cursors.""" + mock_wrapped_conn = mock.Mock() + cursor1 = mock.Mock() + cursor2 = mock.Mock() + mock_wrapped_conn.cursor.side_effect = [cursor1, cursor2] + + tracking_conn = se.CursorTrackingDBAPIConnection(mock_wrapped_conn) + + # Create two cursors + tracking_conn.cursor() + tracking_conn.cursor() + + # Cancel all cursors + tracking_conn.cancel_all_cursors() + + cursor1.cancel.assert_called_once() + cursor2.cancel.assert_called_once() + + +def test_cursor_tracking_dbapi_connection_handles_unhashable_cursor(): + """Test that CursorTrackingDBAPIConnection handles cursors that can't be added to weakset.""" + mock_wrapped_conn = mock.Mock() + + class UnhashableCursor: + __hash__ = None + + unhashable_cursor = UnhashableCursor() + mock_wrapped_conn.cursor.return_value = unhashable_cursor + + tracking_conn = se.CursorTrackingDBAPIConnection(mock_wrapped_conn) + + with mock.patch.object(se.logger, "warning") as mock_warning: + result = tracking_conn.cursor() + + assert result is unhashable_cursor + mock_warning.assert_called_once() + assert "can't be added to weakset" in mock_warning.call_args[0][0] + + +def test_cursor_tracking_sqlalchemy_connection_handles_none_dbapi_connection(): + """Test that CursorTrackingSQLAlchemyConnection handles None dbapi connection.""" + mock_sa_conn = mock.Mock() + mock_sa_conn._dbapi_connection = None + + with mock.patch.object(se.logger, "warning") as mock_warning: + se.CursorTrackingSQLAlchemyConnection(mock_sa_conn) + + mock_warning.assert_called_once() + assert "DBAPI connection is None" in mock_warning.call_args[0][0] + + def test_build_params_for_bigquery_oauth_ok(): with mock.patch( "deepnote_toolkit.sql.sql_execution.bigquery.Client" From c6b5106c462284bd375d6b809feeac8ebad776ae Mon Sep 17 00:00:00 2001 From: OlegWock Date: Wed, 4 Feb 2026 11:53:22 +0100 Subject: [PATCH 08/12] Address smaller PR comments --- deepnote_toolkit/sql/sql_execution.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 4895b2e..02e92d4 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -5,7 +5,7 @@ import uuid import warnings import weakref -from typing import Any +from typing import Any, Optional from urllib.parse import quote import google.oauth2.credentials @@ -18,7 +18,8 @@ from google.cloud import bigquery from packaging.version import parse as parse_version from pydantic import BaseModel -from sqlalchemy.engine import URL, create_engine, make_url +from sqlalchemy.engine import URL, Connection, create_engine, make_url +from sqlalchemy.engine.interfaces import DBAPIConnection, DBAPICursor from sqlalchemy.exc import ResourceClosedError from deepnote_core.pydantic_compat_helpers import model_validate_compat @@ -522,7 +523,11 @@ def _query_data_source( class CursorTrackingDBAPIConnection(wrapt.ObjectProxy): """Wraps DBAPI connection to track cursors as they're created.""" - def __init__(self, wrapped, cursor_registry=None): + def __init__( + self, + wrapped: DBAPIConnection, + cursor_registry: Optional[weakref.WeakSet[DBAPICursor]] = None, + ) -> None: super().__init__(wrapped) # Use provided registry or create our own self._self_cursor_registry = ( @@ -552,9 +557,9 @@ class CursorTrackingSQLAlchemyConnection(wrapt.ObjectProxy): so all cursors created (including by exec_driver_sql) are tracked. """ - def __init__(self, wrapped): + def __init__(self, wrapped: Connection) -> None: super().__init__(wrapped) - self._self_cursors = weakref.WeakSet() + self._self_cursors: weakref.WeakSet[DBAPICursor] = weakref.WeakSet() self._install_dbapi_wrapper() def _install_dbapi_wrapper(self): @@ -580,10 +585,10 @@ def cancel_all_cursors(self): _cancel_cursor(cursor) -def _cancel_cursor(cursor): +def _cancel_cursor(cursor: DBAPICursor) -> None: """Best-effort cancel a cursor using available methods.""" try: - if hasattr(cursor, "cancel"): + if hasattr(cursor, "cancel") and callable(cursor.cancel): cursor.cancel() except (Exception, KeyboardInterrupt): pass # Best effort, ignore all errors From 5595abcac82eed66b438f8ee35a535ca95eeba3e Mon Sep 17 00:00:00 2001 From: OlegWock Date: Wed, 4 Feb 2026 12:24:45 +0100 Subject: [PATCH 09/12] Fix missing interfaces for older SQLAlchemy --- deepnote_toolkit/sql/sql_execution.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 02e92d4..31535e5 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -5,7 +5,7 @@ import uuid import warnings import weakref -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from urllib.parse import quote import google.oauth2.credentials @@ -19,7 +19,6 @@ from packaging.version import parse as parse_version from pydantic import BaseModel from sqlalchemy.engine import URL, Connection, create_engine, make_url -from sqlalchemy.engine.interfaces import DBAPIConnection, DBAPICursor from sqlalchemy.exc import ResourceClosedError from deepnote_core.pydantic_compat_helpers import model_validate_compat @@ -40,6 +39,14 @@ from deepnote_toolkit.sql.sql_utils import is_single_select_query from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url +if TYPE_CHECKING: + try: + from sqlalchemy.engine.interfaces import DBAPIConnection, DBAPICursor + except ImportError: + # Not available in SQLAlchemy < 2.0. We use them only for typing, so replace with Any + DBAPIConnection = Any + DBAPICursor = Any + logger = LoggerManager().get_logger() @@ -525,8 +532,8 @@ class CursorTrackingDBAPIConnection(wrapt.ObjectProxy): def __init__( self, - wrapped: DBAPIConnection, - cursor_registry: Optional[weakref.WeakSet[DBAPICursor]] = None, + wrapped: "DBAPIConnection", + cursor_registry: Optional[weakref.WeakSet["DBAPICursor"]] = None, ) -> None: super().__init__(wrapped) # Use provided registry or create our own @@ -585,7 +592,7 @@ def cancel_all_cursors(self): _cancel_cursor(cursor) -def _cancel_cursor(cursor: DBAPICursor) -> None: +def _cancel_cursor(cursor: "DBAPICursor") -> None: """Best-effort cancel a cursor using available methods.""" try: if hasattr(cursor, "cancel") and callable(cursor.cancel): From 62576fb040bc211729eba56f77e8b62afa136e43 Mon Sep 17 00:00:00 2001 From: OlegWock Date: Thu, 5 Feb 2026 10:43:10 +0100 Subject: [PATCH 10/12] Avoid using private properties --- deepnote_toolkit/sql/sql_execution.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 31535e5..60eb4dd 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -553,7 +553,7 @@ def cursor(self, *args, **kwargs): def cancel_all_cursors(self): """Cancel all tracked cursors. Best-effort, ignores errors.""" - for cursor in list(self._self_cursor_registry): + for cursor in self._self_cursor_registry: _cancel_cursor(cursor) @@ -573,14 +573,14 @@ def _install_dbapi_wrapper(self): """Replace SQLAlchemy's internal DBAPI connection with our tracking wrapper.""" try: # Access the internal DBAPI connection - dbapi_conn = self.__wrapped__._dbapi_connection + dbapi_conn = self.__wrapped__.connection.dbapi_connection if dbapi_conn is None: logger.warning( f"DBAPI connection is None (connection type {type(self.__wrapped__)}), cannot install tracking" ) return - self.__wrapped__._dbapi_connection = CursorTrackingDBAPIConnection( + self.__wrapped__.connection.dbapi_connection = CursorTrackingDBAPIConnection( dbapi_conn, self._self_cursors ) except Exception as e: @@ -588,7 +588,7 @@ def _install_dbapi_wrapper(self): def cancel_all_cursors(self): """Cancel all tracked cursors. Best-effort, ignores errors.""" - for cursor in list(self._self_cursors): + for cursor in self._self_cursors: _cancel_cursor(cursor) From 9cb3ef4a716b8b3078f8f1bcfa467b926bab192e Mon Sep 17 00:00:00 2001 From: OlegWock Date: Thu, 5 Feb 2026 10:55:37 +0100 Subject: [PATCH 11/12] Properly handle older sqlalchemy versions --- deepnote_toolkit/sql/sql_execution.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 60eb4dd..fb94b6d 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -573,15 +573,23 @@ def _install_dbapi_wrapper(self): """Replace SQLAlchemy's internal DBAPI connection with our tracking wrapper.""" try: # Access the internal DBAPI connection - dbapi_conn = self.__wrapped__.connection.dbapi_connection + if hasattr(self.__wrapped__.connection, "dbapi_connection"): + dbapi_conn = self.__wrapped__.connection.dbapi_connection + dbapi_connection_attr_name = "dbapi_connection" + else: + # SQLAlchemy pre v1.4 + dbapi_conn = self.__wrapped__.connection.connection + dbapi_connection_attr_name = "connection" if dbapi_conn is None: logger.warning( f"DBAPI connection is None (connection type {type(self.__wrapped__)}), cannot install tracking" ) return - self.__wrapped__.connection.dbapi_connection = CursorTrackingDBAPIConnection( - dbapi_conn, self._self_cursors + setattr( + self.__wrapped__.connection, + dbapi_connection_attr_name, + CursorTrackingDBAPIConnection(dbapi_conn, self._self_cursors), ) except Exception as e: logger.warning(f"Could not install DBAPI wrapper: {e}") From 93302e6f786c2e37817180a7680ab4c8f9385800 Mon Sep 17 00:00:00 2001 From: OlegWock Date: Thu, 5 Feb 2026 11:25:20 +0100 Subject: [PATCH 12/12] Fix tests --- tests/unit/test_sql_execution_internal.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index b0e8a7c..6826ab6 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -20,13 +20,18 @@ def _setup_mock_engine_with_cursor(mock_cursor: mock.Mock) -> mock.Mock: mock_dbapi_connection: mock.Mock = mock.Mock() mock_dbapi_connection.cursor.return_value = mock_cursor + mock_pool_connection = mock.Mock() + mock_pool_connection.dbapi_connection = mock_dbapi_connection + mock_pool_connection.cursor.side_effect = ( + lambda: mock_pool_connection.dbapi_connection.cursor() + ) + mock_sa_connection = mock.Mock(spec=sqlalchemy.engine.Connection) - mock_sa_connection._dbapi_connection = mock_dbapi_connection - mock_sa_connection.connection = mock_dbapi_connection + mock_sa_connection.connection = mock_pool_connection mock_sa_connection.in_transaction.return_value = False def mock_exec_driver_sql(sql: str, *args: Any) -> mock.Mock: - cursor: mock.Mock = mock_sa_connection._dbapi_connection.cursor() + cursor: mock.Mock = mock_sa_connection.connection.cursor() cursor.execute(sql, *args) return cursor @@ -134,8 +139,11 @@ class UnhashableCursor: def test_cursor_tracking_sqlalchemy_connection_handles_none_dbapi_connection(): """Test that CursorTrackingSQLAlchemyConnection handles None dbapi connection.""" + mock_conn_pool = mock.Mock() + mock_conn_pool.dbapi_connection = None + mock_sa_conn = mock.Mock() - mock_sa_conn._dbapi_connection = None + mock_sa_conn.connection = mock_conn_pool with mock.patch.object(se.logger, "warning") as mock_warning: se.CursorTrackingSQLAlchemyConnection(mock_sa_conn)