From 960d76cda16734ceee6875db8f6995a145a150c3 Mon Sep 17 00:00:00 2001 From: Alexander Alderman Webb Date: Fri, 30 Jan 2026 11:54:04 +0100 Subject: [PATCH 1/3] test(mcp): Use AsyncClient for SSE --- .../mcp/streaming_asgi_transport.py | 85 +++++++++ tests/integrations/mcp/test_mcp.py | 164 ++++++++++++++++-- 2 files changed, 237 insertions(+), 12 deletions(-) create mode 100644 tests/integrations/mcp/streaming_asgi_transport.py diff --git a/tests/integrations/mcp/streaming_asgi_transport.py b/tests/integrations/mcp/streaming_asgi_transport.py new file mode 100644 index 0000000000..03f84b0e91 --- /dev/null +++ b/tests/integrations/mcp/streaming_asgi_transport.py @@ -0,0 +1,85 @@ +import asyncio +from httpx import ASGITransport, Request, Response, AsyncByteStream +import anyio + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Callable, MutableMapping + + +class StreamingASGITransport(ASGITransport): + """ + Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing + tests involving SSE interactions to run in-process. + """ + + def __init__( + self, + app: "Callable", + keep_sse_alive: "asyncio.Event", + ) -> None: + self.keep_sse_alive = keep_sse_alive + super().__init__(app) + + async def handle_async_request(self, request: "Request") -> "Response": + scope = { + "type": "http", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "path": request.url.path, + "query_string": request.url.query, + } + + is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse" + if not is_streaming_sse: + return await super().handle_async_request(request) + + request_body = b"" + if request.content: + request_body = await request.aread() + + body_sender, body_receiver = anyio.create_memory_object_stream[bytes](0) + + async def receive() -> "dict[str, Any]": + if self.keep_sse_alive.is_set(): + return {"type": "http.disconnect"} + + await self.keep_sse_alive.wait() # Keep alive :) + return {"type": "http.request", "body": request_body, "more_body": False} + + async def send(message: "MutableMapping[str, Any]") -> None: + if message["type"] == "http.response.body": + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body == b"" and not more_body: + return + + if body: + await body_sender.send(body) + + if not more_body: + await body_sender.aclose() + + async def run_app(): + await self.app(scope, receive, send) + + class StreamingBodyStream(AsyncByteStream): + def __init__(self, receiver, task): + self.receiver = receiver + self.task = task + + async def __aiter__(self): + try: + async for chunk in self.receiver: + yield chunk + except anyio.EndOfStream: + pass + finally: + await self.task + + stream = StreamingBodyStream(body_receiver, asyncio.create_task(run_app())) + response = Response(status_code=200, headers=[], stream=stream) + + return response diff --git a/tests/integrations/mcp/test_mcp.py b/tests/integrations/mcp/test_mcp.py index 798953bda1..ab3c2cf73d 100644 --- a/tests/integrations/mcp/test_mcp.py +++ b/tests/integrations/mcp/test_mcp.py @@ -15,6 +15,14 @@ that the integration properly instruments MCP handlers with Sentry spans. """ +import sentry_sdk + +from urllib.parse import urlparse, parse_qs +import anyio +import asyncio +import httpx +from .streaming_asgi_transport import StreamingASGITransport + import pytest import json from unittest import mock @@ -32,9 +40,10 @@ async def __call__(self, *args, **kwargs): from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel import Server from mcp.server.lowlevel.server import request_ctx +from mcp.server.sse import SseServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.routing import Mount +from starlette.routing import Mount, Route, Response from starlette.applications import Starlette try: @@ -129,6 +138,98 @@ def __init__(self, messages): self.messages = messages +async def json_rpc_sse( + app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event" +): + context = {} + + stream_complete = asyncio.Event() + endpoint_parsed = asyncio.Event() + + # https://github.com/Kludex/starlette/issues/104#issuecomment-729087925 + async with httpx.AsyncClient( + transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive), + base_url="http://test", + ) as client: + + async def parse_stream(): + async with client.stream("GET", "/sse") as stream: + # Read directly from stream.stream instead of aiter_bytes() + async for chunk in stream.stream: + if b"event: endpoint" in chunk: + sse_text = chunk.decode("utf-8") + url = sse_text.split("data: ")[1] + + parsed = urlparse(url) + query_params = parse_qs(parsed.query) + context["session_id"] = query_params["session_id"][0] + endpoint_parsed.set() + continue + + if b"event: message" in chunk and b"structuredContent" in chunk: + sse_text = chunk.decode("utf-8") + + json_str = sse_text.split("data: ")[1] + context["response"] = json.loads(json_str) + break + + stream_complete.set() + + task = asyncio.create_task(parse_stream()) + await endpoint_parsed.wait() + + await client.post( + f"/messages/?session_id={context['session_id']}", + headers={ + "Content-Type": "application/json", + }, + json={ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-11-25", + "capabilities": {}, + }, + "id": request_id, + }, + ) + + # Notification response is mandatory. + # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle + await client.post( + f"/messages/?session_id={context['session_id']}", + headers={ + "Content-Type": "application/json", + "mcp-session-id": context["session_id"], + }, + json={ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {}, + }, + ) + + await client.post( + f"/messages/?session_id={context['session_id']}", + headers={ + "Content-Type": "application/json", + "mcp-session-id": context["session_id"], + }, + json={ + "jsonrpc": "2.0", + "method": method, + "params": params, + "id": request_id, + }, + ) + + await stream_complete.wait() + keep_sse_alive.set() + + return task, context["session_id"], context["response"] + + def test_integration_patches_server(sentry_init): """Test that MCPIntegration patches the Server class""" # Get original methods before integration @@ -1084,7 +1185,8 @@ async def async_tool(tool_name, arguments): assert all(span["op"] == OP.MCP_SERVER for span in tx["spans"]) -def test_sse_transport_detection(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_sse_transport_detection(sentry_init, capture_events): """Test that SSE transport is correctly detected via query parameter""" sentry_init( integrations=[MCPIntegration()], @@ -1093,29 +1195,67 @@ def test_sse_transport_detection(sentry_init, capture_events): events = capture_events() server = Server("test-server") + sse = SseServerTransport("/messages/") + + sse_connection_closed = asyncio.Event() + + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + async with anyio.create_task_group() as tg: - # Set up mock request context with SSE transport - mock_ctx = MockRequestContext( - request_id="req-sse", session_id="session-sse-123", transport="sse" + async def run_server(): + await server.run( + streams[0], streams[1], server.create_initialization_options() + ) + + tg.start_soon(run_server) + + sse_connection_closed.set() + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], ) - request_ctx.set(mock_ctx) @server.call_tool() - def test_tool(tool_name, arguments): + async def test_tool(tool_name, arguments): return {"result": "success"} - with start_transaction(name="mcp tx"): - result = test_tool("sse_tool", {}) + keep_sse_alive = asyncio.Event() + app_task, session_id, result = await json_rpc_sse( + app, + method="tools/call", + params={ + "name": "sse_tool", + "arguments": {}, + }, + request_id="req-sse", + keep_sse_alive=keep_sse_alive, + ) - assert result == {"result": "success"} + await sse_connection_closed.wait() + await app_task - (tx,) = events + assert result["result"]["structuredContent"] == {"result": "success"} + + transactions = [ + event + for event in events + if event["type"] == "transaction" and event["transaction"] == "/sse" + ] + assert len(transactions) == 1 + tx = transactions[0] span = tx["spans"][0] # Check that SSE transport is detected assert span["data"][SPANDATA.MCP_TRANSPORT] == "sse" assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp" - assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-sse-123" + assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id def test_streamable_http_transport_detection( From d6c2fa574b3ab218a6e95b1d744b06e042dc965e Mon Sep 17 00:00:00 2001 From: Alexander Alderman Webb Date: Fri, 30 Jan 2026 15:00:40 +0100 Subject: [PATCH 2/3] simplify streaming transport --- tests/integrations/mcp/streaming_asgi_transport.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integrations/mcp/streaming_asgi_transport.py b/tests/integrations/mcp/streaming_asgi_transport.py index 318705b1cb..681a5bc96e 100644 --- a/tests/integrations/mcp/streaming_asgi_transport.py +++ b/tests/integrations/mcp/streaming_asgi_transport.py @@ -66,9 +66,8 @@ async def run_app(): await self.app(scope, receive, send) class StreamingBodyStream(AsyncByteStream): - def __init__(self, receiver, task): + def __init__(self, receiver): self.receiver = receiver - self.task = task async def __aiter__(self): try: @@ -77,7 +76,8 @@ async def __aiter__(self): except anyio.EndOfStream: pass - stream = StreamingBodyStream(body_receiver, asyncio.create_task(run_app())) + stream = StreamingBodyStream(body_receiver) response = Response(status_code=200, headers=[], stream=stream) + asyncio.create_task(run_app()) return response From 0f47f063190cd2b84c1d02bc25333310d42c8ed0 Mon Sep 17 00:00:00 2001 From: Alexander Alderman Webb Date: Fri, 30 Jan 2026 15:20:22 +0100 Subject: [PATCH 3/3] remove unused import --- tests/integrations/mcp/test_mcp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integrations/mcp/test_mcp.py b/tests/integrations/mcp/test_mcp.py index 9eaf4e7bf1..8569ad18e4 100644 --- a/tests/integrations/mcp/test_mcp.py +++ b/tests/integrations/mcp/test_mcp.py @@ -15,8 +15,6 @@ that the integration properly instruments MCP handlers with Sentry spans. """ -import sentry_sdk - from urllib.parse import urlparse, parse_qs import anyio import asyncio