Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/conformance/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def run_sse_retry(server_url: str) -> None:


async def default_elicitation_callback(
context: RequestContext[ClientSession, Any], # noqa: ARG001
context: RequestContext[ClientSession],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can have a ClientRequestContext.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea true, would be nice to have both

params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
"""Accept elicitation and apply defaults from the schema (SEP-1034)."""
Expand Down
5 changes: 2 additions & 3 deletions README.v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass

from mcp.server.mcpserver import Context, MCPServer
from mcp.server.session import ServerSession


# Mock database class for example
Expand Down Expand Up @@ -275,7 +274,7 @@ mcp = MCPServer("My App", lifespan=app_lifespan)

# Access type-safe lifespan context in tools
@mcp.tool()
def query_db(ctx: Context[ServerSession, AppContext]) -> str:
def query_db(ctx: Context[AppContext]) -> str:
"""Tool that uses initialized resources."""
db = ctx.request_context.lifespan_context.db
return db.query()
Expand Down Expand Up @@ -2135,7 +2134,7 @@ server_params = StdioServerParameters(

# Optional: create a sampling callback
async def handle_sampling_message(
context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams
context: RequestContext[ClientSession], params: types.CreateMessageRequestParams
) -> types.CreateMessageResult:
print(f"Sampling request: {params.messages}")
return types.CreateMessageResult(
Expand Down
44 changes: 44 additions & 0 deletions docs/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,50 @@ async def handle_tool(name: str, arguments: dict) -> list[TextContent]:
await ctx.session.send_progress_notification(ctx.meta["progress_token"], 0.5, 100)
```

### `RequestContext` and `ProgressContext` type parameters simplified

The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3.

**`RequestContext` changes:**

- Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]`
- Server-specific fields (`lifespan_context`, `experimental`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context`

**`ProgressContext` changes:**

- Type parameters reduced from `ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]` to `ProgressContext[SessionT]`

**Before (v1):**

```python
from mcp.shared.context import RequestContext, LifespanContextT, RequestT
from mcp.shared.progress import ProgressContext

# RequestContext with 3 type parameters
ctx: RequestContext[ClientSession, LifespanContextT, RequestT]

# ProgressContext with 5 type parameters
progress_ctx: ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]
```

**After (v2):**

```python
from mcp.shared.context import RequestContext
from mcp.shared.progress import ProgressContext

# RequestContext with 1 type parameter
ctx: RequestContext[ClientSession]

# ProgressContext with 1 type parameter
progress_ctx: ProgressContext[ClientSession]

# For server-specific context with lifespan and request types
from mcp.server.context import ServerRequestContext, LifespanContextT, RequestT

server_ctx: ServerRequestContext[LifespanContextT, RequestT]
```

### Resource URI type changed from `AnyUrl` to `str`

The `uri` field on resource-related types now uses `str` instead of Pydantic's `AnyUrl`. This aligns with the [MCP specification schema](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.ts) which defines URIs as plain strings (`uri: string`) without strict URL validation. This change allows relative paths like `users/me` that were previously rejected.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

import asyncio
from typing import Any

import click
from mcp import ClientSession
Expand All @@ -24,7 +23,7 @@


async def elicitation_callback(
context: RequestContext[ClientSession, Any],
context: RequestContext[ClientSession],
params: ElicitRequestParams,
) -> ElicitResult:
"""Handle elicitation requests from the server."""
Expand All @@ -39,7 +38,7 @@ async def elicitation_callback(


async def sampling_callback(
context: RequestContext[ClientSession, Any],
context: RequestContext[ClientSession],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
"""Handle sampling requests from the server."""
Expand Down
2 changes: 1 addition & 1 deletion examples/snippets/clients/stdio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Optional: create a sampling callback
async def handle_sampling_message(
context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams
context: RequestContext[ClientSession], params: types.CreateMessageRequestParams
) -> types.CreateMessageResult:
print(f"Sampling request: {params.messages}")
return types.CreateMessageResult(
Expand Down
2 changes: 1 addition & 1 deletion examples/snippets/clients/url_elicitation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


async def handle_elicitation(
context: RequestContext[ClientSession, Any],
context: RequestContext[ClientSession],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
"""Handle elicitation requests from the server.
Expand Down
3 changes: 1 addition & 2 deletions examples/snippets/servers/lifespan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dataclasses import dataclass

from mcp.server.mcpserver import Context, MCPServer
from mcp.server.session import ServerSession


# Mock database class for example
Expand Down Expand Up @@ -51,7 +50,7 @@ async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]:

# Access type-safe lifespan context in tools
@mcp.tool()
def query_db(ctx: Context[ServerSession, AppContext]) -> str:
def query_db(ctx: Context[AppContext]) -> str:
"""Tool that uses initialized resources."""
db = ctx.request_context.lifespan_context.db
return db.query()
30 changes: 16 additions & 14 deletions src/mcp/client/experimental/task_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
- Server polls client's task status via tasks/get, tasks/result, etc.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol
from typing import TYPE_CHECKING, Protocol

from pydantic import TypeAdapter

Expand All @@ -32,7 +34,7 @@ class GetTaskHandlerFnT(Protocol):

async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.GetTaskRequestParams,
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch

Expand All @@ -45,7 +47,7 @@ class GetTaskResultHandlerFnT(Protocol):

async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.GetTaskPayloadRequestParams,
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch

Expand All @@ -58,7 +60,7 @@ class ListTasksHandlerFnT(Protocol):

async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.PaginatedRequestParams | None,
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch

Expand All @@ -71,7 +73,7 @@ class CancelTaskHandlerFnT(Protocol):

async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.CancelTaskRequestParams,
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch

Expand All @@ -88,7 +90,7 @@ class TaskAugmentedSamplingFnT(Protocol):

async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
Expand All @@ -106,14 +108,14 @@ class TaskAugmentedElicitationFnT(Protocol):

async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.ElicitRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch


async def default_get_task_handler(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.GetTaskRequestParams,
) -> types.GetTaskResult | types.ErrorData:
return types.ErrorData(
Expand All @@ -123,7 +125,7 @@ async def default_get_task_handler(


async def default_get_task_result_handler(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.GetTaskPayloadRequestParams,
) -> types.GetTaskPayloadResult | types.ErrorData:
return types.ErrorData(
Expand All @@ -133,7 +135,7 @@ async def default_get_task_result_handler(


async def default_list_tasks_handler(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.PaginatedRequestParams | None,
) -> types.ListTasksResult | types.ErrorData:
return types.ErrorData(
Expand All @@ -143,7 +145,7 @@ async def default_list_tasks_handler(


async def default_cancel_task_handler(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.CancelTaskRequestParams,
) -> types.CancelTaskResult | types.ErrorData:
return types.ErrorData(
Expand All @@ -153,7 +155,7 @@ async def default_cancel_task_handler(


async def default_task_augmented_sampling(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData:
Expand All @@ -164,7 +166,7 @@ async def default_task_augmented_sampling(


async def default_task_augmented_elicitation(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.ElicitRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData:
Expand Down Expand Up @@ -248,7 +250,7 @@ def handles_request(request: types.ServerRequest) -> bool:

async def handle_request(
self,
ctx: RequestContext["ClientSession", Any],
ctx: RequestContext[ClientSession],
responder: RequestResponder[types.ServerRequest, types.ClientResult],
) -> None:
"""Handle a task-related request from the server.
Expand Down
31 changes: 11 additions & 20 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from typing import Any, Protocol

Expand All @@ -22,30 +24,27 @@
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch


class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
self, context: RequestContext[ClientSession]
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch


class LoggingFnT(Protocol):
async def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ... # pragma: no branch
async def __call__(self, params: types.LoggingMessageNotificationParams) -> None: ... # pragma: no branch


class MessageHandlerFnT(Protocol):
Expand All @@ -62,7 +61,7 @@ async def _default_message_handler(


async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
return types.ErrorData(
Expand All @@ -72,7 +71,7 @@ async def _default_sampling_callback(


async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData( # pragma: no cover
Expand All @@ -82,7 +81,7 @@ async def _default_elicitation_callback(


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext[ClientSession],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
Expand Down Expand Up @@ -153,10 +152,7 @@ async def initialize(self) -> types.InitializeResult:
else None
)
elicitation = (
types.ElicitationCapability(
form=types.FormElicitationCapability(),
url=types.UrlElicitationCapability(),
)
types.ElicitationCapability(form=types.FormElicitationCapability(), url=types.UrlElicitationCapability())
if self._elicitation_callback is not _default_elicitation_callback
else None
)
Expand Down Expand Up @@ -414,12 +410,7 @@ async def send_roots_list_changed(self) -> None: # pragma: no cover
await self.send_notification(types.RootsListChangedNotification())

async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
ctx = RequestContext[ClientSession](request_id=responder.request_id, meta=responder.request_meta, session=self)

# Delegate to experimental task handler if applicable
if self._task_handlers.handles_request(responder.request):
Expand Down
23 changes: 23 additions & 0 deletions src/mcp/server/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Generic

from typing_extensions import TypeVar

from mcp.server.experimental.request_context import Experimental
from mcp.server.session import ServerSession
from mcp.shared.context import RequestContext
from mcp.shared.message import CloseSSEStreamCallback

LifespanContextT = TypeVar("LifespanContextT")
RequestT = TypeVar("RequestT", default=Any)


@dataclass(kw_only=True)
class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContextT, RequestT]):
lifespan_context: LifespanContextT
experimental: Experimental
request: RequestT | None = None
close_sse_stream: CloseSSEStreamCallback | None = None
close_standalone_sse_stream: CloseSSEStreamCallback | None = None
Loading