diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 67453624c..a9f5e110e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -489,11 +489,12 @@ async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None return tool - def call_tool(self, *, validate_input: bool = True): + def call_tool(self, *, validate_input: bool = True, validate_output: bool = True): """Register a tool call handler. Args: validate_input: If True, validates input against inputSchema. Default is True. + validate_output: If True, validates output against outputSchema. Default is True. The handler validates input against inputSchema (if validate_input=True), calls the tool function, and builds a CallToolResult with the results: @@ -501,7 +502,7 @@ def call_tool(self, *, validate_input: bool = True): - Structured content (dict): returned in structuredContent, serialized JSON text returned in content - Both: returned in content and structuredContent - If outputSchema is defined, validates structuredContent or errors if missing. + If validate_output is True and outputSchema is defined, validates structuredContent or errors if missing. """ def decorator( @@ -522,7 +523,11 @@ async def handler(req: types.CallToolRequest): try: tool_name = req.params.name arguments = req.params.arguments or {} - tool = await self._get_cached_tool_definition(tool_name) + + if validate_input or validate_output: + tool = await self._get_cached_tool_definition(tool_name) + else: + tool = None # input validation if validate_input and tool: @@ -557,7 +562,7 @@ async def handler(req: types.CallToolRequest): return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}") # output validation - if tool and tool.outputSchema is not None: + if validate_output and tool and tool.outputSchema is not None: if maybe_structured_content is None: return self._make_error_result( "Output validation error: outputSchema defined but no structured output returned" diff --git a/tests/server/test_lowlevel_input_validation.py b/tests/server/test_lowlevel_input_validation.py index 47cb57232..d1bad5933 100644 --- a/tests/server/test_lowlevel_input_validation.py +++ b/tests/server/test_lowlevel_input_validation.py @@ -19,8 +19,10 @@ async def run_tool_test( tools: list[Tool], - call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[list[TextContent]]], + call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[Any]], test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], + validate_input: bool = True, + validate_output: bool = True, ) -> CallToolResult | None: """Helper to run a tool test with minimal boilerplate. @@ -28,6 +30,8 @@ async def run_tool_test( tools: List of tools to register call_tool_handler: Handler function for tool calls test_callback: Async function that performs the test using the client session + validate_input: Whether to enable input validation (default: True) + validate_output: Whether to enable output validation (default: True) Returns: The result of the tool call @@ -39,7 +43,7 @@ async def run_tool_test( async def list_tools(): return tools - @server.call_tool() + @server.call_tool(validate_input=validate_input, validate_output=validate_output) async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: return await call_tool_handler(name, arguments) @@ -309,3 +313,138 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: assert any( "Tool 'unknown_tool' not listed, no validation will be performed" in record.message for record in caplog.records ) + + +@pytest.mark.anyio +async def test_validate_input_false_with_invalid_input(): + """Test that when validate_input=False, invalid input is not validated.""" + tools = [ + Tool( + name="add", + description="Add two numbers", + inputSchema={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["a", "b"], + }, + ) + ] + + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: + if name == "add": + # Even with invalid input (string instead of number), this should execute + # because validation is disabled + a = arguments.get("a", 0) + b = arguments.get("b", 0) + return [TextContent(type="text", text=f"Result: {a} + {b}")] + else: # pragma: no cover + raise ValueError(f"Unknown tool: {name}") + + async def test_callback(client_session: ClientSession) -> CallToolResult: + # Call with invalid input (string instead of number) + # With validate_input=False, this should succeed + return await client_session.call_tool("add", {"a": "five", "b": "three"}) + + result = await run_tool_test(tools, call_tool_handler, test_callback, validate_input=False) + + # Verify results - should succeed because validation is disabled + assert result is not None + assert not result.isError + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Result: five + three" + + +@pytest.mark.anyio +async def test_validate_input_true_with_invalid_input(): + """Test that when validate_input=True (default), invalid input is validated.""" + tools = [ + Tool( + name="add", + description="Add two numbers", + inputSchema={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["a", "b"], + }, + ) + ] + + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: + # This should not be reached because validation will fail + return [TextContent(type="text", text="This should not be reached")] # pragma: no cover + + async def test_callback(client_session: ClientSession) -> CallToolResult: + # Call with invalid input (string instead of number) + # With validate_input=True (default), this should fail validation + return await client_session.call_tool("add", {"a": "five", "b": "three"}) + + result = await run_tool_test(tools, call_tool_handler, test_callback) + + # Verify error - input validation is enabled by default + assert result is not None + assert result.isError + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert "Input validation error" in result.content[0].text + + +@pytest.mark.anyio +async def test_validate_both_false(): + """Test that when both validate_input and validate_output are False, no validation occurs.""" + tools = [ + Tool( + name="process", + description="Process data", + inputSchema={ + "type": "object", + "properties": { + "value": {"type": "number"}, + }, + "required": ["value"], + }, + outputSchema={ + "type": "object", + "properties": { + "result": {"type": "number"}, + }, + "required": ["result"], + }, + ) + ] + + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: + if name == "process": + # Invalid input (string) and invalid output (string), but no validation + value = arguments.get("value", 0) + return {"result": f"processed_{value}"} + else: # pragma: no cover + raise ValueError(f"Unknown tool: {name}") + + async def test_callback(client_session: ClientSession) -> CallToolResult: + # Call with invalid input, and handler returns invalid output + # With both validations disabled, server should not return error + try: + return await client_session.call_tool("process", {"value": "invalid"}) + except RuntimeError as e: + # Client validation will fail, but server validation was disabled + assert "Invalid structured content" in str(e) + return CallToolResult( + content=[TextContent(type="text", text="Server returned result")], + structuredContent={"result": "processed_invalid"}, + isError=False, + ) + + result = await run_tool_test(tools, call_tool_handler, test_callback, validate_input=False, validate_output=False) + + # Verify server didn't return an error + assert result is not None + assert not result.isError + assert result.structuredContent == {"result": "processed_invalid"} diff --git a/tests/server/test_lowlevel_output_validation.py b/tests/server/test_lowlevel_output_validation.py index f73544521..e35e7f540 100644 --- a/tests/server/test_lowlevel_output_validation.py +++ b/tests/server/test_lowlevel_output_validation.py @@ -21,6 +21,7 @@ async def run_tool_test( tools: list[Tool], call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[Any]], test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], + validate_output: bool = True, ) -> CallToolResult | None: """Helper to run a tool test with minimal boilerplate. @@ -28,6 +29,7 @@ async def run_tool_test( tools: List of tools to register call_tool_handler: Handler function for tool calls test_callback: Async function that performs the test using the client session + validate_output: Whether to enable output validation (default: True) Returns: The result of the tool call @@ -40,7 +42,7 @@ async def run_tool_test( async def list_tools(): return tools - @server.call_tool() + @server.call_tool(validate_output=validate_output) async def call_tool(name: str, arguments: dict[str, Any]): return await call_tool_handler(name, arguments) @@ -474,3 +476,149 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: assert result.content[0].type == "text" assert "Output validation error:" in result.content[0].text assert "'five' is not of type 'integer'" in result.content[0].text + + +@pytest.mark.anyio +async def test_validate_output_false_returns_invalid_schema(): + """Test that when validate_output=False, server returns invalid output without error.""" + tools = [ + Tool( + name="tool_with_schema", + description="Tool with output schema", + inputSchema={ + "type": "object", + "properties": {}, + }, + outputSchema={ + "type": "object", + "properties": { + "required_field": {"type": "string"}, + }, + "required": ["required_field"], + }, + ) + ] + + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: + if name == "tool_with_schema": + # Missing required field, but server validation is disabled + return {"other_field": "value"} + else: # pragma: no cover + raise ValueError(f"Unknown tool: {name}") + + async def test_callback(client_session: ClientSession) -> CallToolResult: + # Note: Even though server validation is disabled, client validation will still fail + # This test verifies that the server doesn't return an error response + try: + return await client_session.call_tool("tool_with_schema", {}) + except RuntimeError as e: + # Client validation failed, but that's expected + # The important thing is that the server didn't return an error response + # We can verify this by checking the error message + assert "Invalid structured content" in str(e) + # Return a mock result to indicate server didn't error + return CallToolResult( + content=[TextContent(type="text", text="Server returned result")], + structuredContent={"other_field": "value"}, + isError=False, + ) + + result = await run_tool_test(tools, call_tool_handler, test_callback, validate_output=False) + + # Verify server didn't return an error - it returned the invalid output + assert result is not None + assert not result.isError + assert result.structuredContent == {"other_field": "value"} + + +@pytest.mark.anyio +async def test_validate_output_false_returns_no_structured_output(): + """Test that when validate_output=False, server returns without structured output without error.""" + tools = [ + Tool( + name="tool_with_schema", + description="Tool with output schema", + inputSchema={ + "type": "object", + "properties": {}, + }, + outputSchema={ + "type": "object", + "properties": { + "result": {"type": "string"}, + }, + "required": ["result"], + }, + ) + ] + + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: + if name == "tool_with_schema": + # Returns only content, no structured output, but server validation is disabled + return [TextContent(type="text", text="No structured output")] + else: # pragma: no cover + raise ValueError(f"Unknown tool: {name}") + + async def test_callback(client_session: ClientSession) -> CallToolResult: + # Note: Even though server validation is disabled, client validation will still fail + # This test verifies that the server doesn't return an error response + try: + return await client_session.call_tool("tool_with_schema", {}) + except RuntimeError as e: + # Client validation failed, but that's expected + # The important thing is that the server didn't return an error response + assert "has an output schema but did not return structured content" in str(e) + # Return a mock result to indicate server didn't error + return CallToolResult( + content=[TextContent(type="text", text="No structured output")], structuredContent=None, isError=False + ) + + result = await run_tool_test(tools, call_tool_handler, test_callback, validate_output=False) + + # Verify server didn't return an error - it returned content without structured output + assert result is not None + assert not result.isError + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "No structured output" + assert result.structuredContent is None + + +@pytest.mark.anyio +async def test_validate_output_true_with_invalid_schema(): + """Test that when validate_output=True (default), invalid output schema is validated.""" + tools = [ + Tool( + name="tool_with_schema", + description="Tool with output schema", + inputSchema={ + "type": "object", + "properties": {}, + }, + outputSchema={ + "type": "object", + "properties": { + "required_field": {"type": "string"}, + }, + "required": ["required_field"], + }, + ) + ] + + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: + if name == "tool_with_schema": + # Missing required field, validation is enabled (default) + return {"other_field": "value"} + else: # pragma: no cover + raise ValueError(f"Unknown tool: {name}") + + async def test_callback(client_session: ClientSession) -> CallToolResult: + return await client_session.call_tool("tool_with_schema", {}) + + result = await run_tool_test(tools, call_tool_handler, test_callback) + + # Verify error - output validation is enabled by default + assert result is not None + assert result.isError + assert isinstance(result.content[0], TextContent) + assert "Output validation error:" in result.content[0].text