diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index f064f7def..08c5f64e2 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -24,8 +24,10 @@ import anyio from mcp import ClientSession, ListToolsResult from mcp.client.session import ElicitationFnT +from mcp.shared.exceptions import McpError from mcp.types import ( BlobResourceContents, + ElicitationRequiredErrorData, GetPromptResult, ListPromptsResult, ListResourcesResult, @@ -667,7 +669,43 @@ async def call_tool_async( return self._handle_tool_execution_error(tool_use_id, e) def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult: - """Create error ToolResult with consistent logging.""" + """Create error ToolResult with consistent logging and elicitation callback support. + + Args: + tool_use_id: Unique identifier for this tool use. + exception: The exception that occurred during tool execution. + + Returns: + MCPToolResult: Error result containing either the elicitation URL(s) or the + original exception message. + """ + if isinstance(exception, McpError) and exception.error.code == -32042: + try: + error_data = ElicitationRequiredErrorData.model_validate(exception.error.data) + elicitation_urls = [e.url for e in error_data.elicitations if e.url] + elicitation_messages = [e.message for e in error_data.elicitations if e.message] + + if elicitation_urls: + url_list = "\n".join(elicitation_urls) + message = elicitation_messages[0] if elicitation_messages else "Authorization required." + + return MCPToolResult( + status="error", + toolUseId=tool_use_id, + content=[ + { + "text": ( + f"URL_ELICITATION_REQUIRED: {message}\n\n" + f"The user must open the following URL(s) in their browser " + f"to complete authorization:\n\n{url_list}\n\n" + f"After the user completes the flow, retry this tool call." + ) + } + ], + ) + except Exception: + logger.debug("Failed to parse ElicitationRequiredErrorData from -32042 error", exc_info=True) + return MCPToolResult( status="error", toolUseId=tool_use_id, diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index e477c64d5..4bceb182e 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -914,3 +914,128 @@ async def test_handle_error_message_with_percent_in_message(): # This should not raise TypeError and should not raise the exception (since it's non-fatal) await client._handle_error_message(error_with_percent) + + +def test_call_tool_sync_elicitation_error(mock_transport, mock_session): + """Test that call_tool_sync correctly handles elicitation required errors.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + # Create an elicitation error (code -32042) + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ + ElicitRequestURLParams( + url="https://example.com/auth", message="Please authorize the application", elicitationId="elicit-123" + ) + ] + ) + + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "URL_ELICITATION_REQUIRED" in result["content"][0]["text"] + assert "https://example.com/auth" in result["content"][0]["text"] + assert "Please authorize the application" in result["content"][0]["text"] + assert "retry this tool call" in result["content"][0]["text"] + + +def test_call_tool_sync_elicitation_error_multiple_urls(mock_transport, mock_session): + """Test that call_tool_sync correctly handles elicitation errors with multiple URLs.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + # Create an elicitation error with multiple URLs + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ + ElicitRequestURLParams( + url="https://example.com/auth1", message="First authorization", elicitationId="elicit-1" + ), + ElicitRequestURLParams( + url="https://example.com/auth2", message="Second authorization", elicitationId="elicit-2" + ), + ] + ) + + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "URL_ELICITATION_REQUIRED" in result["content"][0]["text"] + assert "https://example.com/auth1" in result["content"][0]["text"] + assert "https://example.com/auth2" in result["content"][0]["text"] + assert "First authorization" in result["content"][0]["text"] + + +def test_call_tool_sync_elicitation_error_invalid_data(mock_transport, mock_session): + """Test that call_tool_sync handles malformed elicitation error data gracefully.""" + from mcp.shared.exceptions import McpError + + # Create an elicitation error with invalid data that can't be parsed + error = McpError(error=MagicMock(code=-32042, data={"invalid": "data structure"})) + + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + # Should fall back to generic error message when parsing fails + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "Tool execution failed" in result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_call_tool_async_elicitation_error(mock_transport, mock_session): + """Test that call_tool_async correctly handles elicitation required errors.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + # Create an elicitation error (code -32042) + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ + ElicitRequestURLParams( + url="https://example.com/auth", message="Please authorize the application", elicitationId="elicit-123" + ) + ] + ) + + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + + with MCPClient(mock_transport["transport_callable"]) as client: + with ( + patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe, + patch("asyncio.wrap_future") as mock_wrap_future, + ): + mock_future = MagicMock() + mock_run_coroutine_threadsafe.return_value = mock_future + + # Create an async mock that raises the elicitation error + async def mock_awaitable(): + raise error + + mock_wrap_future.return_value = mock_awaitable() + + result = await client.call_tool_async( + tool_use_id="test-123", name="test_tool", arguments={"param": "value"} + ) + + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "URL_ELICITATION_REQUIRED" in result["content"][0]["text"] + assert "https://example.com/auth" in result["content"][0]["text"] + assert "Please authorize the application" in result["content"][0]["text"]