Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions deepnote_toolkit/runtime_patches.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,81 @@
from typing import Any, Optional, Union
from typing import Any

from deepnote_toolkit.logging import LoggerManager

logger = LoggerManager().get_logger()


def _monkeypatch_trino_cancel_on_error() -> None:
"""Monkey patch Trino client to cancel queries on exceptions.

When a query is running and an exception occurs (including KeyboardInterrupt
from cell cancellation), the query will continue running on the Trino server
unless explicitly cancelled. This patch wraps TrinoQuery.execute() and
TrinoQuery.fetch() to automatically cancel the query when an exception occurs.
"""
try:
from trino import client as trino_client

_original_execute = trino_client.TrinoQuery.execute
_original_fetch = trino_client.TrinoQuery.fetch

def _cancel_on_error(query: "trino_client.TrinoQuery") -> None:
"""Best-effort cancel when an error occurs."""
if not query._cancelled and query._next_uri:
try:
query.cancel()
except (KeyboardInterrupt, Exception):
pass

def _patched_execute(
self: "trino_client.TrinoQuery",
additional_http_headers: dict[str, str] | None = None,
) -> "trino_client.TrinoResult":
try:
return _original_execute(self, additional_http_headers)
except (KeyboardInterrupt, Exception):
_cancel_on_error(self)
raise

def _patched_fetch(
self: "trino_client.TrinoQuery",
) -> list[list[Any] | Any]:
try:
return _original_fetch(self)
except (KeyboardInterrupt, Exception):
_cancel_on_error(self)
raise

trino_client.TrinoQuery.execute = _patched_execute
trino_client.TrinoQuery.fetch = _patched_fetch
logger.debug(
"Successfully monkeypatched trino.client.TrinoQuery.execute and fetch"
)
except ImportError:
logger.warning(
"Could not monkeypatch Trino cancel on error: trino not available"
)
except Exception as e:
logger.warning("Failed to monkeypatch Trino cancel on error: %s", repr(e))


# TODO(BLU-5171): Temporary hack to allow cancelling BigQuery jobs on KeyboardInterrupt (e.g. when user cancels cell execution)
# Can be removed once
# 1. https://github.com/googleapis/python-bigquery/pull/2331 is merged and released
# 2. Dependencies updated for the toolkit. We don't depend on google-cloud-bigquery directly, but it's transitive
# dependency through sqlalchemy-bigquery
def _monkeypatch_bigquery_wait_or_cancel():
def _monkeypatch_bigquery_wait_or_cancel() -> None:
try:
import google.cloud.bigquery._job_helpers as _job_helpers
from google.cloud.bigquery import job, table

def _wait_or_cancel(
job_obj: job.QueryJob,
api_timeout: Optional[float],
wait_timeout: Optional[Union[object, float]],
retry: Optional[Any],
page_size: Optional[int],
max_results: Optional[int],
api_timeout: float | None,
wait_timeout: object | float | None,
retry: Any | None,
page_size: int | None,
max_results: int | None,
) -> table.RowIterator:
try:
return job_obj.result(
Expand Down Expand Up @@ -49,5 +103,6 @@ def _wait_or_cancel(
logger.warning("Failed to monkeypatch BigQuery _wait_or_cancel: %s", repr(e))


def apply_runtime_patches():
def apply_runtime_patches() -> None:
_monkeypatch_trino_cancel_on_error()
_monkeypatch_bigquery_wait_or_cancel()
68 changes: 68 additions & 0 deletions tests/unit/test_sql_execution_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,74 @@ def test_bigquery_wait_or_cancel_handles_keyboard_interrupt():
mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0)


def test_trino_execute_cancels_on_keyboard_interrupt():
from trino import client as trino_client

mock_request = mock.Mock()
mock_request.next_uri = "http://trino:8080/v1/statement/query123/2"
mock_query = trino_client.TrinoQuery(mock_request, query="SELECT 1")

# Simulate KeyboardInterrupt during execute
with mock.patch.object(
mock_request, "post", side_effect=KeyboardInterrupt("User interrupted")
):
mock_query._next_uri = "http://trino:8080/v1/statement/query123/1"
mock_query._cancelled = False

with mock.patch.object(mock_query, "cancel") as mock_cancel:
with pytest.raises(KeyboardInterrupt):
# TrinoQuery.execute should be patched by `_monkeypatch_trino_cancel_on_error`
mock_query.execute()

mock_cancel.assert_called_once()


def test_trino_fetch_cancels_on_keyboard_interrupt():
from trino import client as trino_client

mock_request = mock.Mock()
mock_request.next_uri = "http://trino:8080/v1/statement/query123/2"
mock_query = trino_client.TrinoQuery(mock_request, query="SELECT 1")

# Simulate KeyboardInterrupt during fetch
with mock.patch.object(
mock_request, "get", side_effect=KeyboardInterrupt("User interrupted")
):
mock_query._next_uri = "http://trino:8080/v1/statement/query123/1"
mock_query._cancelled = False

with mock.patch.object(mock_query, "cancel") as mock_cancel:
with pytest.raises(KeyboardInterrupt):
# TrinoQuery.fetch should be patched by `_monkeypatch_trino_cancel_on_error`
mock_query.fetch()

mock_cancel.assert_called_once()


def test_trino_handles_cancel_failure_gracefully():
from trino import client as trino_client

mock_request = mock.Mock()
mock_request.next_uri = "http://trino:8080/v1/statement/query123/2"
mock_query = trino_client.TrinoQuery(mock_request, query="SELECT 1")

# Simulate KeyboardInterrupt during execute, and cancel() itself fails
with mock.patch.object(
mock_request, "post", side_effect=KeyboardInterrupt("User interrupted")
):
mock_query._next_uri = "http://trino:8080/v1/statement/query123/1"
mock_query._cancelled = False

with mock.patch.object(
mock_query, "cancel", side_effect=RuntimeError("Cancel failed")
) as mock_cancel:
# Should still raise the original KeyboardInterrupt, not the cancel error
with pytest.raises(KeyboardInterrupt):
mock_query.execute()

mock_cancel.assert_called_once()


def test_build_params_for_bigquery_oauth_ok():
with mock.patch(
"deepnote_toolkit.sql.sql_execution.bigquery.Client"
Expand Down