From 609863c77a29d2f39ad65b3c0265f91fa93a2ec1 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 3 Feb 2026 10:53:48 +0100 Subject: [PATCH 1/4] refactor: split RequestContext between server and client --- README.v2.md | 5 ++- .../main.py | 5 ++- examples/snippets/clients/stdio_client.py | 2 +- .../clients/url_elicitation_client.py | 2 +- examples/snippets/servers/lifespan_example.py | 3 +- src/mcp/client/experimental/task_handlers.py | 30 +++++++++--------- src/mcp/client/session.py | 31 +++++++------------ src/mcp/server/context.py | 21 +++++++++++++ .../server/experimental/request_context.py | 12 ++----- src/mcp/server/lowlevel/server.py | 27 ++++++---------- src/mcp/server/mcpserver/prompts/base.py | 5 ++- src/mcp/server/mcpserver/prompts/manager.py | 5 ++- .../mcpserver/resources/resource_manager.py | 5 ++- .../server/mcpserver/resources/templates.py | 5 ++- src/mcp/server/mcpserver/server.py | 16 +++++----- src/mcp/server/mcpserver/tools/base.py | 5 ++- .../server/mcpserver/tools/tool_manager.py | 5 ++- src/mcp/shared/context.py | 20 +++--------- src/mcp/shared/progress.py | 24 +++----------- tests/client/test_list_roots_callback.py | 4 +-- tests/client/test_sampling_callback.py | 4 +-- tests/client/test_session.py | 8 ++--- .../tasks/client/test_capabilities.py | 10 +++--- .../tasks/client/test_handlers.py | 20 ++++++------ .../tasks/test_elicitation_scenarios.py | 16 +++++----- tests/issues/test_176_progress_token.py | 6 ++-- tests/issues/test_355_type_error.py | 3 +- tests/server/mcpserver/test_elicitation.py | 20 ++++++------ tests/server/mcpserver/test_integration.py | 4 +-- tests/server/mcpserver/test_tool_manager.py | 6 ++-- .../server/mcpserver/test_url_elicitation.py | 22 ++++++------- tests/shared/test_progress_notifications.py | 10 ++---- tests/shared/test_streamable_http.py | 2 +- 33 files changed, 161 insertions(+), 202 deletions(-) create mode 100644 src/mcp/server/context.py diff --git a/README.v2.md b/README.v2.md index d34b7832b..4fc110448 100644 --- a/README.v2.md +++ b/README.v2.md @@ -229,7 +229,6 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession # Mock database class for example @@ -275,7 +274,7 @@ mcp = MCPServer("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() @@ -2135,7 +2134,7 @@ server_params = StdioServerParameters( # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py index d37958f52..a929418fa 100644 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py +++ b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py @@ -7,7 +7,6 @@ """ import asyncio -from typing import Any import click from mcp import ClientSession @@ -24,7 +23,7 @@ async def elicitation_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: ElicitRequestParams, ) -> ElicitResult: """Handle elicitation requests from the server.""" @@ -39,7 +38,7 @@ async def elicitation_callback( async def sampling_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: CreateMessageRequestParams, ) -> CreateMessageResult: """Handle sampling requests from the server.""" diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index f096a2649..ab3959f09 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -19,7 +19,7 @@ # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index 8cf1f88f0..b534135e0 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -38,7 +38,7 @@ async def handle_elicitation( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: """Handle elicitation requests from the server. diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 2925f1060..f290d31dd 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession # Mock database class for example @@ -51,7 +50,7 @@ async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]: # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py index d6cde09fa..448322cfb 100644 --- a/src/mcp/client/experimental/task_handlers.py +++ b/src/mcp/client/experimental/task_handlers.py @@ -11,8 +11,10 @@ - Server polls client's task status via tasks/get, tasks/result, etc. """ +from __future__ import annotations + from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Protocol from pydantic import TypeAdapter @@ -32,7 +34,7 @@ class GetTaskHandlerFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.GetTaskRequestParams, ) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch @@ -45,7 +47,7 @@ class GetTaskResultHandlerFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.GetTaskPayloadRequestParams, ) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch @@ -58,7 +60,7 @@ class ListTasksHandlerFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.PaginatedRequestParams | None, ) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch @@ -71,7 +73,7 @@ class CancelTaskHandlerFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.CancelTaskRequestParams, ) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch @@ -88,7 +90,7 @@ class TaskAugmentedSamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams, task_metadata: types.TaskMetadata, ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch @@ -106,14 +108,14 @@ class TaskAugmentedElicitationFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.ElicitRequestParams, task_metadata: types.TaskMetadata, ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch async def default_get_task_handler( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.GetTaskRequestParams, ) -> types.GetTaskResult | types.ErrorData: return types.ErrorData( @@ -123,7 +125,7 @@ async def default_get_task_handler( async def default_get_task_result_handler( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.GetTaskPayloadRequestParams, ) -> types.GetTaskPayloadResult | types.ErrorData: return types.ErrorData( @@ -133,7 +135,7 @@ async def default_get_task_result_handler( async def default_list_tasks_handler( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.PaginatedRequestParams | None, ) -> types.ListTasksResult | types.ErrorData: return types.ErrorData( @@ -143,7 +145,7 @@ async def default_list_tasks_handler( async def default_cancel_task_handler( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.CancelTaskRequestParams, ) -> types.CancelTaskResult | types.ErrorData: return types.ErrorData( @@ -153,7 +155,7 @@ async def default_cancel_task_handler( async def default_task_augmented_sampling( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams, task_metadata: types.TaskMetadata, ) -> types.CreateTaskResult | types.ErrorData: @@ -164,7 +166,7 @@ async def default_task_augmented_sampling( async def default_task_augmented_elicitation( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.ElicitRequestParams, task_metadata: types.TaskMetadata, ) -> types.CreateTaskResult | types.ErrorData: @@ -248,7 +250,7 @@ def handles_request(request: types.ServerRequest) -> bool: async def handle_request( self, - ctx: RequestContext["ClientSession", Any], + ctx: RequestContext[ClientSession], responder: RequestResponder[types.ServerRequest, types.ClientResult], ) -> None: """Handle a task-related request from the server. diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 5080e5385..b10d02ce6 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from typing import Any, Protocol @@ -22,7 +24,7 @@ class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch @@ -30,22 +32,19 @@ async def __call__( class ElicitationFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext["ClientSession", Any] + self, context: RequestContext[ClientSession] ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch class LoggingFnT(Protocol): - async def __call__( - self, - params: types.LoggingMessageNotificationParams, - ) -> None: ... # pragma: no branch + async def __call__(self, params: types.LoggingMessageNotificationParams) -> None: ... # pragma: no branch class MessageHandlerFnT(Protocol): @@ -62,7 +61,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: return types.ErrorData( @@ -72,7 +71,7 @@ async def _default_sampling_callback( async def _default_elicitation_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( # pragma: no cover @@ -82,7 +81,7 @@ async def _default_elicitation_callback( async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -153,10 +152,7 @@ async def initialize(self) -> types.InitializeResult: else None ) elicitation = ( - types.ElicitationCapability( - form=types.FormElicitationCapability(), - url=types.UrlElicitationCapability(), - ) + types.ElicitationCapability(form=types.FormElicitationCapability(), url=types.UrlElicitationCapability()) if self._elicitation_callback is not _default_elicitation_callback else None ) @@ -414,12 +410,7 @@ async def send_roots_list_changed(self) -> None: # pragma: no cover await self.send_notification(types.RootsListChangedNotification()) async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: - ctx = RequestContext[ClientSession, Any]( - request_id=responder.request_id, - meta=responder.request_meta, - session=self, - lifespan_context=None, - ) + ctx = RequestContext[ClientSession](request_id=responder.request_id, meta=responder.request_meta, session=self) # Delegate to experimental task handler if applicable if self._task_handlers.handles_request(responder.request): diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py new file mode 100644 index 000000000..aba6c0ce6 --- /dev/null +++ b/src/mcp/server/context.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +from mcp.server.experimental.request_context import Experimental +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext +from mcp.shared.message import CloseSSEStreamCallback + +LifespanContextT = TypeVar("LifespanContextT") +RequestT = TypeVar("RequestT", default=Any) + + +@dataclass(kw_only=True) +class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContextT, RequestT]): + lifespan_context: LifespanContextT + experimental: Experimental + request: RequestT | None = None + close_sse_stream: CloseSSEStreamCallback | None = None + close_standalone_sse_stream: CloseSSEStreamCallback | None = None diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 3bf12179b..80ae5912b 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -57,10 +57,7 @@ def client_supports_tasks(self) -> bool: return self._client_capabilities.tasks is not None def validate_task_mode( - self, - tool_task_mode: TaskExecutionMode | None, - *, - raise_error: bool = True, + self, tool_task_mode: TaskExecutionMode | None, *, raise_error: bool = True ) -> ErrorData | None: """Validate that the request is compatible with the tool's task execution mode. @@ -95,12 +92,7 @@ def validate_task_mode( return error - def validate_for_tool( - self, - tool: Tool, - *, - raise_error: bool = True, - ) -> ErrorData | None: + def validate_for_tool(self, tool: Tool, *, raise_error: bool = True) -> ErrorData | None: """Validate that the request is compatible with the given tool. Convenience wrapper around validate_task_mode that extracts the mode from a Tool. diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 1dfa47129..9bab9d73a 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -91,6 +91,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings +from mcp.server.context import ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.lowlevel.func_inspection import create_call_wrapper @@ -100,7 +101,6 @@ async def main(): from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.context import RequestContext from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -117,16 +117,11 @@ async def main(): CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[ServerRequestContext[Any, Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: - def __init__( - self, - prompts_changed: bool = False, - resources_changed: bool = False, - tools_changed: bool = False, - ): + def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): self.prompts_changed = prompts_changed self.resources_changed = resources_changed self.tools_changed = tools_changed @@ -253,9 +248,7 @@ def get_capabilities( return capabilities @property - def request_context( - self, - ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: + def request_context(self) -> ServerRequestContext[LifespanResultT, RequestT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() @@ -762,12 +755,12 @@ async def _handle_request( if hasattr(req, "params") and req.params is not None: task_metadata = getattr(req.params, "task", None) token = request_ctx.set( - RequestContext( - message.request_id, - message.request_meta, - session, - lifespan_context, - Experimental( + ServerRequestContext( + request_id=message.request_id, + meta=message.request_meta, + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( task_metadata=task_metadata, _client_capabilities=client_capabilities, _session=session, diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 751733f9c..17744a670 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -14,9 +14,8 @@ from mcp.types import ContentBlock, Icon, TextContent if TYPE_CHECKING: + from mcp.server.context import LifespanContextT, RequestT from mcp.server.mcpserver.server import Context - from mcp.server.session import ServerSessionT - from mcp.shared.context import LifespanContextT, RequestT class Message(BaseModel): @@ -137,7 +136,7 @@ def from_function( async def render( self, arguments: dict[str, Any] | None = None, - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT] | None = None, ) -> list[Message]: """Render the prompt with arguments.""" # Validate required arguments diff --git a/src/mcp/server/mcpserver/prompts/manager.py b/src/mcp/server/mcpserver/prompts/manager.py index 34d4c7e94..21b974131 100644 --- a/src/mcp/server/mcpserver/prompts/manager.py +++ b/src/mcp/server/mcpserver/prompts/manager.py @@ -8,9 +8,8 @@ from mcp.server.mcpserver.utilities.logging import get_logger if TYPE_CHECKING: + from mcp.server.context import LifespanContextT, RequestT from mcp.server.mcpserver.server import Context - from mcp.server.session import ServerSessionT - from mcp.shared.context import LifespanContextT, RequestT logger = get_logger(__name__) @@ -50,7 +49,7 @@ async def render_prompt( self, name: str, arguments: dict[str, Any] | None = None, - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT] | None = None, ) -> list[Message]: """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) diff --git a/src/mcp/server/mcpserver/resources/resource_manager.py b/src/mcp/server/mcpserver/resources/resource_manager.py index 589015688..ed5b74123 100644 --- a/src/mcp/server/mcpserver/resources/resource_manager.py +++ b/src/mcp/server/mcpserver/resources/resource_manager.py @@ -13,9 +13,8 @@ from mcp.types import Annotations, Icon if TYPE_CHECKING: + from mcp.server.context import LifespanContextT, RequestT from mcp.server.mcpserver.server import Context - from mcp.server.session import ServerSessionT - from mcp.shared.context import LifespanContextT, RequestT logger = get_logger(__name__) @@ -82,7 +81,7 @@ def add_template( return template async def get_resource( - self, uri: AnyUrl | str, context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None + self, uri: AnyUrl | str, context: Context[LifespanContextT, RequestT] | None = None ) -> Resource: """Get resource by URI, checking concrete resources first, then templates.""" uri_str = str(uri) diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index 698ac3682..e796823d9 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -16,9 +16,8 @@ from mcp.types import Annotations, Icon if TYPE_CHECKING: + from mcp.server.context import LifespanContextT, RequestT from mcp.server.mcpserver.server import Context - from mcp.server.session import ServerSessionT - from mcp.shared.context import LifespanContextT, RequestT class ResourceTemplate(BaseModel): @@ -100,7 +99,7 @@ async def create_resource( self, uri: str, params: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT] | None = None, ) -> Resource: """Create a resource from the template with the given parameters.""" try: diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index fa63a4ef7..8c1fc342b 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -25,6 +25,7 @@ from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings +from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation from mcp.server.elicitation import elicit_url as _elicit_url from mcp.server.lowlevel.helper_types import ReadResourceContents @@ -36,13 +37,11 @@ from mcp.server.mcpserver.tools import Tool, ToolManager from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.logging import configure_logging, get_logger -from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument @@ -294,7 +293,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: + def get_context(self) -> Context[LifespanResultT, Request]: """Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. """ @@ -972,7 +971,7 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) - raise ValueError(str(e)) -class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]): +class Context(BaseModel, Generic[LifespanContextT, RequestT]): """Context object providing access to MCP capabilities. This provides a cleaner interface to MCP's RequestContext functionality. @@ -1006,14 +1005,15 @@ def my_tool(x: int, ctx: Context) -> str: The context is optional - tools that don't need it can omit the parameter. """ - _request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None + _request_context: ServerRequestContext[LifespanContextT, RequestT] | None _mcp_server: MCPServer | None def __init__( self, *, - request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None, + request_context: ServerRequestContext[LifespanContextT, RequestT] | None = None, mcp_server: MCPServer | None = None, + # TODO(Marcelo): We should drop this kwargs parameter. **kwargs: Any, ): super().__init__(**kwargs) @@ -1028,9 +1028,7 @@ def mcp_server(self) -> MCPServer: return self._mcp_server # pragma: no cover @property - def request_context( - self, - ) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]: + def request_context(self) -> ServerRequestContext[LifespanContextT, RequestT]: """Access to the underlying request context.""" if self._request_context is None: # pragma: no cover raise ValueError("Context is not available outside of a request") diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 9798fd96e..f6bfadbc4 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -16,9 +16,8 @@ from mcp.types import Icon, ToolAnnotations if TYPE_CHECKING: + from mcp.server.context import LifespanContextT, RequestT from mcp.server.mcpserver.server import Context - from mcp.server.session import ServerSessionT - from mcp.shared.context import LifespanContextT, RequestT class Tool(BaseModel): @@ -93,7 +92,7 @@ def from_function( async def run( self, arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT] | None = None, convert_result: bool = False, ) -> Any: """Run the tool with arguments.""" diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index 5decefb7e..c6f8384bd 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -6,12 +6,11 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.tools.base import Tool from mcp.server.mcpserver.utilities.logging import get_logger -from mcp.shared.context import LifespanContextT, RequestT from mcp.types import Icon, ToolAnnotations if TYPE_CHECKING: + from mcp.server.context import LifespanContextT, RequestT from mcp.server.mcpserver.server import Context - from mcp.server.session import ServerSessionT logger = get_logger(__name__) @@ -82,7 +81,7 @@ async def call_tool( self, name: str, arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[LifespanContextT, RequestT] | None = None, convert_result: bool = False, ) -> Any: """Call a tool by name with arguments.""" diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 890536a5d..2facc2a49 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,30 +1,20 @@ """Request context for MCP handlers.""" -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Generic from typing_extensions import TypeVar -from mcp.shared.message import CloseSSEStreamCallback from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParamsMeta SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) -LifespanContextT = TypeVar("LifespanContextT") -RequestT = TypeVar("RequestT", default=Any) -@dataclass -class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): +@dataclass(kw_only=True) +class RequestContext(Generic[SessionT]): + """Common context for handling incoming requests.""" + request_id: RequestId meta: RequestParamsMeta | None session: SessionT - lifespan_context: LifespanContextT - # NOTE: This is typed as Any to avoid circular imports. The actual type is - # mcp.server.experimental.request_context.Experimental, but importing it here - # triggers mcp.server.__init__ -> mcpserver -> tools -> back to this module. - # The Server sets this to an Experimental instance at runtime. - experimental: Any = field(default=None) - request: RequestT | None = None - close_sse_stream: CloseSSEStreamCallback | None = None - close_standalone_sse_stream: CloseSSEStreamCallback | None = None diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index bc54304cb..7225ac8d0 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -5,15 +5,7 @@ from pydantic import BaseModel -from mcp.shared.context import LifespanContextT, RequestContext -from mcp.shared.session import ( - BaseSession, - ReceiveNotificationT, - ReceiveRequestT, - SendNotificationT, - SendRequestT, - SendResultT, -) +from mcp.shared.context import RequestContext, SessionT from mcp.types import ProgressToken @@ -23,8 +15,8 @@ class Progress(BaseModel): @dataclass -class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]): - session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT] +class ProgressContext(Generic[SessionT]): + session: SessionT progress_token: ProgressToken total: float | None current: float = field(default=0.0, init=False) @@ -39,15 +31,9 @@ async def progress(self, amount: float, message: str | None = None) -> None: @contextmanager def progress( - ctx: RequestContext[ - BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], - LifespanContextT, - ], + ctx: RequestContext[SessionT], total: float | None = None, -) -> Generator[ - ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], - None, -]: +) -> Generator[ProgressContext[SessionT], None]: progress_token = ctx.meta.get("progress_token") if ctx.meta else None if progress_token is None: # pragma: no cover raise ValueError("No progress token provided") diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 6919624c7..06c292bd2 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -28,12 +28,12 @@ async def test_list_roots_callback(): ) async def list_roots_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], ) -> ListRootsResult: return callback_return @server.tool("test_list_roots") - async def test_list_roots(context: Context[ServerSession, None], message: str): + async def test_list_roots(context: Context[ServerSession], message: str): roots = await context.session.list_roots() assert roots == callback_return return True diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 78d7ba688..28995e0fb 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -26,7 +26,7 @@ async def test_sampling_callback(): ) async def sampling_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return @@ -71,7 +71,7 @@ async def test_create_message_backwards_compat_single_content(): ) async def sampling_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 220c571a5..40bd65b97 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,4 +1,4 @@ -from typing import Any +from __future__ import annotations import anyio import pytest @@ -390,7 +390,7 @@ async def test_client_capabilities_with_custom_callbacks(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( @@ -400,7 +400,7 @@ async def custom_sampling_callback( # pragma: no cover ) async def custom_list_roots_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) @@ -474,7 +474,7 @@ async def test_client_capabilities_with_sampling_tools(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py index 7bb806696..04561a090 100644 --- a/tests/experimental/tasks/client/test_capabilities.py +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -88,13 +88,13 @@ async def test_client_capabilities_with_tasks(): # Define custom handlers to trigger capability building (never actually called) async def my_list_tasks_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: types.PaginatedRequestParams | None, ) -> types.ListTasksResult | types.ErrorData: raise NotImplementedError async def my_cancel_task_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: types.CancelTaskRequestParams, ) -> types.CancelTaskResult | types.ErrorData: raise NotImplementedError @@ -168,13 +168,13 @@ async def test_client_capabilities_auto_built_from_handlers(): # Define custom handlers (not defaults) async def my_list_tasks_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: types.PaginatedRequestParams | None, ) -> types.ListTasksResult | types.ErrorData: raise NotImplementedError async def my_cancel_task_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: types.CancelTaskRequestParams, ) -> types.CancelTaskResult | types.ErrorData: raise NotImplementedError @@ -249,7 +249,7 @@ async def test_client_capabilities_with_task_augmented_handlers(): # Define task-augmented handler async def my_augmented_sampling_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams, task_metadata: types.TaskMetadata, ) -> types.CreateTaskResult | types.ErrorData: diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py index 0cac3c736..9061aedc2 100644 --- a/tests/experimental/tasks/client/test_handlers.py +++ b/tests/experimental/tasks/client/test_handlers.py @@ -113,7 +113,7 @@ async def test_client_handles_get_task_request(client_streams: ClientTestStreams received_task_id: str | None = None async def get_task_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: GetTaskRequestParams, ) -> GetTaskResult | ErrorData: nonlocal received_task_id @@ -176,7 +176,7 @@ async def test_client_handles_get_task_result_request(client_streams: ClientTest store = InMemoryTaskStore() async def get_task_result_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: GetTaskPayloadRequestParams, ) -> GetTaskPayloadResult | ErrorData: result = await store.get_result(params.task_id) @@ -239,7 +239,7 @@ async def test_client_handles_list_tasks_request(client_streams: ClientTestStrea store = InMemoryTaskStore() async def list_tasks_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: types.PaginatedRequestParams | None, ) -> ListTasksResult | ErrorData: cursor = params.cursor if params else None @@ -294,7 +294,7 @@ async def test_client_handles_cancel_task_request(client_streams: ClientTestStre store = InMemoryTaskStore() async def cancel_task_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: CancelTaskRequestParams, ) -> CancelTaskResult | ErrorData: task = await store.get_task(params.task_id) @@ -361,7 +361,7 @@ async def test_client_task_augmented_sampling(client_streams: ClientTestStreams) background_tg: list[TaskGroup | None] = [None] async def task_augmented_sampling_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: CreateMessageRequestParams, task_metadata: TaskMetadata, ) -> CreateTaskResult: @@ -384,7 +384,7 @@ async def do_sampling() -> None: return CreateTaskResult(task=task) async def get_task_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: GetTaskRequestParams, ) -> GetTaskResult | ErrorData: task = await store.get_task(params.task_id) @@ -400,7 +400,7 @@ async def get_task_handler( ) async def get_task_result_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: GetTaskPayloadRequestParams, ) -> GetTaskPayloadResult | ErrorData: result = await store.get_result(params.task_id) @@ -505,7 +505,7 @@ async def test_client_task_augmented_elicitation(client_streams: ClientTestStrea background_tg: list[TaskGroup | None] = [None] async def task_augmented_elicitation_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: ElicitRequestParams, task_metadata: TaskMetadata, ) -> CreateTaskResult | ErrorData: @@ -524,7 +524,7 @@ async def do_elicitation() -> None: return CreateTaskResult(task=task) async def get_task_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: GetTaskRequestParams, ) -> GetTaskResult | ErrorData: task = await store.get_task(params.task_id) @@ -540,7 +540,7 @@ async def get_task_handler( ) async def get_task_result_handler( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession], params: GetTaskPayloadRequestParams, ) -> GetTaskPayloadResult | ErrorData: result = await store.get_result(params.task_id) diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 1cefe847d..f755658c4 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -53,7 +53,7 @@ def create_client_task_handlers( task_complete_events: dict[str, Event] = {} async def handle_augmented_elicitation( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: ElicitRequestParams, task_metadata: TaskMetadata, ) -> CreateTaskResult: @@ -72,7 +72,7 @@ async def complete_task() -> None: return CreateTaskResult(task=task) async def handle_get_task( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: Any, ) -> GetTaskResult: """Handle tasks/get from server.""" @@ -89,7 +89,7 @@ async def handle_get_task( ) async def handle_get_task_result( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: Any, ) -> GetTaskPayloadResult | ErrorData: """Handle tasks/result from server.""" @@ -121,7 +121,7 @@ def create_sampling_task_handlers( task_complete_events: dict[str, Event] = {} async def handle_augmented_sampling( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: CreateMessageRequestParams, task_metadata: TaskMetadata, ) -> CreateTaskResult: @@ -140,7 +140,7 @@ async def complete_task() -> None: return CreateTaskResult(task=task) async def handle_get_task( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: Any, ) -> GetTaskResult: """Handle tasks/get from server.""" @@ -157,7 +157,7 @@ async def handle_get_task( ) async def handle_get_task_result( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: Any, ) -> GetTaskPayloadResult | ErrorData: """Handle tasks/result from server.""" @@ -211,7 +211,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu # Elicitation callback for client async def elicitation_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: ElicitRequestParams, ) -> ElicitResult: elicit_received.set() @@ -379,7 +379,7 @@ async def work(task: ServerTaskContext) -> CallToolResult: # Elicitation callback for client async def elicitation_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: ElicitRequestParams, ) -> ElicitResult: elicit_received.set() diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index db0a66f9b..fb4bb0101 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -2,8 +2,9 @@ import pytest +from mcp.server.context import ServerRequestContext +from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context -from mcp.shared.context import RequestContext pytestmark = pytest.mark.anyio @@ -16,11 +17,12 @@ async def test_progress_token_zero_first_call(): mock_session.send_progress_notification = AsyncMock() # Create request context with progress token 0 - request_context = RequestContext( + request_context = ServerRequestContext( request_id="test-request", session=mock_session, meta={"progress_token": 0}, lifespan_context=None, + experimental=Experimental(), ) # Create context with our mocks diff --git a/tests/issues/test_355_type_error.py b/tests/issues/test_355_type_error.py index 33d6b455b..905cf7eee 100644 --- a/tests/issues/test_355_type_error.py +++ b/tests/issues/test_355_type_error.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from mcp.server.mcpserver import Context, MCPServer -from mcp.server.session import ServerSession class Database: # Replace with your actual DB type @@ -45,7 +44,7 @@ async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]: # pragm # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: # pragma: no cover +def query_db(ctx: Context[AppContext]) -> str: # pragma: no cover """Tool that uses initialized resources""" db = ctx.request_context.lifespan_context.db return db.query() diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 5a55592ab..37e87a1f4 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -65,7 +65,7 @@ async def test_stdio_elicitation(): create_ask_user_tool(mcp) # Create a custom handler for elicitation requests - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) else: # pragma: no cover @@ -82,7 +82,7 @@ async def test_stdio_elicitation_decline(): mcp = MCPServer(name="StdioElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="decline") await call_tool_and_assert( @@ -121,7 +121,7 @@ class InvalidNestedSchema(BaseModel): # Dummy callback (won't be called due to validation failure) async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientSession], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -177,7 +177,7 @@ async def optional_tool(ctx: Context[ServerSession, None]) -> str: for content, expected in test_cases: - async def callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -196,7 +196,7 @@ async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: return f"Validation failed: {str(e)}" async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientSession], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -220,7 +220,7 @@ async def valid_multiselect_tool(ctx: Context[ServerSession, None]) -> str: return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" return f"User {result.action}" # pragma: no cover - async def multiselect_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams): + async def multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): if "Please provide tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -240,7 +240,7 @@ async def optional_multiselect_tool(ctx: Context[ServerSession, None]) -> str: return f"Name: {result.data.name}, Tags: {tags_str}" return f"User {result.action}" # pragma: no cover - async def optional_multiselect_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams): + async def optional_multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): if "Please provide optional tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -274,7 +274,7 @@ async def defaults_tool(ctx: Context[ServerSession, None]) -> str: return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients - async def callback_schema_verify(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_schema_verify(context: RequestContext[ClientSession], params: ElicitRequestParams): # Verify the schema includes defaults assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation" schema = params.requested_schema @@ -296,7 +296,7 @@ async def callback_schema_verify(context: RequestContext[ClientSession, None], p ) # Test overriding defaults - async def callback_override(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_override(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult( action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False} ) @@ -372,7 +372,7 @@ async def select_color_legacy(ctx: Context[ServerSession, None]) -> str: return f"User: {result.data.user_name}, Color: {result.data.color}" return f"User {result.action}" # pragma: no cover - async def enum_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams): + async def enum_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): if "colors" in params.message and "legacy" not in params.message: return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]}) elif "color" in params.message: diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index 132427e5e..40453b89d 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -185,7 +185,7 @@ def create_client_for_transport(transport: str, server_url: str): # Callback functions for testing async def sampling_callback( - context: RequestContext[ClientSession, None], params: CreateMessageRequestParams + context: RequestContext[ClientSession], params: CreateMessageRequestParams ) -> CreateMessageResult: """Sampling callback for tests.""" return CreateMessageResult( @@ -198,7 +198,7 @@ async def sampling_callback( ) -async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): +async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index 42cac073c..550bba50a 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -6,12 +6,12 @@ import pytest from pydantic import BaseModel +from mcp.server.context import LifespanContextT, RequestT from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.tools import Tool, ToolManager from mcp.server.mcpserver.utilities.func_metadata import ArgModelBase, FuncMetadata from mcp.server.session import ServerSessionT -from mcp.shared.context import LifespanContextT, RequestT from mcp.types import TextContent, ToolAnnotations @@ -347,9 +347,7 @@ def tool_without_context(x: int) -> str: # pragma: no cover tool = manager.add_tool(tool_without_context) assert tool.context_kwarg is None - def tool_with_parametrized_context( - x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT] - ) -> str: # pragma: no cover + def tool_with_parametrized_context(x: int, ctx: Context[LifespanContextT, RequestT]) -> str: # pragma: no cover return str(x) tool = manager.add_tool(tool_with_parametrized_context) diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index 45ec40a37..667a4279a 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -29,7 +29,7 @@ async def request_api_key(ctx: Context[ServerSession, None]) -> str: return f"User {result.action}" # Create elicitation callback that accepts URL mode - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): assert params.mode == "url" assert params.url == "https://example.com/api_key_setup" assert params.elicitation_id == "test-elicitation-001" @@ -58,7 +58,7 @@ async def oauth_flow(ctx: Context[ServerSession, None]) -> str: # Test only checks decline path return f"User {result.action} authorization" - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): assert params.mode == "url" return ElicitResult(action="decline") @@ -84,7 +84,7 @@ async def payment_flow(ctx: Context[ServerSession, None]) -> str: # Test only checks cancel path return f"User {result.action} payment" - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): assert params.mode == "url" return ElicitResult(action="cancel") @@ -111,7 +111,7 @@ async def setup_credentials(ctx: Context[ServerSession, None]) -> str: # Test only checks accept path - return the type name return type(result).__name__ - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="accept") async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -138,7 +138,7 @@ async def check_url_response(ctx: Context[ServerSession, None]) -> str: assert result.content is None return f"Action: {result.action}, Content: {result.content}" - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): # Verify that this is URL mode assert params.mode == "url" assert isinstance(params, types.ElicitRequestURLParams) @@ -171,7 +171,7 @@ async def ask_name(ctx: Context[ServerSession, None]) -> str: assert result.data is not None return f"Hello, {result.data.name}!" - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): # Verify form mode parameters assert params.mode == "form" assert isinstance(params, types.ElicitRequestFormParams) @@ -207,7 +207,7 @@ async def trigger_elicitation(ctx: Context[ServerSession, None]) -> str: return "Elicitation completed" - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="accept") # pragma: no cover async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -264,7 +264,7 @@ async def test_cancel(ctx: Context[ServerSession, None]) -> str: return "Not cancelled" # pragma: no cover # Test declined result - async def decline_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def decline_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="decline") async with Client(mcp, elicitation_callback=decline_callback) as client: @@ -274,7 +274,7 @@ async def decline_callback(context: RequestContext[ClientSession, None], params: assert result.content[0].text == "Declined" # Test cancelled result - async def cancel_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def cancel_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): return ElicitResult(action="cancel") async with Client(mcp, elicitation_callback=cancel_callback) as client: @@ -304,7 +304,7 @@ async def use_deprecated_elicit(ctx: Context[ServerSession, None]) -> str: return f"Email: {result.content.get('email', 'none')}" return "No email provided" # pragma: no cover - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): # Verify this is form mode assert params.mode == "form" assert params.requested_schema is not None @@ -332,7 +332,7 @@ async def direct_elicit_url(ctx: Context[ServerSession, None]) -> str: ) return f"Result: {result.action}" - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): assert params.mode == "url" assert params.elicitation_id == "ctx-test-001" return ElicitResult(action="accept") diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 81aa1ccbc..ca632148b 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any from unittest.mock import patch import anyio @@ -14,7 +14,7 @@ from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.progress import progress -from mcp.shared.session import BaseSession, RequestResponder +from mcp.shared.session import RequestResponder @pytest.mark.anyio @@ -279,14 +279,10 @@ async def handle_client_message( request_id="test-request", session=client_session, meta={"progress_token": progress_token}, - lifespan_context=None, ) - # cast for type checker - typed_context = cast(RequestContext[BaseSession[Any, Any, Any, Any, Any], Any], request_context) - # Utilize progress context manager - with progress(typed_context, total=100) as p: + with progress(request_context, total=100) as p: await p.progress(10, message="Loading configuration...") await p.progress(30, message="Connecting to database...") await p.progress(40, message="Fetching data...") diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 70a9fca40..c1d0e3062 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1344,7 +1344,7 @@ async def test_streamablehttp_server_sampling(basic_server: None, basic_server_u # Define sampling callback that returns a mock response async def sampling_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientSession], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult: nonlocal sampling_callback_invoked, captured_message_params From 42e2104d2f6cfdd4682e90895982348071ef478a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 3 Feb 2026 11:04:56 +0100 Subject: [PATCH 2/4] fix --- src/mcp/server/context.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index aba6c0ce6..0951a0784 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,7 +1,9 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Any, Generic + +from typing_extensions import TypeVar from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession From 3585b22bc2b54d7ceff7fd1d99ab08b55c05dc3c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 3 Feb 2026 11:08:13 +0100 Subject: [PATCH 3/4] ... --- .github/actions/conformance/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py index 3a2ac2802..87f323132 100644 --- a/.github/actions/conformance/client.py +++ b/.github/actions/conformance/client.py @@ -187,7 +187,7 @@ async def run_sse_retry(server_url: str) -> None: async def default_elicitation_callback( - context: RequestContext[ClientSession, Any], # noqa: ARG001 + context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: """Accept elicitation and apply defaults from the schema (SEP-1034).""" From fdc1c3b3686ebfc07e7c814085e45ae8d9705be8 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 3 Feb 2026 11:21:40 +0100 Subject: [PATCH 4/4] add migration --- docs/migration.md | 44 ++++++++++++++++++++++++ tests/client/test_list_roots_callback.py | 13 ++----- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index b941fb5a1..84320ffef 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -371,6 +371,50 @@ async def handle_tool(name: str, arguments: dict) -> list[TextContent]: await ctx.session.send_progress_notification(ctx.meta["progress_token"], 0.5, 100) ``` +### `RequestContext` and `ProgressContext` type parameters simplified + +The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3. + +**`RequestContext` changes:** + +- Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]` +- Server-specific fields (`lifespan_context`, `experimental`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` + +**`ProgressContext` changes:** + +- Type parameters reduced from `ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]` to `ProgressContext[SessionT]` + +**Before (v1):** + +```python +from mcp.shared.context import RequestContext, LifespanContextT, RequestT +from mcp.shared.progress import ProgressContext + +# RequestContext with 3 type parameters +ctx: RequestContext[ClientSession, LifespanContextT, RequestT] + +# ProgressContext with 5 type parameters +progress_ctx: ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT] +``` + +**After (v2):** + +```python +from mcp.shared.context import RequestContext +from mcp.shared.progress import ProgressContext + +# RequestContext with 1 type parameter +ctx: RequestContext[ClientSession] + +# ProgressContext with 1 type parameter +progress_ctx: ProgressContext[ClientSession] + +# For server-specific context with lifespan and request types +from mcp.server.context import ServerRequestContext, LifespanContextT, RequestT + +server_ctx: ServerRequestContext[LifespanContextT, RequestT] +``` + ### Resource URI type changed from `AnyUrl` to `str` The `uri` field on resource-related types now uses `str` instead of Pydantic's `AnyUrl`. This aligns with the [MCP specification schema](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.ts) which defines URIs as plain strings (`uri: string`) without strict URL validation. This change allows relative paths like `users/me` that were previously rejected. diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 06c292bd2..40265d57f 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -5,7 +5,6 @@ from mcp.client.session import ClientSession from mcp.server.mcpserver import MCPServer from mcp.server.mcpserver.server import Context -from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.types import ListRootsResult, Root, TextContent @@ -16,14 +15,8 @@ async def test_list_roots_callback(): callback_return = ListRootsResult( roots=[ - Root( - uri=FileUrl("file://users/fake/test"), - name="Test Root 1", - ), - Root( - uri=FileUrl("file://users/fake/test/2"), - name="Test Root 2", - ), + Root(uri=FileUrl("file://users/fake/test"), name="Test Root 1"), + Root(uri=FileUrl("file://users/fake/test/2"), name="Test Root 2"), ] ) @@ -33,7 +26,7 @@ async def list_roots_callback( return callback_return @server.tool("test_list_roots") - async def test_list_roots(context: Context[ServerSession], message: str): + async def test_list_roots(context: Context[None], message: str): roots = await context.session.list_roots() assert roots == callback_return return True