diff --git a/README.md b/README.md index dc23d0d1d..9a661e978 100644 --- a/README.md +++ b/README.md @@ -1429,6 +1429,8 @@ app = Starlette( ) ``` +Security note: StreamableHTTP enforces a default `max_body_bytes=1_000_000` limit for incoming `application/json` POST bodies (413 on oversized payloads). Override via `mcp.streamable_http_app(max_body_bytes=...)` or `mcp.run("streamable-http", ..., max_body_bytes=...)`. Set to `None` to disable (not recommended). + _Full example: [examples/snippets/servers/streamable_http_basic_mounting.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/streamable_http_basic_mounting.py)_ @@ -1601,6 +1603,8 @@ app = Starlette( app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app())) ``` +Security note: SSE message endpoints enforce a default `max_body_bytes=1_000_000` limit for incoming `application/json` POST bodies (413 on oversized payloads). Override via `mcp.sse_app(max_body_bytes=...)` or `mcp.run("sse", ..., max_body_bytes=...)`. Set to `None` to disable (not recommended). + You can also mount multiple MCP servers at different sub-paths. The SSE transport automatically detects the mount path via ASGI's `root_path` mechanism, so message endpoints are correctly routed: ```python diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 79eb0fb0c..6b3f63677 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -12,6 +12,7 @@ from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode from mcp.server.auth.settings import ClientRegistrationOptions +from mcp.server.http_body import BodyTooLargeError, read_request_body from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata # this alias is a no-op; it's just to separate out the types exposed to the @@ -32,10 +33,12 @@ class RegistrationHandler: async def handle(self, request: Request) -> Response: # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 try: - body = await request.body() + body = await read_request_body(request, max_body_bytes=self.options.max_body_bytes) client_metadata = OAuthClientMetadata.model_validate_json(body) # Scope validation is handled below + except BodyTooLargeError: + return Response("Payload too large", status_code=413, headers={"Connection": "close"}) except ValidationError as validation_error: return PydanticJSONResponse( content=RegistrationErrorResponse( diff --git a/src/mcp/server/auth/settings.py b/src/mcp/server/auth/settings.py index 1649826db..db289ef38 100644 --- a/src/mcp/server/auth/settings.py +++ b/src/mcp/server/auth/settings.py @@ -1,11 +1,15 @@ from pydantic import AnyHttpUrl, BaseModel, Field +from mcp.server.http_body import DEFAULT_MAX_BODY_BYTES + class ClientRegistrationOptions(BaseModel): enabled: bool = False client_secret_expiry_seconds: int | None = None valid_scopes: list[str] | None = None default_scopes: list[str] | None = None + # Limit the size of incoming /register request bodies to avoid DoS via unbounded reads. + max_body_bytes: int = DEFAULT_MAX_BODY_BYTES class RevocationOptions(BaseModel): diff --git a/src/mcp/server/http_body.py b/src/mcp/server/http_body.py new file mode 100644 index 000000000..a0b6b7557 --- /dev/null +++ b/src/mcp/server/http_body.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from starlette.requests import Request + +DEFAULT_MAX_BODY_BYTES = 1_000_000 + + +@dataclass(frozen=True) +class BodyTooLargeError(Exception): + max_body_bytes: int + + def __str__(self) -> str: + return f"Request body exceeds max_body_bytes={self.max_body_bytes}" + + +async def read_request_body(request: Request, *, max_body_bytes: int | None = DEFAULT_MAX_BODY_BYTES) -> bytes: + """Read an HTTP request body with a hard cap. + + Notes: + - This avoids unbounded buffering of the request body in Python. + - If the body exceeds max_body_bytes, this raises BodyTooLargeError as soon + as possible. + """ + if max_body_bytes is None: + return await request.body() + + if max_body_bytes <= 0: + raise ValueError("max_body_bytes must be positive or None") + + # Fast-path: reject based on Content-Length when provided. + content_length = request.headers.get("content-length") + if content_length is not None: + try: + if int(content_length) > max_body_bytes: + raise BodyTooLargeError(max_body_bytes) + except ValueError: + # Ignore invalid Content-Length; we'll enforce while streaming. + pass + + body = bytearray() + async for chunk in request.stream(): + if not chunk: + continue + + # Never buffer more than max_body_bytes bytes. + remaining = max_body_bytes - len(body) + if remaining <= 0: + raise BodyTooLargeError(max_body_bytes) + if len(chunk) > remaining: + body.extend(chunk[:remaining]) + raise BodyTooLargeError(max_body_bytes) + + body.extend(chunk) + + return bytes(body) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 96dcaf1c7..2621dbe59 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -93,6 +93,7 @@ async def main(): from mcp.server.auth.settings import AuthSettings from mcp.server.context import ServerRequestContext from mcp.server.experimental.request_context import Experimental +from mcp.server.http_body import DEFAULT_MAX_BODY_BYTES from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents @@ -276,7 +277,7 @@ def session_manager(self) -> StreamableHTTPSessionManager: "Session manager can only be accessed after calling streamable_http_app(). " "The session manager is created lazily to avoid unnecessary initialization." ) - return self._session_manager # pragma: no cover + return self._session_manager def list_prompts(self): def decorator( @@ -810,6 +811,7 @@ def streamable_http_app( event_store: EventStore | None = None, retry_interval: int | None = None, transport_security: TransportSecuritySettings | None = None, + max_body_bytes: int | None = DEFAULT_MAX_BODY_BYTES, host: str = "127.0.0.1", auth: AuthSettings | None = None, token_verifier: TokenVerifier | None = None, @@ -817,7 +819,12 @@ def streamable_http_app( custom_starlette_routes: list[Route] | None = None, debug: bool = False, ) -> Starlette: - """Return an instance of the StreamableHTTP server app.""" + """Return an instance of the StreamableHTTP server app. + + Args: + max_body_bytes: Maximum size (in bytes) for JSON POST request bodies. Defaults + to 1_000_000. Set to None to disable this guard. + """ # Auto-enable DNS rebinding protection for localhost (IPv4 and IPv6) if transport_security is None and host in ("127.0.0.1", "localhost", "::1"): transport_security = TransportSecuritySettings( @@ -833,6 +840,7 @@ def streamable_http_app( json_response=json_response, stateless=stateless_http, security_settings=transport_security, + max_body_bytes=max_body_bytes, ) self._session_manager = session_manager diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 8c1fc342b..85726d7f2 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -28,6 +28,7 @@ from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation from mcp.server.elicitation import elicit_url as _elicit_url +from mcp.server.http_body import DEFAULT_MAX_BODY_BYTES from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel.server import LifespanResultT, Server from mcp.server.lowlevel.server import lifespan as default_lifespan @@ -208,7 +209,7 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - return self._lowlevel_server.session_manager # pragma: no cover + return self._lowlevel_server.session_manager @overload def run(self, transport: Literal["stdio"] = ...) -> None: ... @@ -223,6 +224,7 @@ def run( sse_path: str = ..., message_path: str = ..., transport_security: TransportSecuritySettings | None = ..., + max_body_bytes: int | None = ..., ) -> None: ... @overload @@ -238,6 +240,7 @@ def run( event_store: EventStore | None = ..., retry_interval: int | None = ..., transport_security: TransportSecuritySettings | None = ..., + max_body_bytes: int | None = ..., ) -> None: ... def run( @@ -725,6 +728,7 @@ async def run_sse_async( # pragma: no cover sse_path: str = "/sse", message_path: str = "/messages/", transport_security: TransportSecuritySettings | None = None, + max_body_bytes: int | None = DEFAULT_MAX_BODY_BYTES, ) -> None: """Run the server using SSE transport.""" import uvicorn @@ -734,6 +738,7 @@ async def run_sse_async( # pragma: no cover message_path=message_path, transport_security=transport_security, host=host, + max_body_bytes=max_body_bytes, ) config = uvicorn.Config( @@ -756,6 +761,7 @@ async def run_streamable_http_async( # pragma: no cover event_store: EventStore | None = None, retry_interval: int | None = None, transport_security: TransportSecuritySettings | None = None, + max_body_bytes: int | None = DEFAULT_MAX_BODY_BYTES, ) -> None: """Run the server using StreamableHTTP transport.""" import uvicorn @@ -768,6 +774,7 @@ async def run_streamable_http_async( # pragma: no cover retry_interval=retry_interval, transport_security=transport_security, host=host, + max_body_bytes=max_body_bytes, ) config = uvicorn.Config( @@ -785,6 +792,7 @@ def sse_app( sse_path: str = "/sse", message_path: str = "/messages/", transport_security: TransportSecuritySettings | None = None, + max_body_bytes: int | None = DEFAULT_MAX_BODY_BYTES, host: str = "127.0.0.1", ) -> Starlette: """Return an instance of the SSE server app.""" @@ -796,7 +804,11 @@ def sse_app( allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://[::1]:*"], ) - sse = SseServerTransport(message_path, security_settings=transport_security) + sse = SseServerTransport( + message_path, + security_settings=transport_security, + max_body_bytes=max_body_bytes, + ) async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no cover # Add client ID from auth context into request context if available @@ -914,6 +926,7 @@ def streamable_http_app( event_store: EventStore | None = None, retry_interval: int | None = None, transport_security: TransportSecuritySettings | None = None, + max_body_bytes: int | None = DEFAULT_MAX_BODY_BYTES, host: str = "127.0.0.1", ) -> Starlette: """Return an instance of the StreamableHTTP server app.""" @@ -924,6 +937,7 @@ def streamable_http_app( event_store=event_store, retry_interval=retry_interval, transport_security=transport_security, + max_body_bytes=max_body_bytes, host=host, auth=self.settings.auth, token_verifier=self._token_verifier, diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 5be6b78ca..07f108537 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -51,6 +51,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send from mcp import types +from mcp.server.http_body import DEFAULT_MAX_BODY_BYTES, BodyTooLargeError, read_request_body from mcp.server.transport_security import ( TransportSecurityMiddleware, TransportSecuritySettings, @@ -75,7 +76,12 @@ class SseServerTransport: _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] _security: TransportSecurityMiddleware - def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: + def __init__( + self, + endpoint: str, + security_settings: TransportSecuritySettings | None = None, + max_body_bytes: int | None = DEFAULT_MAX_BODY_BYTES, + ) -> None: """Creates a new SSE server transport, which will direct the client to POST messages to the relative path given. @@ -83,6 +89,8 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | endpoint: A relative path where messages should be posted (e.g., "/messages/"). security_settings: Optional security settings for DNS rebinding protection. + max_body_bytes: Maximum size (in bytes) for JSON POST request bodies. Defaults + to 1_000_000. Set to None to disable this guard. Note: We use relative paths instead of full URLs for several reasons: @@ -98,6 +106,8 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | """ super().__init__() + if max_body_bytes is not None and max_body_bytes <= 0: + raise ValueError("max_body_bytes must be positive or None") # Validate that endpoint is a relative path and not a full URL if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint: @@ -113,6 +123,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | self._endpoint = endpoint self._read_stream_writers = {} self._security = TransportSecurityMiddleware(security_settings) + self._max_body_bytes = max_body_bytes logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager @@ -194,7 +205,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") request = Request(scope, receive) @@ -223,7 +234,12 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) response = Response("Could not find session", status_code=404) return await response(scope, receive, send) - body = await request.body() + try: + body = await read_request_body(request, max_body_bytes=self._max_body_bytes) + except BodyTooLargeError as e: + response = Response("Payload too large", status_code=413, headers={"Connection": "close"}) + logger.warning(f"Received payload too large: {e}") + return await response(scope, receive, send) logger.debug(f"Received JSON: {body}") try: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index e9156f7ba..b54a1ad88 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -24,6 +24,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.http_body import DEFAULT_MAX_BODY_BYTES, BodyTooLargeError, read_request_body from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -132,6 +133,7 @@ def __init__( event_store: EventStore | None = None, security_settings: TransportSecuritySettings | None = None, retry_interval: int | None = None, + max_body_bytes: int | None = DEFAULT_MAX_BODY_BYTES, ) -> None: """Initialize a new StreamableHTTP server transport. @@ -148,18 +150,23 @@ def __init__( retry field. When set, the server will send a retry field in SSE priming events to control client reconnection timing for polling behavior. Only used when event_store is provided. + max_body_bytes: Maximum size (in bytes) for JSON POST request bodies. Defaults + to 1_000_000. Set to None to disable this guard. Raises: ValueError: If the session ID contains invalid characters. """ if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(mcp_session_id): raise ValueError("Session ID must only contain visible ASCII characters (0x21-0x7E)") + if max_body_bytes is not None and max_body_bytes <= 0: + raise ValueError("max_body_bytes must be positive or None") self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store self._security = TransportSecurityMiddleware(security_settings) self._retry_interval = retry_interval + self._max_body_bytes = max_body_bytes self._request_streams: dict[ RequestId, tuple[ @@ -289,7 +296,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: response_headers.update(headers) if self.mcp_session_id: @@ -427,6 +434,43 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se return False return True + async def _parse_jsonrpc_message( + self, + request: Request, + scope: Scope, + receive: Receive, + send: Send, + ) -> JSONRPCMessage | None: + """Read + parse a JSON-RPC message from an HTTP request body.""" + try: + body = await read_request_body(request, max_body_bytes=self._max_body_bytes) + except BodyTooLargeError as e: + response = self._create_error_response( + f"Payload too large: {e}", + HTTPStatus.REQUEST_ENTITY_TOO_LARGE, + headers={"Connection": "close"}, + ) + await response(scope, receive, send) + return None + + try: + raw_message = pydantic_core.from_json(body) + except ValueError as e: + response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR) + await response(scope, receive, send) + return None + + try: + return jsonrpc_message_adapter.validate_python(raw_message, by_name=False) + except ValidationError as e: # pragma: no cover + response = self._create_error_response( + f"Validation error: {str(e)}", + HTTPStatus.BAD_REQUEST, + INVALID_PARAMS, + ) + await response(scope, receive, send) + return None + async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: """Handle POST requests containing JSON-RPC messages.""" writer = self._read_stream_writer @@ -446,25 +490,8 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) return - # Parse the body - only read it once - body = await request.body() - - try: - raw_message = pydantic_core.from_json(body) - except ValueError as e: - response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR) - await response(scope, receive, send) - return - - try: - message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover - response = self._create_error_response( - f"Validation error: {str(e)}", - HTTPStatus.BAD_REQUEST, - INVALID_PARAMS, - ) - await response(scope, receive, send) + message = await self._parse_jsonrpc_message(request, scope, receive, send) + if message is None: return # Check if this is an initialization request diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index a954b24a4..c4b3ae3db 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -15,6 +15,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.http_body import DEFAULT_MAX_BODY_BYTES from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, EventStore, @@ -56,6 +57,8 @@ class StreamableHTTPSessionManager: security_settings: Optional transport security settings. retry_interval: Retry interval in milliseconds to suggest to clients in SSE retry field. Used for SSE polling behavior. + max_body_bytes: Maximum size (in bytes) for JSON POST request bodies. Defaults + to 1_000_000. Set to None to disable this guard. """ def __init__( @@ -66,6 +69,7 @@ def __init__( stateless: bool = False, security_settings: TransportSecuritySettings | None = None, retry_interval: int | None = None, + max_body_bytes: int | None = DEFAULT_MAX_BODY_BYTES, ): self.app = app self.event_store = event_store @@ -73,6 +77,7 @@ def __init__( self.stateless = stateless self.security_settings = security_settings self.retry_interval = retry_interval + self.max_body_bytes = max_body_bytes # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() @@ -147,6 +152,7 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: is_json_response_enabled=self.json_response, event_store=None, # No event store in stateless mode security_settings=self.security_settings, + max_body_bytes=self.max_body_bytes, ) # Start server in a new task @@ -198,6 +204,7 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S event_store=self.event_store, # May be None (no resumability) security_settings=self.security_settings, retry_interval=self.retry_interval, + max_body_bytes=self.max_body_bytes, ) assert http_transport.mcp_session_id is not None diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index f70c24eee..551979c14 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -350,6 +350,7 @@ async def test_nested_process_tree(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as f3: grandchild_file = f3.name + proc = None try: # Simple nested process tree test # We create parent -> child -> grandchild, each writing to a file @@ -405,9 +406,16 @@ async def test_nested_process_tree(self): for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: if os.path.exists(file_path): # pragma: no branch initial_size = os.path.getsize(file_path) - await anyio.sleep(0.3) - new_size = os.path.getsize(file_path) - assert new_size > initial_size, f"{name} process should be writing" + # Under high load (e.g. CI + xdist), short fixed sleeps can be flaky on Windows. + # Poll for growth within a small deadline rather than asserting after a single sleep. + deadline = time.monotonic() + 3.0 + while time.monotonic() < deadline: + await anyio.sleep(0.1) + new_size = os.path.getsize(file_path) + if new_size > initial_size: # pragma: no branch + break + else: # pragma: no cover + raise AssertionError(f"{name} process should be writing") # Terminate the whole tree await _terminate_process_tree(proc) @@ -424,6 +432,14 @@ async def test_nested_process_tree(self): print("SUCCESS: All processes in tree terminated") finally: + # If the "is writing" assertions fail, ensure we still terminate the process tree to + # avoid leaking subprocesses (which can cause cascading failures on Windows). + if proc is not None: # pragma: no branch + try: + await _terminate_process_tree(proc) + except (OSError, RuntimeError, ProcessLookupError): # pragma: no cover + pass + # Clean up all marker files for f in [parent_file, child_file, grandchild_file]: try: diff --git a/tests/server/auth/test_error_handling.py b/tests/server/auth/test_error_handling.py index 7c5c43582..d5290371a 100644 --- a/tests/server/auth/test_error_handling.py +++ b/tests/server/auth/test_error_handling.py @@ -116,6 +116,31 @@ async def test_registration_error_handling(client: httpx.AsyncClient, oauth_prov assert data["error_description"] == "The redirect URI is invalid" +@pytest.mark.anyio +async def test_registration_rejects_payload_too_large(oauth_provider: MockOAuthProvider): + client_registration_options = ClientRegistrationOptions(enabled=True, max_body_bytes=10) + revocation_options = RevocationOptions(enabled=False) + + auth_routes = create_auth_routes( + oauth_provider, + issuer_url=AnyHttpUrl("http://localhost"), + client_registration_options=client_registration_options, + revocation_options=revocation_options, + ) + app = Starlette(routes=auth_routes) + + transport = ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://localhost") as client: + body = b'{"a":"' + (b"x" * 20) + b'"}' + response = await client.post( + "/register", + content=body, + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 413, response.content + assert response.text == "Payload too large" + + @pytest.mark.anyio async def test_authorize_error_handling( client: httpx.AsyncClient, diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 979dc580f..55bbc03fe 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -77,6 +77,23 @@ async def test_sse_app_returns_starlette_app(self): assert sse_routes[0].path == "/sse" assert mount_routes[0].path == "/messages" + async def test_sse_app_passes_max_body_bytes(self): + mcp = MCPServer("test") + app = mcp.sse_app(host="0.0.0.0", max_body_bytes=123) + + mount_routes = [r for r in app.routes if isinstance(r, Mount)] + assert len(mount_routes) == 1 + + message_app: Any = mount_routes[0].app + assert hasattr(message_app, "__self__"), "Expected a bound method for message handler" + sse_transport: Any = message_app.__self__ + assert getattr(sse_transport, "_max_body_bytes") == 123 + + async def test_streamable_http_app_passes_max_body_bytes(self): + mcp = MCPServer("test") + mcp.streamable_http_app(host="0.0.0.0", max_body_bytes=123) + assert mcp.session_manager.max_body_bytes == 123 + async def test_non_ascii_description(self): """Test that MCPServer handles non-ASCII characters in descriptions correctly""" mcp = MCPServer() diff --git a/tests/server/test_http_body.py b/tests/server/test_http_body.py new file mode 100644 index 000000000..64e2fb56c --- /dev/null +++ b/tests/server/test_http_body.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import pytest +from starlette.requests import Request +from starlette.types import Message + +from mcp.server.http_body import BodyTooLargeError, read_request_body + + +def make_request(*, body_chunks: list[bytes], headers: dict[str, str] | None = None) -> Request: + scope = { + "type": "http", + "method": "POST", + "path": "/", + "query_string": b"", + "headers": [(k.lower().encode(), v.encode()) for k, v in (headers or {}).items()], + } + + messages: list[Message] = [ + { + "type": "http.request", + "body": chunk, + "more_body": i < len(body_chunks) - 1, + } + for i, chunk in enumerate(body_chunks) + ] + + async def receive() -> Message: + if messages: + return messages.pop(0) + return {"type": "http.request", "body": b"", "more_body": False} + + return Request(scope, receive) + + +pytestmark = pytest.mark.anyio + + +async def test_read_request_body_allows_disabling_limit_with_none(): + request = make_request(body_chunks=[b"x" * 20]) + body = await read_request_body(request, max_body_bytes=None) + assert body == b"x" * 20 + + +async def test_read_request_body_rejects_non_positive_limit(): + request = make_request(body_chunks=[b"{}"]) + with pytest.raises(ValueError, match="max_body_bytes must be positive or None"): + await read_request_body(request, max_body_bytes=0) + + +async def test_read_request_body_ignores_invalid_content_length_header(): + request = make_request(body_chunks=[b"{}"], headers={"content-length": "not-a-number"}) + body = await read_request_body(request, max_body_bytes=10) + assert body == b"{}" + + +async def test_read_request_body_errors_if_more_chunks_arrive_after_limit_is_reached(): + # First chunk reaches the limit exactly; the next non-empty chunk should error. + request = make_request(body_chunks=[b"12345", b"6"]) + with pytest.raises(BodyTooLargeError): + await read_request_body(request, max_body_bytes=5) + + +async def test_read_request_body_handles_empty_request_body(): + request = make_request(body_chunks=[]) + body = await read_request_body(request, max_body_bytes=10) + assert body == b"" diff --git a/tests/server/test_max_body_bytes_validation.py b/tests/server/test_max_body_bytes_validation.py new file mode 100644 index 000000000..b1a5eabf7 --- /dev/null +++ b/tests/server/test_max_body_bytes_validation.py @@ -0,0 +1,14 @@ +import pytest + +from mcp.server.sse import SseServerTransport +from mcp.server.streamable_http import StreamableHTTPServerTransport + + +def test_sse_transport_rejects_non_positive_max_body_bytes(): + with pytest.raises(ValueError, match="max_body_bytes must be positive or None"): + SseServerTransport("/messages/", max_body_bytes=0) + + +def test_streamable_http_transport_rejects_non_positive_max_body_bytes(): + with pytest.raises(ValueError, match="max_body_bytes must be positive or None"): + StreamableHTTPServerTransport(mcp_session_id=None, max_body_bytes=0) diff --git a/tests/server/test_sse_max_body_bytes.py b/tests/server/test_sse_max_body_bytes.py new file mode 100644 index 000000000..69893ed2a --- /dev/null +++ b/tests/server/test_sse_max_body_bytes.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from unittest.mock import AsyncMock +from uuid import uuid4 + +import anyio +import pytest +from pydantic import ValidationError +from starlette.responses import Response +from starlette.types import Message + +from mcp.server.sse import SseServerTransport +from mcp.shared.message import SessionMessage + + +def make_receive(body: bytes) -> Callable[[], Awaitable[Message]]: + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + return receive + + +@pytest.mark.anyio +async def test_sse_max_body_bytes_rejects_large_request(): + sse_transport = SseServerTransport("/messages/", max_body_bytes=10) + + session_id = uuid4() + writer, reader = anyio.create_memory_object_stream[SessionMessage | Exception](0) + try: + sse_transport._read_stream_writers[session_id] = writer + + sent_messages: list[Message] = [] + response_body = b"" + + async def send(message: Message): + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + scope = { + "type": "http", + "method": "POST", + "path": "/messages/", + "query_string": f"session_id={session_id.hex}".encode(), + "headers": [(b"content-type", b"application/json")], + } + + body = b'{"a":"' + (b"x" * 20) + b'"}' + + receive = make_receive(body) + + await sse_transport.handle_post_message(scope, receive, send) + + response_start = next( + (msg for msg in sent_messages if msg["type"] == "http.response.start"), + None, + ) + assert response_start is not None, "Should have sent a response" + assert response_start["status"] == 413 + assert response_body == b"Payload too large" + finally: + await writer.aclose() + await reader.aclose() + + +@pytest.mark.anyio +async def test_sse_handle_post_message_short_circuits_on_security_error(): + sse_transport = SseServerTransport("/messages/") + sse_transport._security.validate_request = AsyncMock(return_value=Response("blocked", status_code=403)) # type: ignore[method-assign] + + sent_messages: list[Message] = [] + response_body = b"" + + async def send(message: Message): + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + scope = { + "type": "http", + "method": "POST", + "path": "/messages/", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + } + receive = make_receive(b"{}") + + await sse_transport.handle_post_message(scope, receive, send) + + response_start = next((msg for msg in sent_messages if msg["type"] == "http.response.start"), None) + assert response_start is not None, "Should have sent a response" + assert response_start["status"] == 403 + assert response_body == b"blocked" + + +@pytest.mark.anyio +async def test_sse_handle_post_message_returns_400_when_session_id_missing(): + sse_transport = SseServerTransport("/messages/") + + sent_messages: list[Message] = [] + response_body = b"" + + async def send(message: Message): + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + scope = { + "type": "http", + "method": "POST", + "path": "/messages/", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + } + receive = make_receive(b"{}") + + await sse_transport.handle_post_message(scope, receive, send) + + response_start = next((msg for msg in sent_messages if msg["type"] == "http.response.start"), None) + assert response_start is not None, "Should have sent a response" + assert response_start["status"] == 400 + assert response_body == b"session_id is required" + + +@pytest.mark.anyio +async def test_sse_handle_post_message_returns_400_when_session_id_invalid(): + sse_transport = SseServerTransport("/messages/") + + sent_messages: list[Message] = [] + response_body = b"" + + async def send(message: Message): + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + scope = { + "type": "http", + "method": "POST", + "path": "/messages/", + "query_string": b"session_id=not-a-uuid", + "headers": [(b"content-type", b"application/json")], + } + receive = make_receive(b"{}") + + await sse_transport.handle_post_message(scope, receive, send) + + response_start = next((msg for msg in sent_messages if msg["type"] == "http.response.start"), None) + assert response_start is not None, "Should have sent a response" + assert response_start["status"] == 400 + assert response_body == b"Invalid session ID" + + +@pytest.mark.anyio +async def test_sse_handle_post_message_returns_404_when_session_not_found(): + sse_transport = SseServerTransport("/messages/") + + sent_messages: list[Message] = [] + response_body = b"" + + async def send(message: Message): + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + scope = { + "type": "http", + "method": "POST", + "path": "/messages/", + "query_string": f"session_id={uuid4().hex}".encode(), + "headers": [(b"content-type", b"application/json")], + } + receive = make_receive(b"{}") + + await sse_transport.handle_post_message(scope, receive, send) + + response_start = next((msg for msg in sent_messages if msg["type"] == "http.response.start"), None) + assert response_start is not None, "Should have sent a response" + assert response_start["status"] == 404 + assert response_body == b"Could not find session" + + +@pytest.mark.anyio +async def test_sse_handle_post_message_returns_400_and_sends_error_on_invalid_jsonrpc(): + sse_transport = SseServerTransport("/messages/", max_body_bytes=1_000) + + session_id = uuid4() + writer, reader = anyio.create_memory_object_stream[SessionMessage | Exception](1) + try: + sse_transport._read_stream_writers[session_id] = writer + + sent_messages: list[Message] = [] + response_body = b"" + + async def send(message: Message): + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + scope = { + "type": "http", + "method": "POST", + "path": "/messages/", + "query_string": f"session_id={session_id.hex}".encode(), + "headers": [(b"content-type", b"application/json")], + } + receive = make_receive(b"{}") + + await sse_transport.handle_post_message(scope, receive, send) + + response_start = next((msg for msg in sent_messages if msg["type"] == "http.response.start"), None) + assert response_start is not None, "Should have sent a response" + assert response_start["status"] == 400 + assert response_body == b"Could not parse message" + + err = await reader.receive() + assert isinstance(err, ValidationError) + finally: + await writer.aclose() + await reader.aclose() + + +@pytest.mark.anyio +async def test_sse_handle_post_message_accepts_valid_jsonrpc_and_sends_session_message(): + sse_transport = SseServerTransport("/messages/", max_body_bytes=1_000) + + session_id = uuid4() + writer, reader = anyio.create_memory_object_stream[SessionMessage | Exception](1) + try: + sse_transport._read_stream_writers[session_id] = writer + + sent_messages: list[Message] = [] + response_body = b"" + + async def send(message: Message): + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + scope = { + "type": "http", + "method": "POST", + "path": "/messages/", + "query_string": f"session_id={session_id.hex}".encode(), + "headers": [(b"content-type", b"application/json")], + } + + body = b'{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}' + receive = make_receive(body) + + await sse_transport.handle_post_message(scope, receive, send) + + response_start = next((msg for msg in sent_messages if msg["type"] == "http.response.start"), None) + assert response_start is not None, "Should have sent a response" + assert response_start["status"] == 202 + assert response_body == b"Accepted" + + session_message = await reader.receive() + assert isinstance(session_message, SessionMessage) + assert getattr(session_message.message, "method", None) == "initialize" + finally: + await writer.aclose() + await reader.aclose() diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index af1b23619..027e43518 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -313,3 +313,48 @@ async def mock_receive(): assert error_data["id"] == "server-error" assert error_data["error"]["code"] == INVALID_REQUEST assert error_data["error"]["message"] == "Session not found" + + +@pytest.mark.anyio +async def test_max_body_bytes_rejects_large_request(): + app = Server("test-max-body-bytes") + manager = StreamableHTTPSessionManager(app=app, max_body_bytes=10) + + async with manager.run(): + app.run = AsyncMock(return_value=None) + + sent_messages: list[Message] = [] + response_body = b"" + + async def mock_send(message: Message): + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ], + } + + body = b'{"a":"' + (b"x" * 20) + b'"}' + + async def mock_receive(): + return {"type": "http.request", "body": body, "more_body": False} + + await manager.handle_request(scope, mock_receive, mock_send) + + response_start = next( + (msg for msg in sent_messages if msg["type"] == "http.response.start"), + None, + ) + assert response_start is not None, "Should have sent a response" + assert response_start["status"] == 413 + + error_data = json.loads(response_body) + assert "Payload too large" in error_data["error"]["message"]