From 0ce692f49ec603226d25b3e4f103390f087ebde7 Mon Sep 17 00:00:00 2001 From: OlegWock Date: Fri, 30 Jan 2026 09:55:37 +0100 Subject: [PATCH 1/2] fix(trino): Monkeypatch TrinoQuery to cancel on KeyboardInterrupt --- deepnote_toolkit/runtime_patches.py | 57 ++++++++++++++++++++++- tests/unit/test_sql_execution_internal.py | 22 +++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/deepnote_toolkit/runtime_patches.py b/deepnote_toolkit/runtime_patches.py index 9dda657..786da37 100644 --- a/deepnote_toolkit/runtime_patches.py +++ b/deepnote_toolkit/runtime_patches.py @@ -1,10 +1,64 @@ -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union from deepnote_toolkit.logging import LoggerManager logger = LoggerManager().get_logger() +def _monkeypatch_trino_cancel_on_error(): + """Monkey patch Trino client to cancel queries on exceptions. + + When a query is running and an exception occurs (including KeyboardInterrupt + from cell cancellation), the query will continue running on the Trino server + unless explicitly cancelled. This patch wraps TrinoQuery.execute() and + TrinoQuery.fetch() to automatically cancel the query when an exception occurs. + """ + try: + from trino import client as trino_client + + _original_execute = trino_client.TrinoQuery.execute + _original_fetch = trino_client.TrinoQuery.fetch + + def _cancel_on_error(query: "trino_client.TrinoQuery") -> None: + """Best-effort cancel when an error occurs.""" + if not query._cancelled and query._next_uri: + try: + query.cancel() + except (KeyboardInterrupt, Exception): + pass + + def _patched_execute( + self: "trino_client.TrinoQuery", + additional_http_headers: Optional[Dict[str, str]] = None, + ) -> "trino_client.TrinoResult": + try: + return _original_execute(self, additional_http_headers) + except (KeyboardInterrupt, Exception): + _cancel_on_error(self) + raise + + def _patched_fetch( + self: "trino_client.TrinoQuery", + ) -> List[Union[List[Any], Any]]: + try: + return _original_fetch(self) + except (KeyboardInterrupt, Exception): + _cancel_on_error(self) + raise + + trino_client.TrinoQuery.execute = _patched_execute + trino_client.TrinoQuery.fetch = _patched_fetch + logger.debug( + "Successfully monkeypatched trino.client.TrinoQuery.execute and fetch" + ) + except ImportError: + logger.warning( + "Could not monkeypatch Trino cancel on error: trino not available" + ) + except Exception as e: + logger.warning("Failed to monkeypatch Trino cancel on error: %s", repr(e)) + + # TODO(BLU-5171): Temporary hack to allow cancelling BigQuery jobs on KeyboardInterrupt (e.g. when user cancels cell execution) # Can be removed once # 1. https://github.com/googleapis/python-bigquery/pull/2331 is merged and released @@ -50,4 +104,5 @@ def _wait_or_cancel( def apply_runtime_patches(): + _monkeypatch_trino_cancel_on_error() _monkeypatch_bigquery_wait_or_cancel() diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 05eecf0..65577ff 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -30,6 +30,28 @@ def test_bigquery_wait_or_cancel_handles_keyboard_interrupt(): mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0) +def test_trino_cancels_on_keyboard_interrupt(): + from trino import client as trino_client + + mock_request = mock.Mock() + mock_request.next_uri = "http://trino:8080/v1/statement/query123/2" + mock_query = trino_client.TrinoQuery(mock_request, query="SELECT 1") + + # Simulate KeyboardInterrupt during execute + with mock.patch.object( + mock_request, "post", side_effect=KeyboardInterrupt("User interrupted") + ): + mock_query._next_uri = "http://trino:8080/v1/statement/query123/1" + mock_query._cancelled = False + + with mock.patch.object(mock_query, "cancel") as mock_cancel: + with pytest.raises(KeyboardInterrupt): + # TrinoQuery should be patched by `_monkeypatch_trino_cancel_on_error` + mock_query.execute() + + mock_cancel.assert_called_once() + + def test_build_params_for_bigquery_oauth_ok(): with mock.patch( "deepnote_toolkit.sql.sql_execution.bigquery.Client" From 413a16bf88ca64f100ccecf1271ea3d7e3f8955e Mon Sep 17 00:00:00 2001 From: OlegWock Date: Fri, 30 Jan 2026 10:07:53 +0100 Subject: [PATCH 2/2] Address PR feedback & add tests --- deepnote_toolkit/runtime_patches.py | 22 +++++----- tests/unit/test_sql_execution_internal.py | 50 ++++++++++++++++++++++- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/deepnote_toolkit/runtime_patches.py b/deepnote_toolkit/runtime_patches.py index 786da37..985f1a7 100644 --- a/deepnote_toolkit/runtime_patches.py +++ b/deepnote_toolkit/runtime_patches.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any from deepnote_toolkit.logging import LoggerManager logger = LoggerManager().get_logger() -def _monkeypatch_trino_cancel_on_error(): +def _monkeypatch_trino_cancel_on_error() -> None: """Monkey patch Trino client to cancel queries on exceptions. When a query is running and an exception occurs (including KeyboardInterrupt @@ -29,7 +29,7 @@ def _cancel_on_error(query: "trino_client.TrinoQuery") -> None: def _patched_execute( self: "trino_client.TrinoQuery", - additional_http_headers: Optional[Dict[str, str]] = None, + additional_http_headers: dict[str, str] | None = None, ) -> "trino_client.TrinoResult": try: return _original_execute(self, additional_http_headers) @@ -39,7 +39,7 @@ def _patched_execute( def _patched_fetch( self: "trino_client.TrinoQuery", - ) -> List[Union[List[Any], Any]]: + ) -> list[list[Any] | Any]: try: return _original_fetch(self) except (KeyboardInterrupt, Exception): @@ -64,18 +64,18 @@ def _patched_fetch( # 1. https://github.com/googleapis/python-bigquery/pull/2331 is merged and released # 2. Dependencies updated for the toolkit. We don't depend on google-cloud-bigquery directly, but it's transitive # dependency through sqlalchemy-bigquery -def _monkeypatch_bigquery_wait_or_cancel(): +def _monkeypatch_bigquery_wait_or_cancel() -> None: try: import google.cloud.bigquery._job_helpers as _job_helpers from google.cloud.bigquery import job, table def _wait_or_cancel( job_obj: job.QueryJob, - api_timeout: Optional[float], - wait_timeout: Optional[Union[object, float]], - retry: Optional[Any], - page_size: Optional[int], - max_results: Optional[int], + api_timeout: float | None, + wait_timeout: object | float | None, + retry: Any | None, + page_size: int | None, + max_results: int | None, ) -> table.RowIterator: try: return job_obj.result( @@ -103,6 +103,6 @@ def _wait_or_cancel( logger.warning("Failed to monkeypatch BigQuery _wait_or_cancel: %s", repr(e)) -def apply_runtime_patches(): +def apply_runtime_patches() -> None: _monkeypatch_trino_cancel_on_error() _monkeypatch_bigquery_wait_or_cancel() diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 65577ff..abc4109 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -30,7 +30,7 @@ def test_bigquery_wait_or_cancel_handles_keyboard_interrupt(): mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0) -def test_trino_cancels_on_keyboard_interrupt(): +def test_trino_execute_cancels_on_keyboard_interrupt(): from trino import client as trino_client mock_request = mock.Mock() @@ -46,7 +46,53 @@ def test_trino_cancels_on_keyboard_interrupt(): with mock.patch.object(mock_query, "cancel") as mock_cancel: with pytest.raises(KeyboardInterrupt): - # TrinoQuery should be patched by `_monkeypatch_trino_cancel_on_error` + # TrinoQuery.execute should be patched by `_monkeypatch_trino_cancel_on_error` + mock_query.execute() + + mock_cancel.assert_called_once() + + +def test_trino_fetch_cancels_on_keyboard_interrupt(): + from trino import client as trino_client + + mock_request = mock.Mock() + mock_request.next_uri = "http://trino:8080/v1/statement/query123/2" + mock_query = trino_client.TrinoQuery(mock_request, query="SELECT 1") + + # Simulate KeyboardInterrupt during fetch + with mock.patch.object( + mock_request, "get", side_effect=KeyboardInterrupt("User interrupted") + ): + mock_query._next_uri = "http://trino:8080/v1/statement/query123/1" + mock_query._cancelled = False + + with mock.patch.object(mock_query, "cancel") as mock_cancel: + with pytest.raises(KeyboardInterrupt): + # TrinoQuery.fetch should be patched by `_monkeypatch_trino_cancel_on_error` + mock_query.fetch() + + mock_cancel.assert_called_once() + + +def test_trino_handles_cancel_failure_gracefully(): + from trino import client as trino_client + + mock_request = mock.Mock() + mock_request.next_uri = "http://trino:8080/v1/statement/query123/2" + mock_query = trino_client.TrinoQuery(mock_request, query="SELECT 1") + + # Simulate KeyboardInterrupt during execute, and cancel() itself fails + with mock.patch.object( + mock_request, "post", side_effect=KeyboardInterrupt("User interrupted") + ): + mock_query._next_uri = "http://trino:8080/v1/statement/query123/1" + mock_query._cancelled = False + + with mock.patch.object( + mock_query, "cancel", side_effect=RuntimeError("Cancel failed") + ) as mock_cancel: + # Should still raise the original KeyboardInterrupt, not the cancel error + with pytest.raises(KeyboardInterrupt): mock_query.execute() mock_cancel.assert_called_once()