From a640f4f2c42d9d92267340f6f5db510f5b1b5b66 Mon Sep 17 00:00:00 2001 From: Gene Hoffman <30377676+hoffmang9@users.noreply.github.com> Date: Sun, 15 Feb 2026 15:40:03 -0800 Subject: [PATCH] Fix/ws heartbeat reset on data (#12030) --- CHANGES/12030.bugfix.rst | 2 + CONTRIBUTORS.txt | 2 + aiohttp/client.py | 9 ++- aiohttp/client_proto.py | 13 ++- aiohttp/client_ws.py | 30 ++++++- aiohttp/web_protocol.py | 9 ++- aiohttp/web_ws.py | 46 +++++++++-- docs/client_reference.rst | 5 +- docs/web_reference.rst | 3 +- tests/test_client_ws.py | 126 +++++++++++++++++++++++++++++ tests/test_client_ws_functional.py | 60 ++++++++++++++ tests/test_web_protocol.py | 48 +++++++++++ tests/test_web_websocket.py | 66 +++++++++++++++ 13 files changed, 400 insertions(+), 19 deletions(-) create mode 100644 CHANGES/12030.bugfix.rst create mode 100644 tests/test_web_protocol.py diff --git a/CHANGES/12030.bugfix.rst b/CHANGES/12030.bugfix.rst new file mode 100644 index 00000000000..5f2f8ba5c3c --- /dev/null +++ b/CHANGES/12030.bugfix.rst @@ -0,0 +1,2 @@ +Reset the WebSocket heartbeat timer on inbound data to avoid false ping/pong timeouts while receiving large frames +-- by :user:`hoffmang9`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 9d593e1e6a2..7c5613648ca 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -120,6 +120,7 @@ Dmitry Trofimov Dmytro Bohomiakov Dmytro Kuznetsov Dustin J. Mitchell +Earle Lowe Eduard Iskandarov Eli Ribble Elizabeth Leddy @@ -146,6 +147,7 @@ Gabriel Tremblay Gang Ji Gary Leung Gary Wilson Jr. +Gene Hoffman Gennady Andreyev Georges Dubus Greg Holt diff --git a/aiohttp/client.py b/aiohttp/client.py index 7a5ef206ecd..26e67d490f2 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -1203,9 +1203,6 @@ async def _ws_connect( transport = conn.transport assert transport is not None reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop) - conn_proto.set_parser( - WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader - ) writer = WebSocketWriter( conn_proto, transport, @@ -1217,7 +1214,7 @@ async def _ws_connect( resp.close() raise else: - return self._ws_response_class( + ws_resp = self._ws_response_class( reader, writer, protocol, @@ -1230,6 +1227,10 @@ async def _ws_connect( compress=compress, client_notakeover=notakeover, ) + parser = WebSocketReader(reader, max_msg_size, decode_text=decode_text) + cb = None if heartbeat is None else ws_resp._on_data_received + conn_proto.set_parser(parser, reader, data_received_cb=cb) + return ws_resp def _prepare_headers(self, headers: LooseHeaders | None) -> "CIMultiDict[str]": """Add default headers and transform it to CIMultiDict""" diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 601b545c82a..07b23c48390 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -1,6 +1,6 @@ import asyncio from contextlib import suppress -from typing import Any +from typing import Any, Callable from .base_protocol import BaseProtocol from .client_exceptions import ( @@ -34,6 +34,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._payload: StreamReader | None = None self._skip_payload = False self._payload_parser: WebSocketReader | None = None + self._data_received_cb: Callable[[], None] | None = None self._timer = None @@ -203,7 +204,12 @@ def set_exception( self._drop_timeout() super().set_exception(exc, exc_cause) - def set_parser(self, parser: Any, payload: Any) -> None: + def set_parser( + self, + parser: Any, + payload: Any, + data_received_cb: Callable[[], None] | None = None, + ) -> None: # TODO: actual types are: # parser: WebSocketReader # payload: WebSocketDataQueue @@ -211,6 +217,7 @@ def set_parser(self, parser: Any, payload: Any) -> None: # Need an ABC for both types self._payload = payload self._payload_parser = parser + self._data_received_cb = data_received_cb self._drop_timeout() @@ -298,6 +305,8 @@ def data_received(self, data: bytes) -> None: # custom payload parser - currently always WebSocketReader if self._payload_parser is not None: + if self._data_received_cb is not None: + self._data_received_cb() eof, tail = self._payload_parser.feed_data(data) if eof: self._payload = None diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index f2e92149e55..b387f4bfc94 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -97,11 +97,17 @@ def __init__( self._compress = compress self._client_notakeover = client_notakeover self._ping_task: asyncio.Task[None] | None = None + self._need_heartbeat_reset = False + self._heartbeat_reset_handle: asyncio.Handle | None = None self._reset_heartbeat() def _cancel_heartbeat(self) -> None: self._cancel_pong_response_cb() + if self._heartbeat_reset_handle is not None: + self._heartbeat_reset_handle.cancel() + self._heartbeat_reset_handle = None + self._need_heartbeat_reset = False if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None @@ -114,6 +120,23 @@ def _cancel_pong_response_cb(self) -> None: self._pong_response_cb.cancel() self._pong_response_cb = None + def _on_data_received(self) -> None: + if self._heartbeat is None or self._need_heartbeat_reset: + return + loop = self._loop + assert loop is not None + # Coalesce multiple chunks received in the same loop tick into a single + # heartbeat reset. Resetting immediately per chunk increases timer churn. + self._need_heartbeat_reset = True + self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset) + + def _flush_heartbeat_reset(self) -> None: + self._heartbeat_reset_handle = None + if not self._need_heartbeat_reset: + return + self._reset_heartbeat() + self._need_heartbeat_reset = False + def _reset_heartbeat(self) -> None: if self._heartbeat is None: return @@ -137,6 +160,12 @@ def _reset_heartbeat(self) -> None: def _send_heartbeat(self) -> None: self._heartbeat_cb = None + + # If heartbeat reset is pending (data is being received), skip sending + # the ping and let the reset callback handle rescheduling the heartbeat. + if self._need_heartbeat_reset: + return + loop = self._loop now = loop.time() if now < self._heartbeat_when: @@ -364,7 +393,6 @@ async def receive( msg = await self._reader.read() else: msg = await self._reader.read() - self._reset_heartbeat() finally: self._waiting = False if self._close_wait: diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index bd39c48050d..4f6b8baf2b7 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -170,6 +170,7 @@ class RequestHandler(BaseProtocol, Generic[_Request]): "_task_handler", "_upgrade", "_payload_parser", + "_data_received_cb", "_request_parser", "logger", "access_log", @@ -226,6 +227,7 @@ def __init__( self._messages: deque[_MsgType] = deque() self._message_tail = b"" + self._data_received_cb: Callable[[], None] | None = None self._waiter: asyncio.Future[None] | None = None self._handler_waiter: asyncio.Future[None] | None = None @@ -402,11 +404,14 @@ def connection_lost(self, exc: BaseException | None) -> None: self._payload_parser.feed_eof() self._payload_parser = None - def set_parser(self, parser: Any) -> None: + def set_parser( + self, parser: Any, data_received_cb: Callable[[], None] | None = None + ) -> None: # Actual type is WebReader assert self._payload_parser is None self._payload_parser = parser + self._data_received_cb = data_received_cb if self._message_tail: self._payload_parser.feed_data(self._message_tail) @@ -450,6 +455,8 @@ def data_received(self, data: bytes) -> None: # feed payload elif data: + if self._data_received_cb is not None: + self._data_received_cb() eof, tail = self._payload_parser.feed_data(data) if eof: self.close() diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index d55d3687d92..dee7225d428 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -93,6 +93,8 @@ class WebSocketResponse(StreamResponse, Generic[_DecodeText]): _heartbeat_cb: asyncio.TimerHandle | None = None _pong_response_cb: asyncio.TimerHandle | None = None _ping_task: asyncio.Task[None] | None = None + _need_heartbeat_reset: bool = False + _heartbeat_reset_handle: asyncio.Handle | None = None def __init__( self, @@ -121,9 +123,15 @@ def __init__( self._max_msg_size = max_msg_size self._writer_limit = writer_limit self._decode_text = decode_text + self._need_heartbeat_reset = False + self._heartbeat_reset_handle = None def _cancel_heartbeat(self) -> None: self._cancel_pong_response_cb() + if self._heartbeat_reset_handle is not None: + self._heartbeat_reset_handle.cancel() + self._heartbeat_reset_handle = None + self._need_heartbeat_reset = False if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None @@ -136,6 +144,23 @@ def _cancel_pong_response_cb(self) -> None: self._pong_response_cb.cancel() self._pong_response_cb = None + def _on_data_received(self) -> None: + if self._heartbeat is None or self._need_heartbeat_reset: + return + loop = self._loop + assert loop is not None + # Coalesce multiple chunks received in the same loop tick into a single + # heartbeat reset. Resetting immediately per chunk increases timer churn. + self._need_heartbeat_reset = True + self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset) + + def _flush_heartbeat_reset(self) -> None: + self._heartbeat_reset_handle = None + if not self._need_heartbeat_reset: + return + self._reset_heartbeat() + self._need_heartbeat_reset = False + def _reset_heartbeat(self) -> None: if self._heartbeat is None: return @@ -159,6 +184,12 @@ def _reset_heartbeat(self) -> None: def _send_heartbeat(self) -> None: self._heartbeat_cb = None + + # If heartbeat reset is pending (data is being received), skip sending + # the ping and let the reset callback handle rescheduling the heartbeat. + if self._need_heartbeat_reset: + return + loop = self._loop assert loop is not None and self._writer is not None now = loop.time() @@ -352,14 +383,14 @@ def _post_start( loop = self._loop assert loop is not None self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop) - request.protocol.set_parser( - WebSocketReader( - self._reader, - self._max_msg_size, - compress=bool(self._compress), - decode_text=self._decode_text, - ) + parser = WebSocketReader( + self._reader, + self._max_msg_size, + compress=bool(self._compress), + decode_text=self._decode_text, ) + cb = None if self._heartbeat is None else self._on_data_received + request.protocol.set_parser(parser, data_received_cb=cb) # disable HTTP keepalive for WebSocket request.protocol.keep_alive(False) @@ -576,7 +607,6 @@ async def receive( msg = await self._reader.read() else: msg = await self._reader.read() - self._reset_heartbeat() finally: self._waiting = False if self._close_wait: diff --git a/docs/client_reference.rst b/docs/client_reference.rst index d085bded527..a4018141519 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -761,8 +761,9 @@ The client session supports the context manager protocol for self closing. :param float heartbeat: Send *ping* message every *heartbeat* seconds and wait *pong* response, if *pong* response is not received then - close connection. The timer is reset on any data - reception.(optional) + close connection. The timer is reset on any + inbound data reception (coalesced per event loop + iteration). (optional) :param str origin: Origin header to send to server(optional) diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 89615285dd9..b7ef0bc90a9 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -968,7 +968,8 @@ and :ref:`aiohttp-web-signals` handlers:: :param float heartbeat: Send `ping` message every `heartbeat` seconds and wait `pong` response, close connection if `pong` response is not - received. The timer is reset on any data reception. + received. The timer is reset on any inbound data + reception (coalesced per event loop iteration). :param float timeout: Timeout value for the ``close`` operation. After sending the close websocket message, diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 655240d3eaa..e36852dc26c 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -663,6 +663,132 @@ async def test_receive_runtime_err(loop: asyncio.AbstractEventLoop) -> None: await resp.receive() +async def test_heartbeat_reset_coalesces_on_data( + loop: asyncio.AbstractEventLoop, +) -> None: + response = mock.Mock() + response.connection = None + resp = client.ClientWebSocketResponse( + mock.Mock(), + mock.Mock(), + None, + response, + ClientWSTimeout(ws_receive=10.0), + True, + True, + loop, + heartbeat=0.05, + ) + with mock.patch.object(resp, "_reset_heartbeat", autospec=True) as reset: + resp._on_data_received() + resp._on_data_received() + + await asyncio.sleep(0) + + assert reset.call_count == 1 + + +async def test_receive_does_not_reset_heartbeat( + loop: asyncio.AbstractEventLoop, +) -> None: + response = mock.Mock() + response.connection = None + msg = mock.Mock(type=aiohttp.WSMsgType.TEXT) + reader = mock.Mock() + reader.read = mock.AsyncMock(return_value=msg) + resp = client.ClientWebSocketResponse( + reader, + mock.Mock(), + None, + response, + ClientWSTimeout(ws_receive=10.0), + True, + True, + loop, + heartbeat=0.05, + ) + with mock.patch.object(resp, "_reset_heartbeat", autospec=True) as reset: + received = await resp.receive() + + assert received is msg + reset.assert_not_called() + + +async def test_cancel_heartbeat_cancels_pending_heartbeat_reset_handle( + loop: asyncio.AbstractEventLoop, +) -> None: + response = mock.Mock() + response.connection = None + resp = client.ClientWebSocketResponse( + mock.Mock(), + mock.Mock(), + None, + response, + ClientWSTimeout(ws_receive=10.0), + True, + True, + loop, + heartbeat=0.05, + ) + + resp._on_data_received() + handle = resp._heartbeat_reset_handle + assert handle is not None + + resp._cancel_heartbeat() + + assert resp._heartbeat_reset_handle is None + assert resp._need_heartbeat_reset is False + assert handle.cancelled() + + +async def test_flush_heartbeat_reset_returns_early_when_not_needed( + loop: asyncio.AbstractEventLoop, +) -> None: + response = mock.Mock() + response.connection = None + resp = client.ClientWebSocketResponse( + mock.Mock(), + mock.Mock(), + None, + response, + ClientWSTimeout(ws_receive=10.0), + True, + True, + loop, + heartbeat=0.05, + ) + resp._need_heartbeat_reset = False + + with mock.patch.object(resp, "_reset_heartbeat", autospec=True) as reset: + resp._flush_heartbeat_reset() + reset.assert_not_called() + + +async def test_send_heartbeat_returns_early_when_reset_is_pending( + loop: asyncio.AbstractEventLoop, +) -> None: + response = mock.Mock() + response.connection = None + writer = mock.Mock() + resp = client.ClientWebSocketResponse( + mock.Mock(), + writer, + None, + response, + ClientWSTimeout(ws_receive=10.0), + True, + True, + loop, + heartbeat=0.05, + ) + resp._need_heartbeat_reset = True + + resp._send_heartbeat() + + writer.send_frame.assert_not_called() + + async def test_ws_connect_close_resp_on_err( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 3cefbb26d3d..ddd1404579f 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,6 +1,8 @@ import asyncio import json +import struct import sys +from contextlib import suppress from typing import Literal, NoReturn from unittest import mock @@ -818,6 +820,64 @@ async def handler(request: web.Request) -> NoReturn: assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE +async def test_heartbeat_does_not_timeout_while_receiving_large_frame( + aiohttp_client: AiohttpClient, +) -> None: + """Slowly receiving a single large frame should not trip heartbeat. + + Regression test for the behavior described in + https://github.com/aio-libs/aiohttp/discussions/12023: on slow connections, + the websocket heartbeat used to be reset only after a full message was read, + which could cause a ping/pong timeout while bytes were still being received. + """ + payload = b"x" * 2048 + heartbeat = 0.05 + chunk_size = 64 + delay = 0.01 + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + assert ws._writer is not None + transport = ws._writer.transport + + # Server-to-client frames are not masked. + length = len(payload) # payload is fixed length of 2048 bytes + header = bytes((0x82, 126)) + struct.pack("!H", length) + + frame = header + payload + for i in range(0, len(frame), chunk_size): + transport.write(frame[i : i + chunk_size]) + await asyncio.sleep(delay) + + # Ensure the server side is cleaned up. + with suppress(asyncio.TimeoutError): + await ws.receive(timeout=1.0) + with suppress(Exception): + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + async with client.ws_connect("/", heartbeat=heartbeat) as resp: + # If heartbeat was not reset on any incoming bytes, the client would start + # sending PINGs while we're still streaming the message body, and since the + # server handler never calls receive(), no PONG would be produced and the + # client would close with a ping/pong timeout. + with mock.patch.object( + resp._writer, "send_frame", wraps=resp._writer.send_frame + ) as sf: + msg = await resp.receive() + assert ( + sf.call_args_list.count(mock.call(b"", WSMsgType.PING)) == 0 + ), "Heartbeat PING sent while data was still being received" + assert msg.type is WSMsgType.BINARY + assert msg.data == payload + + async def test_heartbeat_no_pong_after_receive_many_messages( aiohttp_client: AiohttpClient, ) -> None: diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py new file mode 100644 index 00000000000..5ae1e5dd756 --- /dev/null +++ b/tests/test_web_protocol.py @@ -0,0 +1,48 @@ +import asyncio +from typing import Any, cast +from unittest import mock + +from aiohttp.web_protocol import RequestHandler + + +class _DummyManager: + def __init__(self) -> None: + self.request_handler = mock.Mock() + self.request_factory = mock.Mock() + + +class _DummyParser: + def __init__(self) -> None: + self.received: list[bytes] = [] + + def feed_data(self, data: bytes) -> tuple[bool, bytes]: + self.received.append(data) + return False, b"" + + +def test_set_parser_does_not_call_data_received_cb_for_tail( + loop: asyncio.AbstractEventLoop, +) -> None: + handler: RequestHandler[Any] = RequestHandler(cast(Any, _DummyManager()), loop=loop) + handler._message_tail = b"tail" + cb = mock.Mock() + parser = _DummyParser() + + handler.set_parser(parser, data_received_cb=cb) + + cb.assert_not_called() + assert parser.received == [b"tail"] + + +def test_data_received_calls_data_received_cb( + loop: asyncio.AbstractEventLoop, +) -> None: + handler: RequestHandler[Any] = RequestHandler(cast(Any, _DummyManager()), loop=loop) + cb = mock.Mock() + parser = _DummyParser() + + handler.set_parser(parser, data_received_cb=cb) + handler.data_received(b"x") + + assert cb.call_count == 1 + assert parser.received == [b"x"] diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 33380d94560..d3ec524b345 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -38,6 +38,7 @@ def app(loop: asyncio.AbstractEventLoop) -> web.Application: def protocol() -> web.RequestHandler[web.Request]: ret = mock.Mock() ret.set_parser.return_value = ret + ret._timeout_ceil_threshold = 5 return ret @@ -118,6 +119,41 @@ async def test_nonstarted_receive_str() -> None: await ws.receive_str() +async def test_cancel_heartbeat_cancels_pending_heartbeat_reset_handle( + loop: asyncio.AbstractEventLoop, +) -> None: + ws = web.WebSocketResponse(heartbeat=0.05) + ws._loop = loop + ws._on_data_received() + handle = ws._heartbeat_reset_handle + assert handle is not None + + ws._cancel_heartbeat() + + assert ws._heartbeat_reset_handle is None + assert ws._need_heartbeat_reset is False + assert handle.cancelled() + + +async def test_flush_heartbeat_reset_returns_early_when_not_needed() -> None: + ws = web.WebSocketResponse(heartbeat=0.05) + ws._need_heartbeat_reset = False + + with mock.patch.object(ws, "_reset_heartbeat") as reset: + ws._flush_heartbeat_reset() + reset.assert_not_called() + + +async def test_send_heartbeat_returns_early_when_reset_is_pending() -> None: + ws = web.WebSocketResponse(heartbeat=0.05) + ws._need_heartbeat_reset = True + + ws._send_heartbeat() + + assert ws._pong_response_cb is None + assert ws._ping_task is None + + async def test_nonstarted_receive_bytes() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): @@ -175,6 +211,36 @@ async def test_heartbeat_timeout(make_request: _RequestMaker) -> None: assert ws.closed +async def test_heartbeat_reset_coalesces_on_data( + make_request: _RequestMaker, +) -> None: + req = make_request("GET", "/") + ws = web.WebSocketResponse(heartbeat=0.05) + await ws.prepare(req) + + with mock.patch.object(ws, "_reset_heartbeat") as reset: + ws._on_data_received() + ws._on_data_received() + + await asyncio.sleep(0) + + assert reset.call_count == 1 + + +async def test_receive_does_not_reset_heartbeat() -> None: + ws = web.WebSocketResponse(heartbeat=0.05) + msg = mock.Mock(type=WSMsgType.TEXT) + reader = mock.Mock() + reader.read = mock.AsyncMock(return_value=msg) + ws._reader = reader + + with mock.patch.object(ws, "_reset_heartbeat") as reset: + received = await ws.receive() + + assert received is msg + reset.assert_not_called() + + def test_websocket_ready() -> None: websocket_ready = WebSocketReady(True, "chat") assert websocket_ready.ok is True