Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion .github/actions/conformance/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,27 @@ async def run_client_credentials_basic(server_url: str) -> None:
async def run_auth_code_client(server_url: str) -> None:
"""Authorization code flow (default for auth/* scenarios)."""
callback_handler = ConformanceOAuthCallbackHandler()
storage = InMemoryTokenStorage()

# Check for pre-registered client credentials from context
context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT")
if context_json:
try:
context = json.loads(context_json)
client_id = context.get("client_id")
client_secret = context.get("client_secret")
if client_id:
await storage.set_client_info(
OAuthClientInformationFull(
client_id=client_id,
client_secret=client_secret,
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
token_endpoint_auth_method="client_secret_basic" if client_secret else "none",
)
)
logger.debug(f"Pre-loaded client credentials: client_id={client_id}")
except json.JSONDecodeError:
pass
Comment on lines +297 to +298
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as before: why wouldn't want to see the exception?


oauth_auth = OAuthClientProvider(
server_url=server_url,
Expand All @@ -284,7 +305,7 @@ async def run_auth_code_client(server_url: str) -> None:
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
),
storage=InMemoryTokenStorage(),
storage=storage,
redirect_handler=callback_handler.handle_redirect,
callback_handler=callback_handler.handle_callback,
client_metadata_url="https://conformance-test.local/client-metadata.json",
Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/conformance.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
name: Conformance Tests

on:
# Disabled: conformance tests are currently broken in CI
# push:
# branches: [main]
# pull_request:
push:
branches: [main]
pull_request:
workflow_dispatch:

concurrency:
Expand Down Expand Up @@ -43,4 +42,4 @@ jobs:
with:
node-version: 24
- run: uv sync --frozen --all-extras --package mcp
- run: npx @modelcontextprotocol/conformance@0.1.10 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
- run: npx @modelcontextprotocol/conformance@0.1.13 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
28 changes: 28 additions & 0 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
client_metadata_url: str | None = None,
validate_resource_url: Callable[[str, str | None], Awaitable[None]] | None = None,
):
"""Initialize OAuth2 authentication.

Expand All @@ -243,6 +244,10 @@ def __init__(
advertises client_id_metadata_document_supported=true, this URL will be
used as the client_id instead of performing dynamic client registration.
Must be a valid HTTPS URL with a non-root pathname.
validate_resource_url: Optional callback to override resource URL validation.
Called with (server_url, prm_resource) where prm_resource is the resource
from Protected Resource Metadata (or None if not present). If not provided,
default validation rejects mismatched resources per RFC 8707.

Raises:
ValueError: If client_metadata_url is provided but not a valid HTTPS URL
Expand All @@ -263,6 +268,7 @@ def __init__(
timeout=timeout,
client_metadata_url=client_metadata_url,
)
self._validate_resource_url_callback = validate_resource_url
self._initialized = False

async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
Expand Down Expand Up @@ -476,6 +482,26 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata

async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None:
"""Validate that PRM resource matches the server URL per RFC 8707."""
prm_resource = str(prm.resource) if prm.resource else None

if self._validate_resource_url_callback is not None:
await self._validate_resource_url_callback(self.context.server_url, prm_resource)
return

if not prm_resource:
return # pragma: no cover
default_resource = resource_url_from_server_url(self.context.server_url)
# Normalize: Pydantic AnyHttpUrl adds trailing slash to root URLs
# (e.g. "https://example.com/") while resource_url_from_server_url may not.
if not default_resource.endswith("/"):
default_resource += "/"
if not prm_resource.endswith("/"):
prm_resource += "/"
if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource):
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
Expand Down Expand Up @@ -517,6 +543,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.

prm = await handle_protected_resource_response(discovery_response)
if prm:
# Validate PRM resource matches server URL (RFC 8707)
await self._validate_resource_match(prm)
self.context.protected_resource_metadata = prm

# todo: try all authorization_servers to find the OASM
Expand Down
135 changes: 133 additions & 2 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import AnyHttpUrl, AnyUrl

from mcp.client.auth import OAuthClientProvider, PKCEParameters
from mcp.client.auth.exceptions import OAuthFlowError
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
Expand Down Expand Up @@ -818,6 +819,136 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
assert "resource=" in content


@pytest.mark.anyio
async def test_validate_resource_rejects_mismatched_resource(
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""Client must reject PRM resource that doesn't match server URL."""
provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

prm = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://evil.example.com/mcp"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
with pytest.raises(OAuthFlowError, match="does not match expected"):
await provider._validate_resource_match(prm)


@pytest.mark.anyio
async def test_validate_resource_accepts_matching_resource(
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""Client must accept PRM resource that matches server URL."""
provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

prm = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
# Should not raise
await provider._validate_resource_match(prm)


@pytest.mark.anyio
async def test_validate_resource_custom_callback(
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""Custom callback overrides default validation."""
callback_called_with: list[tuple[str, str | None]] = []

async def custom_validate(server_url: str, prm_resource: str | None) -> None:
callback_called_with.append((server_url, prm_resource))

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
validate_resource_url=custom_validate,
)
provider._initialized = True

# This would normally fail default validation (different origin),
# but custom callback accepts it
prm = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://evil.example.com/mcp"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
await provider._validate_resource_match(prm)
assert callback_called_with == snapshot([("https://api.example.com/v1/mcp", "https://evil.example.com/mcp")])


@pytest.mark.anyio
async def test_validate_resource_accepts_root_url_with_trailing_slash(
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""Root URLs with trailing slash normalization should match."""
provider = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

prm = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://api.example.com/"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
# Should not raise despite trailing slash difference
await provider._validate_resource_match(prm)


@pytest.mark.anyio
async def test_validate_resource_accepts_server_url_with_trailing_slash(
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""Server URL with trailing slash should match PRM resource."""
provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp/",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

prm = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)
# Should not raise - both normalize to the same URL with trailing slash
await provider._validate_resource_match(prm)


@pytest.mark.anyio
async def test_get_resource_url_uses_canonical_when_prm_mismatches(
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""get_resource_url falls back to canonical URL when PRM resource doesn't match."""
provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

# Set PRM with a resource that is NOT a parent of the server URL
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://other.example.com/mcp"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)

# get_resource_url should return the canonical server URL, not the PRM resource
assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp")


class TestRegistrationResponse:
"""Test client registration response handling."""

Expand Down Expand Up @@ -963,7 +1094,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
# Send a successful discovery response with minimal protected resource metadata
discovery_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
request=discovery_request,
)

Expand Down Expand Up @@ -1116,7 +1247,7 @@ async def test_token_exchange_accepts_201_status(
# Send a successful discovery response with minimal protected resource metadata
discovery_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
request=discovery_request,
)

Expand Down