-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Fix leaked anyio streams in streamable_http #1991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aabmass
wants to merge
8
commits into
modelcontextprotocol:main
Choose a base branch
from
aabmass:fix-leaked-stream
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+160
−142
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
1606c96
Test case showing failure at main
aabmass bbcfbf6
Fix leaked anyio streams
aabmass de5d624
Update "pragma: no cover" locations for improved coverage
aabmass b11e04f
Move tests
aabmass 1f6e8ec
Use Client instead of ClientSession
aabmass 1a943ad
chore(deps): bump the uv group across 1 directory with 2 updates (#1961)
dependabot[bot] cf9e43e
Merge main and fix uv.lock
aabmass a0afd1d
don't worry about it
Kludex File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -328,19 +328,19 @@ def _create_json_response( | |
| headers=response_headers, | ||
| ) | ||
|
|
||
| def _get_session_id(self, request: Request) -> str | None: # pragma: no cover | ||
| def _get_session_id(self, request: Request) -> str | None: | ||
| """Extract the session ID from request headers.""" | ||
| return request.headers.get(MCP_SESSION_ID_HEADER) | ||
|
|
||
| def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover | ||
| def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: | ||
| """Create event data dictionary from an EventMessage.""" | ||
| event_data = { | ||
| "event": "message", | ||
| "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), | ||
| } | ||
|
|
||
| # If an event ID was provided, include it | ||
| if event_message.event_id: | ||
| if event_message.event_id: # pragma: no cover | ||
| event_data["id"] = event_message.event_id | ||
|
|
||
| return event_data | ||
|
|
@@ -381,9 +381,9 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No | |
|
|
||
| if request.method == "POST": | ||
| await self._handle_post_request(scope, request, receive, send) | ||
| elif request.method == "GET": # pragma: no cover | ||
| elif request.method == "GET": | ||
| await self._handle_get_request(request, send) | ||
| elif request.method == "DELETE": # pragma: no cover | ||
| elif request.method == "DELETE": | ||
| await self._handle_delete_request(request, send) | ||
| else: # pragma: no cover | ||
| await self._handle_unsupported_request(request, send) | ||
|
|
@@ -470,14 +470,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re | |
| # Check if this is an initialization request | ||
| is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" | ||
|
|
||
| if is_initialization_request: # pragma: no cover | ||
| if is_initialization_request: | ||
| # Check if the server already has an established session | ||
| if self.mcp_session_id: | ||
| # Check if request has a session ID | ||
| request_session_id = self._get_session_id(request) | ||
|
|
||
| # If request has a session ID but doesn't match, return 404 | ||
| if request_session_id and request_session_id != self.mcp_session_id: | ||
| if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Not Found: Invalid or expired session ID", | ||
| HTTPStatus.NOT_FOUND, | ||
|
|
@@ -488,7 +488,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re | |
| return | ||
|
|
||
| # For notifications and responses only, return 202 Accepted | ||
| if not isinstance(message, JSONRPCRequest): # pragma: no cover | ||
| if not isinstance(message, JSONRPCRequest): | ||
| # Create response object and send it | ||
| response = self._create_json_response( | ||
| None, | ||
|
|
@@ -561,14 +561,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re | |
| await response(scope, receive, send) | ||
| finally: | ||
| await self._clean_up_memory_streams(request_id) | ||
| else: # pragma: no cover | ||
| else: | ||
| # Create SSE stream | ||
| sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) | ||
|
|
||
| # Store writer reference so close_sse_stream() can close it | ||
| self._sse_stream_writers[request_id] = sse_stream_writer | ||
|
|
||
| async def sse_writer(): | ||
| async def sse_writer(): # pragma: lax no cover | ||
| # Get the request ID from the incoming request message | ||
| try: | ||
| async with sse_stream_writer, request_stream_reader: | ||
|
|
@@ -617,11 +617,12 @@ async def sse_writer(): | |
| # Then send the message to be processed by the server | ||
| session_message = self._create_session_message(message, request, request_id, protocol_version) | ||
| await writer.send(session_message) | ||
| except Exception: | ||
| except Exception: # pragma: no cover | ||
| logger.exception("SSE response error") | ||
| await sse_stream_writer.aclose() | ||
| await sse_stream_reader.aclose() | ||
| await self._clean_up_memory_streams(request_id) | ||
| finally: | ||
| await sse_stream_reader.aclose() | ||
|
|
||
| except Exception as err: # pragma: no cover | ||
| logger.exception("Error handling POST request") | ||
|
|
@@ -635,33 +636,33 @@ async def sse_writer(): | |
| await writer.send(Exception(err)) | ||
| return | ||
|
|
||
| async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: no cover | ||
| async def _handle_get_request(self, request: Request, send: Send) -> None: | ||
| """Handle GET request to establish SSE. | ||
|
|
||
| This allows the server to communicate to the client without the client | ||
| first sending data via HTTP POST. The server can send JSON-RPC requests | ||
| and notifications on this stream. | ||
| """ | ||
| writer = self._read_stream_writer | ||
| if writer is None: | ||
| if writer is None: # pragma: no cover | ||
| raise ValueError("No read stream writer available. Ensure connect() is called first.") | ||
|
|
||
| # Validate Accept header - must include text/event-stream | ||
| _, has_sse = self._check_accept_headers(request) | ||
|
|
||
| if not has_sse: | ||
| if not has_sse: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Not Acceptable: Client must accept text/event-stream", | ||
| HTTPStatus.NOT_ACCEPTABLE, | ||
| ) | ||
| await response(request.scope, request.receive, send) | ||
| return | ||
|
|
||
| if not await self._validate_request_headers(request, send): | ||
| if not await self._validate_request_headers(request, send): # pragma: no cover | ||
| return | ||
|
|
||
| # Handle resumability: check for Last-Event-ID header | ||
| if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): | ||
| if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover | ||
| await self._replay_events(last_event_id, request, send) | ||
| return | ||
|
|
||
|
|
@@ -675,7 +676,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # pr | |
| headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id | ||
|
|
||
| # Check if we already have an active GET stream | ||
| if GET_STREAM_KEY in self._request_streams: | ||
| if GET_STREAM_KEY in self._request_streams: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Conflict: Only one SSE stream is allowed per session", | ||
| HTTPStatus.CONFLICT, | ||
|
|
@@ -695,7 +696,7 @@ async def standalone_sse_writer(): | |
|
|
||
| async with sse_stream_writer, standalone_stream_reader: | ||
| # Process messages from the standalone stream | ||
| async for event_message in standalone_stream_reader: | ||
| async for event_message in standalone_stream_reader: # pragma: lax no cover | ||
| # For the standalone stream, we handle: | ||
| # - JSONRPCNotification (server sends notifications to client) | ||
| # - JSONRPCRequest (server sends requests to client) | ||
|
|
@@ -704,7 +705,7 @@ async def standalone_sse_writer(): | |
| # Send the message via SSE | ||
| event_data = self._create_event_data(event_message) | ||
| await sse_stream_writer.send(event_data) | ||
| except Exception: | ||
| except Exception: # pragma: no cover | ||
| logger.exception("Error in standalone SSE writer") | ||
| finally: | ||
| logger.debug("Closing standalone SSE writer") | ||
|
|
@@ -720,16 +721,17 @@ async def standalone_sse_writer(): | |
| try: | ||
| # This will send headers immediately and establish the SSE connection | ||
| await response(request.scope, request.receive, send) | ||
| except Exception: | ||
| except Exception: # pragma: lax no cover | ||
| logger.exception("Error in standalone SSE response") | ||
| await self._clean_up_memory_streams(GET_STREAM_KEY) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this moves up?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This stays in the |
||
| finally: | ||
| await sse_stream_writer.aclose() | ||
| await sse_stream_reader.aclose() | ||
| await self._clean_up_memory_streams(GET_STREAM_KEY) | ||
|
|
||
| async def _handle_delete_request(self, request: Request, send: Send) -> None: # pragma: no cover | ||
| async def _handle_delete_request(self, request: Request, send: Send) -> None: | ||
| """Handle DELETE requests for explicit session termination.""" | ||
| # Validate session ID | ||
| if not self.mcp_session_id: | ||
| if not self.mcp_session_id: # pragma: no cover | ||
| # If no session ID set, return Method Not Allowed | ||
| response = self._create_error_response( | ||
| "Method Not Allowed: Session termination not supported", | ||
|
|
@@ -738,7 +740,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: # | |
| await response(request.scope, request.receive, send) | ||
| return | ||
|
|
||
| if not await self._validate_request_headers(request, send): | ||
| if not await self._validate_request_headers(request, send): # pragma: no cover | ||
| return | ||
|
|
||
| await self.terminate() | ||
|
|
@@ -796,24 +798,24 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non | |
| ) | ||
| await response(request.scope, request.receive, send) | ||
|
|
||
| async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: no cover | ||
| async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover | ||
| if not await self._validate_session(request, send): | ||
| return False | ||
| if not await self._validate_protocol_version(request, send): | ||
| return False | ||
| return True | ||
|
|
||
| async def _validate_session(self, request: Request, send: Send) -> bool: # pragma: no cover | ||
| async def _validate_session(self, request: Request, send: Send) -> bool: | ||
| """Validate the session ID in the request.""" | ||
| if not self.mcp_session_id: | ||
| if not self.mcp_session_id: # pragma: no cover | ||
| # If we're not using session IDs, return True | ||
| return True | ||
|
|
||
| # Get the session ID from the request headers | ||
| request_session_id = self._get_session_id(request) | ||
|
|
||
| # If no session ID provided but required, return error | ||
| if not request_session_id: | ||
| if not request_session_id: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Bad Request: Missing session ID", | ||
| HTTPStatus.BAD_REQUEST, | ||
|
|
@@ -822,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag | |
| return False | ||
|
|
||
| # If session ID doesn't match, return error | ||
| if request_session_id != self.mcp_session_id: | ||
| if request_session_id != self.mcp_session_id: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Not Found: Invalid or expired session ID", | ||
| HTTPStatus.NOT_FOUND, | ||
|
|
@@ -832,17 +834,17 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag | |
|
|
||
| return True | ||
|
|
||
| async def _validate_protocol_version(self, request: Request, send: Send) -> bool: # pragma: no cover | ||
| async def _validate_protocol_version(self, request: Request, send: Send) -> bool: | ||
| """Validate the protocol version header in the request.""" | ||
| # Get the protocol version from the request headers | ||
| protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) | ||
|
|
||
| # If no protocol version provided, assume default version | ||
| if protocol_version is None: | ||
| if protocol_version is None: # pragma: no cover | ||
| protocol_version = DEFAULT_NEGOTIATED_VERSION | ||
|
|
||
| # Check if the protocol version is supported | ||
| if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: | ||
| if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover | ||
| supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) | ||
| response = self._create_error_response( | ||
| f"Bad Request: Unsupported protocol version: {protocol_version}. " | ||
|
|
@@ -1004,10 +1006,7 @@ async def message_router(): | |
| try: | ||
| # Send both the message and the event ID | ||
| await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) | ||
| except ( # pragma: no cover | ||
| anyio.BrokenResourceError, | ||
| anyio.ClosedResourceError, | ||
| ): | ||
| except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover | ||
| # Stream might be closed, remove from registry | ||
| self._request_streams.pop(request_stream_id, None) | ||
| else: # pragma: no cover | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the reader moves to the finally but not the writer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On L569 it gets stored
Should I be closing it here too, or does it need to outlive this function?