Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/12030.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Reset the WebSocket heartbeat timer on inbound data to avoid false ping/pong timeouts while receiving large frames
-- by :user:`hoffmang9`.
2 changes: 2 additions & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Dmitry Trofimov
Dmytro Bohomiakov
Dmytro Kuznetsov
Dustin J. Mitchell
Earle Lowe
Eduard Iskandarov
Eli Ribble
Elizabeth Leddy
Expand All @@ -146,6 +147,7 @@ Gabriel Tremblay
Gang Ji
Gary Leung
Gary Wilson Jr.
Gene Hoffman
Gennady Andreyev
Georges Dubus
Greg Holt
Expand Down
9 changes: 5 additions & 4 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,9 +1203,6 @@ async def _ws_connect(
transport = conn.transport
assert transport is not None
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
conn_proto.set_parser(
WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader
)
writer = WebSocketWriter(
conn_proto,
transport,
Expand All @@ -1217,7 +1214,7 @@ async def _ws_connect(
resp.close()
raise
else:
return self._ws_response_class(
ws_resp = self._ws_response_class(
reader,
writer,
protocol,
Expand All @@ -1230,6 +1227,10 @@ async def _ws_connect(
compress=compress,
client_notakeover=notakeover,
)
parser = WebSocketReader(reader, max_msg_size, decode_text=decode_text)
cb = None if heartbeat is None else ws_resp._on_data_received
conn_proto.set_parser(parser, reader, data_received_cb=cb)
return ws_resp

def _prepare_headers(self, headers: LooseHeaders | None) -> "CIMultiDict[str]":
"""Add default headers and transform it to CIMultiDict"""
Expand Down
13 changes: 11 additions & 2 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from contextlib import suppress
from typing import Any
from typing import Any, Callable

from .base_protocol import BaseProtocol
from .client_exceptions import (
Expand Down Expand Up @@ -34,6 +34,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._payload: StreamReader | None = None
self._skip_payload = False
self._payload_parser: WebSocketReader | None = None
self._data_received_cb: Callable[[], None] | None = None

self._timer = None

Expand Down Expand Up @@ -203,14 +204,20 @@ def set_exception(
self._drop_timeout()
super().set_exception(exc, exc_cause)

def set_parser(self, parser: Any, payload: Any) -> None:
def set_parser(
self,
parser: Any,
payload: Any,
data_received_cb: Callable[[], None] | None = None,
) -> None:
# TODO: actual types are:
# parser: WebSocketReader
# payload: WebSocketDataQueue
# but they are not generi enough
# Need an ABC for both types
self._payload = payload
self._payload_parser = parser
self._data_received_cb = data_received_cb

self._drop_timeout()

Expand Down Expand Up @@ -298,6 +305,8 @@ def data_received(self, data: bytes) -> None:

# custom payload parser - currently always WebSocketReader
if self._payload_parser is not None:
if self._data_received_cb is not None:
self._data_received_cb()
eof, tail = self._payload_parser.feed_data(data)
if eof:
self._payload = None
Expand Down
30 changes: 29 additions & 1 deletion aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,17 @@ def __init__(
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: asyncio.Task[None] | None = None
self._need_heartbeat_reset = False
self._heartbeat_reset_handle: asyncio.Handle | None = None

self._reset_heartbeat()

def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_reset_handle is not None:
self._heartbeat_reset_handle.cancel()
self._heartbeat_reset_handle = None
self._need_heartbeat_reset = False
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
Expand All @@ -114,6 +120,23 @@ def _cancel_pong_response_cb(self) -> None:
self._pong_response_cb.cancel()
self._pong_response_cb = None

def _on_data_received(self) -> None:
if self._heartbeat is None or self._need_heartbeat_reset:
return
loop = self._loop
assert loop is not None
# Coalesce multiple chunks received in the same loop tick into a single
# heartbeat reset. Resetting immediately per chunk increases timer churn.
self._need_heartbeat_reset = True
self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset)

def _flush_heartbeat_reset(self) -> None:
self._heartbeat_reset_handle = None
if not self._need_heartbeat_reset:
return
self._reset_heartbeat()
self._need_heartbeat_reset = False

def _reset_heartbeat(self) -> None:
if self._heartbeat is None:
return
Expand All @@ -137,6 +160,12 @@ def _reset_heartbeat(self) -> None:

def _send_heartbeat(self) -> None:
self._heartbeat_cb = None

# If heartbeat reset is pending (data is being received), skip sending
# the ping and let the reset callback handle rescheduling the heartbeat.
if self._need_heartbeat_reset:
return

loop = self._loop
now = loop.time()
if now < self._heartbeat_when:
Expand Down Expand Up @@ -364,7 +393,6 @@ async def receive(
msg = await self._reader.read()
else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
self._waiting = False
if self._close_wait:
Expand Down
9 changes: 8 additions & 1 deletion aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class RequestHandler(BaseProtocol, Generic[_Request]):
"_task_handler",
"_upgrade",
"_payload_parser",
"_data_received_cb",
"_request_parser",
"logger",
"access_log",
Expand Down Expand Up @@ -226,6 +227,7 @@ def __init__(

self._messages: deque[_MsgType] = deque()
self._message_tail = b""
self._data_received_cb: Callable[[], None] | None = None

self._waiter: asyncio.Future[None] | None = None
self._handler_waiter: asyncio.Future[None] | None = None
Expand Down Expand Up @@ -402,11 +404,14 @@ def connection_lost(self, exc: BaseException | None) -> None:
self._payload_parser.feed_eof()
self._payload_parser = None

def set_parser(self, parser: Any) -> None:
def set_parser(
self, parser: Any, data_received_cb: Callable[[], None] | None = None
) -> None:
# Actual type is WebReader
assert self._payload_parser is None

self._payload_parser = parser
self._data_received_cb = data_received_cb

if self._message_tail:
self._payload_parser.feed_data(self._message_tail)
Expand Down Expand Up @@ -450,6 +455,8 @@ def data_received(self, data: bytes) -> None:

# feed payload
elif data:
if self._data_received_cb is not None:
self._data_received_cb()
eof, tail = self._payload_parser.feed_data(data)
if eof:
self.close()
Expand Down
46 changes: 38 additions & 8 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class WebSocketResponse(StreamResponse, Generic[_DecodeText]):
_heartbeat_cb: asyncio.TimerHandle | None = None
_pong_response_cb: asyncio.TimerHandle | None = None
_ping_task: asyncio.Task[None] | None = None
_need_heartbeat_reset: bool = False
_heartbeat_reset_handle: asyncio.Handle | None = None

def __init__(
self,
Expand Down Expand Up @@ -121,9 +123,15 @@ def __init__(
self._max_msg_size = max_msg_size
self._writer_limit = writer_limit
self._decode_text = decode_text
self._need_heartbeat_reset = False
self._heartbeat_reset_handle = None

def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_reset_handle is not None:
self._heartbeat_reset_handle.cancel()
self._heartbeat_reset_handle = None
self._need_heartbeat_reset = False
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
Expand All @@ -136,6 +144,23 @@ def _cancel_pong_response_cb(self) -> None:
self._pong_response_cb.cancel()
self._pong_response_cb = None

def _on_data_received(self) -> None:
if self._heartbeat is None or self._need_heartbeat_reset:
return
loop = self._loop
assert loop is not None
# Coalesce multiple chunks received in the same loop tick into a single
# heartbeat reset. Resetting immediately per chunk increases timer churn.
self._need_heartbeat_reset = True
self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset)

def _flush_heartbeat_reset(self) -> None:
self._heartbeat_reset_handle = None
if not self._need_heartbeat_reset:
return
self._reset_heartbeat()
self._need_heartbeat_reset = False

def _reset_heartbeat(self) -> None:
if self._heartbeat is None:
return
Expand All @@ -159,6 +184,12 @@ def _reset_heartbeat(self) -> None:

def _send_heartbeat(self) -> None:
self._heartbeat_cb = None

# If heartbeat reset is pending (data is being received), skip sending
# the ping and let the reset callback handle rescheduling the heartbeat.
if self._need_heartbeat_reset:
return

loop = self._loop
assert loop is not None and self._writer is not None
now = loop.time()
Expand Down Expand Up @@ -352,14 +383,14 @@ def _post_start(
loop = self._loop
assert loop is not None
self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop)
request.protocol.set_parser(
WebSocketReader(
self._reader,
self._max_msg_size,
compress=bool(self._compress),
decode_text=self._decode_text,
)
parser = WebSocketReader(
self._reader,
self._max_msg_size,
compress=bool(self._compress),
decode_text=self._decode_text,
)
cb = None if self._heartbeat is None else self._on_data_received
request.protocol.set_parser(parser, data_received_cb=cb)
# disable HTTP keepalive for WebSocket
request.protocol.keep_alive(False)

Expand Down Expand Up @@ -576,7 +607,6 @@ async def receive(
msg = await self._reader.read()
else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
self._waiting = False
if self._close_wait:
Expand Down
5 changes: 3 additions & 2 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,9 @@ The client session supports the context manager protocol for self closing.
:param float heartbeat: Send *ping* message every *heartbeat*
seconds and wait *pong* response, if
*pong* response is not received then
close connection. The timer is reset on any data
reception.(optional)
close connection. The timer is reset on any
inbound data reception (coalesced per event loop
iteration). (optional)

:param str origin: Origin header to send to server(optional)

Expand Down
3 changes: 2 additions & 1 deletion docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,8 @@ and :ref:`aiohttp-web-signals` handlers::
:param float heartbeat: Send `ping` message every `heartbeat`
seconds and wait `pong` response, close
connection if `pong` response is not
received. The timer is reset on any data reception.
received. The timer is reset on any inbound data
reception (coalesced per event loop iteration).

:param float timeout: Timeout value for the ``close``
operation. After sending the close websocket message,
Expand Down
Loading
Loading