Skip to content

Commit b1f7eec

Browse files
authored
refactor: split RequestContext between server and client (#1987)
1 parent 050aeb6 commit b1f7eec

File tree

35 files changed

+210
-212
lines changed

35 files changed

+210
-212
lines changed

.github/actions/conformance/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ async def run_sse_retry(server_url: str) -> None:
187187

188188

189189
async def default_elicitation_callback(
190-
context: RequestContext[ClientSession, Any], # noqa: ARG001
190+
context: RequestContext[ClientSession],
191191
params: types.ElicitRequestParams,
192192
) -> types.ElicitResult | types.ErrorData:
193193
"""Accept elicitation and apply defaults from the schema (SEP-1034)."""

README.v2.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ from contextlib import asynccontextmanager
229229
from dataclasses import dataclass
230230

231231
from mcp.server.mcpserver import Context, MCPServer
232-
from mcp.server.session import ServerSession
233232

234233

235234
# Mock database class for example
@@ -275,7 +274,7 @@ mcp = MCPServer("My App", lifespan=app_lifespan)
275274

276275
# Access type-safe lifespan context in tools
277276
@mcp.tool()
278-
def query_db(ctx: Context[ServerSession, AppContext]) -> str:
277+
def query_db(ctx: Context[AppContext]) -> str:
279278
"""Tool that uses initialized resources."""
280279
db = ctx.request_context.lifespan_context.db
281280
return db.query()
@@ -2135,7 +2134,7 @@ server_params = StdioServerParameters(
21352134

21362135
# Optional: create a sampling callback
21372136
async def handle_sampling_message(
2138-
context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams
2137+
context: RequestContext[ClientSession], params: types.CreateMessageRequestParams
21392138
) -> types.CreateMessageResult:
21402139
print(f"Sampling request: {params.messages}")
21412140
return types.CreateMessageResult(

docs/migration.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,50 @@ async def handle_tool(name: str, arguments: dict) -> list[TextContent]:
371371
await ctx.session.send_progress_notification(ctx.meta["progress_token"], 0.5, 100)
372372
```
373373

374+
### `RequestContext` and `ProgressContext` type parameters simplified
375+
376+
The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3.
377+
378+
**`RequestContext` changes:**
379+
380+
- Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]`
381+
- Server-specific fields (`lifespan_context`, `experimental`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context`
382+
383+
**`ProgressContext` changes:**
384+
385+
- Type parameters reduced from `ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]` to `ProgressContext[SessionT]`
386+
387+
**Before (v1):**
388+
389+
```python
390+
from mcp.shared.context import RequestContext, LifespanContextT, RequestT
391+
from mcp.shared.progress import ProgressContext
392+
393+
# RequestContext with 3 type parameters
394+
ctx: RequestContext[ClientSession, LifespanContextT, RequestT]
395+
396+
# ProgressContext with 5 type parameters
397+
progress_ctx: ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]
398+
```
399+
400+
**After (v2):**
401+
402+
```python
403+
from mcp.shared.context import RequestContext
404+
from mcp.shared.progress import ProgressContext
405+
406+
# RequestContext with 1 type parameter
407+
ctx: RequestContext[ClientSession]
408+
409+
# ProgressContext with 1 type parameter
410+
progress_ctx: ProgressContext[ClientSession]
411+
412+
# For server-specific context with lifespan and request types
413+
from mcp.server.context import ServerRequestContext, LifespanContextT, RequestT
414+
415+
server_ctx: ServerRequestContext[LifespanContextT, RequestT]
416+
```
417+
374418
### Resource URI type changed from `AnyUrl` to `str`
375419

376420
The `uri` field on resource-related types now uses `str` instead of Pydantic's `AnyUrl`. This aligns with the [MCP specification schema](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/draft/schema.ts) which defines URIs as plain strings (`uri: string`) without strict URL validation. This change allows relative paths like `users/me` that were previously rejected.

examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""
88

99
import asyncio
10-
from typing import Any
1110

1211
import click
1312
from mcp import ClientSession
@@ -24,7 +23,7 @@
2423

2524

2625
async def elicitation_callback(
27-
context: RequestContext[ClientSession, Any],
26+
context: RequestContext[ClientSession],
2827
params: ElicitRequestParams,
2928
) -> ElicitResult:
3029
"""Handle elicitation requests from the server."""
@@ -39,7 +38,7 @@ async def elicitation_callback(
3938

4039

4140
async def sampling_callback(
42-
context: RequestContext[ClientSession, Any],
41+
context: RequestContext[ClientSession],
4342
params: CreateMessageRequestParams,
4443
) -> CreateMessageResult:
4544
"""Handle sampling requests from the server."""

examples/snippets/clients/stdio_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# Optional: create a sampling callback
2121
async def handle_sampling_message(
22-
context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams
22+
context: RequestContext[ClientSession], params: types.CreateMessageRequestParams
2323
) -> types.CreateMessageResult:
2424
print(f"Sampling request: {params.messages}")
2525
return types.CreateMessageResult(

examples/snippets/clients/url_elicitation_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939

4040
async def handle_elicitation(
41-
context: RequestContext[ClientSession, Any],
41+
context: RequestContext[ClientSession],
4242
params: types.ElicitRequestParams,
4343
) -> types.ElicitResult | types.ErrorData:
4444
"""Handle elicitation requests from the server.

examples/snippets/servers/lifespan_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from dataclasses import dataclass
66

77
from mcp.server.mcpserver import Context, MCPServer
8-
from mcp.server.session import ServerSession
98

109

1110
# Mock database class for example
@@ -51,7 +50,7 @@ async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]:
5150

5251
# Access type-safe lifespan context in tools
5352
@mcp.tool()
54-
def query_db(ctx: Context[ServerSession, AppContext]) -> str:
53+
def query_db(ctx: Context[AppContext]) -> str:
5554
"""Tool that uses initialized resources."""
5655
db = ctx.request_context.lifespan_context.db
5756
return db.query()

src/mcp/client/experimental/task_handlers.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
- Server polls client's task status via tasks/get, tasks/result, etc.
1212
"""
1313

14+
from __future__ import annotations
15+
1416
from dataclasses import dataclass, field
15-
from typing import TYPE_CHECKING, Any, Protocol
17+
from typing import TYPE_CHECKING, Protocol
1618

1719
from pydantic import TypeAdapter
1820

@@ -32,7 +34,7 @@ class GetTaskHandlerFnT(Protocol):
3234

3335
async def __call__(
3436
self,
35-
context: RequestContext["ClientSession", Any],
37+
context: RequestContext[ClientSession],
3638
params: types.GetTaskRequestParams,
3739
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
3840

@@ -45,7 +47,7 @@ class GetTaskResultHandlerFnT(Protocol):
4547

4648
async def __call__(
4749
self,
48-
context: RequestContext["ClientSession", Any],
50+
context: RequestContext[ClientSession],
4951
params: types.GetTaskPayloadRequestParams,
5052
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
5153

@@ -58,7 +60,7 @@ class ListTasksHandlerFnT(Protocol):
5860

5961
async def __call__(
6062
self,
61-
context: RequestContext["ClientSession", Any],
63+
context: RequestContext[ClientSession],
6264
params: types.PaginatedRequestParams | None,
6365
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
6466

@@ -71,7 +73,7 @@ class CancelTaskHandlerFnT(Protocol):
7173

7274
async def __call__(
7375
self,
74-
context: RequestContext["ClientSession", Any],
76+
context: RequestContext[ClientSession],
7577
params: types.CancelTaskRequestParams,
7678
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
7779

@@ -88,7 +90,7 @@ class TaskAugmentedSamplingFnT(Protocol):
8890

8991
async def __call__(
9092
self,
91-
context: RequestContext["ClientSession", Any],
93+
context: RequestContext[ClientSession],
9294
params: types.CreateMessageRequestParams,
9395
task_metadata: types.TaskMetadata,
9496
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
@@ -106,14 +108,14 @@ class TaskAugmentedElicitationFnT(Protocol):
106108

107109
async def __call__(
108110
self,
109-
context: RequestContext["ClientSession", Any],
111+
context: RequestContext[ClientSession],
110112
params: types.ElicitRequestParams,
111113
task_metadata: types.TaskMetadata,
112114
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
113115

114116

115117
async def default_get_task_handler(
116-
context: RequestContext["ClientSession", Any],
118+
context: RequestContext[ClientSession],
117119
params: types.GetTaskRequestParams,
118120
) -> types.GetTaskResult | types.ErrorData:
119121
return types.ErrorData(
@@ -123,7 +125,7 @@ async def default_get_task_handler(
123125

124126

125127
async def default_get_task_result_handler(
126-
context: RequestContext["ClientSession", Any],
128+
context: RequestContext[ClientSession],
127129
params: types.GetTaskPayloadRequestParams,
128130
) -> types.GetTaskPayloadResult | types.ErrorData:
129131
return types.ErrorData(
@@ -133,7 +135,7 @@ async def default_get_task_result_handler(
133135

134136

135137
async def default_list_tasks_handler(
136-
context: RequestContext["ClientSession", Any],
138+
context: RequestContext[ClientSession],
137139
params: types.PaginatedRequestParams | None,
138140
) -> types.ListTasksResult | types.ErrorData:
139141
return types.ErrorData(
@@ -143,7 +145,7 @@ async def default_list_tasks_handler(
143145

144146

145147
async def default_cancel_task_handler(
146-
context: RequestContext["ClientSession", Any],
148+
context: RequestContext[ClientSession],
147149
params: types.CancelTaskRequestParams,
148150
) -> types.CancelTaskResult | types.ErrorData:
149151
return types.ErrorData(
@@ -153,7 +155,7 @@ async def default_cancel_task_handler(
153155

154156

155157
async def default_task_augmented_sampling(
156-
context: RequestContext["ClientSession", Any],
158+
context: RequestContext[ClientSession],
157159
params: types.CreateMessageRequestParams,
158160
task_metadata: types.TaskMetadata,
159161
) -> types.CreateTaskResult | types.ErrorData:
@@ -164,7 +166,7 @@ async def default_task_augmented_sampling(
164166

165167

166168
async def default_task_augmented_elicitation(
167-
context: RequestContext["ClientSession", Any],
169+
context: RequestContext[ClientSession],
168170
params: types.ElicitRequestParams,
169171
task_metadata: types.TaskMetadata,
170172
) -> types.CreateTaskResult | types.ErrorData:
@@ -248,7 +250,7 @@ def handles_request(request: types.ServerRequest) -> bool:
248250

249251
async def handle_request(
250252
self,
251-
ctx: RequestContext["ClientSession", Any],
253+
ctx: RequestContext[ClientSession],
252254
responder: RequestResponder[types.ServerRequest, types.ClientResult],
253255
) -> None:
254256
"""Handle a task-related request from the server.

src/mcp/client/session.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import logging
24
from typing import Any, Protocol
35

@@ -22,30 +24,27 @@
2224
class SamplingFnT(Protocol):
2325
async def __call__(
2426
self,
25-
context: RequestContext["ClientSession", Any],
27+
context: RequestContext[ClientSession],
2628
params: types.CreateMessageRequestParams,
2729
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
2830

2931

3032
class ElicitationFnT(Protocol):
3133
async def __call__(
3234
self,
33-
context: RequestContext["ClientSession", Any],
35+
context: RequestContext[ClientSession],
3436
params: types.ElicitRequestParams,
3537
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
3638

3739

3840
class ListRootsFnT(Protocol):
3941
async def __call__(
40-
self, context: RequestContext["ClientSession", Any]
42+
self, context: RequestContext[ClientSession]
4143
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
4244

4345

4446
class LoggingFnT(Protocol):
45-
async def __call__(
46-
self,
47-
params: types.LoggingMessageNotificationParams,
48-
) -> None: ... # pragma: no branch
47+
async def __call__(self, params: types.LoggingMessageNotificationParams) -> None: ... # pragma: no branch
4948

5049

5150
class MessageHandlerFnT(Protocol):
@@ -62,7 +61,7 @@ async def _default_message_handler(
6261

6362

6463
async def _default_sampling_callback(
65-
context: RequestContext["ClientSession", Any],
64+
context: RequestContext[ClientSession],
6665
params: types.CreateMessageRequestParams,
6766
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
6867
return types.ErrorData(
@@ -72,7 +71,7 @@ async def _default_sampling_callback(
7271

7372

7473
async def _default_elicitation_callback(
75-
context: RequestContext["ClientSession", Any],
74+
context: RequestContext[ClientSession],
7675
params: types.ElicitRequestParams,
7776
) -> types.ElicitResult | types.ErrorData:
7877
return types.ErrorData( # pragma: no cover
@@ -82,7 +81,7 @@ async def _default_elicitation_callback(
8281

8382

8483
async def _default_list_roots_callback(
85-
context: RequestContext["ClientSession", Any],
84+
context: RequestContext[ClientSession],
8685
) -> types.ListRootsResult | types.ErrorData:
8786
return types.ErrorData(
8887
code=types.INVALID_REQUEST,
@@ -153,10 +152,7 @@ async def initialize(self) -> types.InitializeResult:
153152
else None
154153
)
155154
elicitation = (
156-
types.ElicitationCapability(
157-
form=types.FormElicitationCapability(),
158-
url=types.UrlElicitationCapability(),
159-
)
155+
types.ElicitationCapability(form=types.FormElicitationCapability(), url=types.UrlElicitationCapability())
160156
if self._elicitation_callback is not _default_elicitation_callback
161157
else None
162158
)
@@ -414,12 +410,7 @@ async def send_roots_list_changed(self) -> None: # pragma: no cover
414410
await self.send_notification(types.RootsListChangedNotification())
415411

416412
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
417-
ctx = RequestContext[ClientSession, Any](
418-
request_id=responder.request_id,
419-
meta=responder.request_meta,
420-
session=self,
421-
lifespan_context=None,
422-
)
413+
ctx = RequestContext[ClientSession](request_id=responder.request_id, meta=responder.request_meta, session=self)
423414

424415
# Delegate to experimental task handler if applicable
425416
if self._task_handlers.handles_request(responder.request):

src/mcp/server/context.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Generic
5+
6+
from typing_extensions import TypeVar
7+
8+
from mcp.server.experimental.request_context import Experimental
9+
from mcp.server.session import ServerSession
10+
from mcp.shared.context import RequestContext
11+
from mcp.shared.message import CloseSSEStreamCallback
12+
13+
LifespanContextT = TypeVar("LifespanContextT")
14+
RequestT = TypeVar("RequestT", default=Any)
15+
16+
17+
@dataclass(kw_only=True)
18+
class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContextT, RequestT]):
19+
lifespan_context: LifespanContextT
20+
experimental: Experimental
21+
request: RequestT | None = None
22+
close_sse_stream: CloseSSEStreamCallback | None = None
23+
close_standalone_sse_stream: CloseSSEStreamCallback | None = None

0 commit comments

Comments
 (0)