diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 8ac58ca4..4005b2fe 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,41 +1,41 @@ // For format details, see https://aka.ms/devcontainer.json. For config options, see the // README at: https://github.com/devcontainers/templates/tree/main/src/rust { - "name": "Rust", - // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile - "image": "mcr.microsoft.com/devcontainers/rust:1-1-bullseye", - "features": { - "ghcr.io/devcontainers/features/node:1": {}, - "ghcr.io/devcontainers/features/python:1": { - "version": "3.10", - "toolsToInstall": "uv" - } - }, - // Configure tool-specific properties. - "customizations": { - "vscode": { - "settings": { - "editor.formatOnSave": true, - "[rust]": { - "editor.defaultFormatter": "rust-lang.rust-analyzer" - } - } - } - }, - // Use 'postCreateCommand' to run commands after the container is created. - "postCreateCommand": "uv venv" - // Use 'mounts' to make the cargo cache persistent in a Docker Volume. - // "mounts": [ - // { - // "source": "devcontainer-cargo-cache-${devcontainerId}", - // "target": "/usr/local/cargo", - // "type": "volume" - // } - // ] - // Features to add to the dev container. More info: https://containers.dev/features. - // "features": {}, - // Use 'forwardPorts' to make a list of ports inside the container available locally. - // "forwardPorts": [], - // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. - // "remoteUser": "root" + "name": "Rust", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "image": "mcr.microsoft.com/devcontainers/rust:1-1-bullseye", + "features": { + "ghcr.io/devcontainers/features/node:1": {}, + "ghcr.io/devcontainers/features/python:1": { + "version": "3.10", + "toolsToInstall": "uv" + } + }, + // Configure tool-specific properties. + "customizations": { + "vscode": { + "settings": { + "editor.formatOnSave": true, + "[rust]": { + "editor.defaultFormatter": "rust-lang.rust-analyzer" + } + } + } + }, + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": "uv venv && npm install -g @commitlint/config-conventional" + // Use 'mounts' to make the cargo cache persistent in a Docker Volume. + // "mounts": [ + // { + // "source": "devcontainer-cargo-cache-${devcontainerId}", + // "target": "/usr/local/cargo", + // "type": "volume" + // } + // ] + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" } \ No newline at end of file diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 5b9318d9..8228ce7c 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -107,9 +107,9 @@ pub mod auth; #[cfg(feature = "auth")] #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] pub use auth::{ - AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient, CredentialStore, - InMemoryCredentialStore, InMemoryStateStore, StateStore, StoredAuthorizationState, - StoredCredentials, + AuthClient, AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient, + CredentialStore, InMemoryCredentialStore, InMemoryStateStore, ScopeUpgradeConfig, StateStore, + StoredAuthorizationState, StoredCredentials, WWWAuthenticateParams, }; // #[cfg(feature = "transport-ws")] diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index de2cf5e9..ad2d69ab 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -23,6 +23,8 @@ const DEFAULT_EXCHANGE_URL: &str = "http://localhost"; pub struct StoredCredentials { pub client_id: String, pub token_response: Option, + #[serde(default)] + pub granted_scopes: Vec, } /// Trait for storing and retrieving OAuth2 credentials @@ -238,6 +240,12 @@ pub enum AuthError { #[error("Registration failed: {0}")] RegistrationFailed(String), + + #[error("Insufficient scope: {required_scope}")] + InsufficientScope { + required_scope: String, + upgrade_url: Option, + }, } /// oauth2 metadata @@ -250,6 +258,7 @@ pub struct AuthorizationMetadata { pub jwks_uri: Option, pub scopes_supported: Option>, pub response_types_supported: Option>, + pub code_challenge_methods_supported: Option>, // allow additional fields #[serde(flatten)] pub additional_fields: HashMap, @@ -259,6 +268,28 @@ pub struct AuthorizationMetadata { struct ResourceServerMetadata { authorization_server: Option, authorization_servers: Option>, + scopes_supported: Option>, +} + +/// Parameters extracted from WWW-Authenticate header +#[derive(Debug, Clone, Default)] +pub struct WWWAuthenticateParams { + pub resource_metadata_url: Option, + pub scope: Option, + pub error: Option, + pub error_description: Option, +} + +impl WWWAuthenticateParams { + /// check if this is an insufficient_scope error + pub fn is_insufficient_scope(&self) -> bool { + self.error.as_deref() == Some("insufficient_scope") + } + + /// check if this is an invalid_token error (expired/revoked) + pub fn is_invalid_token(&self) -> bool { + self.error.as_deref() == Some("invalid_token") + } } /// oauth2 client config @@ -291,6 +322,24 @@ type OAuthClient = oauth2::Client< >; type Credentials = (String, Option); +/// Configuration for scope upgrade behavior +#[derive(Debug, Clone)] +pub struct ScopeUpgradeConfig { + /// Maximum number of scope upgrade attempts before giving up + pub max_upgrade_attempts: u32, + /// Whether to automatically attempt scope upgrades on 403 + pub auto_upgrade: bool, +} + +impl Default for ScopeUpgradeConfig { + fn default() -> Self { + Self { + max_upgrade_attempts: 3, + auto_upgrade: true, + } + } +} + /// oauth2 auth manager pub struct AuthorizationManager { http_client: HttpClient, @@ -299,6 +348,13 @@ pub struct AuthorizationManager { credential_store: Arc, state_store: Arc, base_url: Url, + current_scopes: RwLock>, + scope_upgrade_attempts: RwLock, + scope_upgrade_config: ScopeUpgradeConfig, + /// scopes from the initial 401 WWW-Authenticate header, used by select_scopes() + www_auth_scopes: RwLock>, + /// scopes_supported from protected resource metadata (RFC 9728) + resource_scopes: RwLock>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -373,11 +429,21 @@ impl AuthorizationManager { credential_store: Arc::new(InMemoryCredentialStore::new()), state_store: Arc::new(InMemoryStateStore::new()), base_url, + current_scopes: RwLock::new(Vec::new()), + scope_upgrade_attempts: RwLock::new(0), + scope_upgrade_config: ScopeUpgradeConfig::default(), + www_auth_scopes: RwLock::new(Vec::new()), + resource_scopes: RwLock::new(Vec::new()), }; Ok(manager) } + /// Set the scope upgrade configuration + pub fn set_scope_upgrade_config(&mut self, config: ScopeUpgradeConfig) { + self.scope_upgrade_config = config; + } + /// Set a custom credential store /// /// This allows you to provide your own implementation of credential storage, @@ -426,13 +492,13 @@ impl AuthorizationManager { Ok(()) } - /// discover oauth2 metadata + /// discover oauth2 metadata (per SEP-985: Protected Resource Metadata first, then direct OAuth) pub async fn discover_metadata(&self) -> Result { - if let Some(metadata) = self.try_discover_oauth_server(&self.base_url).await? { + if let Some(metadata) = self.discover_oauth_server_via_resource_metadata().await? { return Ok(metadata); } - if let Some(metadata) = self.discover_oauth_server_via_resource_metadata().await? { + if let Some(metadata) = self.try_discover_oauth_server(&self.base_url).await? { return Ok(metadata); } @@ -485,15 +551,34 @@ impl AuthorizationManager { self.oauth_client = Some(client_builder); Ok(()) } - /// validate if the server support the response type - fn validate_response_supported(&self, response_type: &str) -> Result<(), AuthError> { - if let Some(metadata) = self.metadata.as_ref() { - if let Some(response_types_supported) = metadata.response_types_supported.as_ref() { - if !response_types_supported.contains(&response_type.to_string()) { - return Err(AuthError::InvalidScope(response_type.to_string())); - } + /// validate authorization server metadata before starting authorization. + fn validate_server_metadata(&self, response_type: &str) -> Result<(), AuthError> { + let Some(metadata) = self.metadata.as_ref() else { + return Ok(()); + }; + + // RFC 8414 RECOMMENDS response_types_supported in the metadata. This field is optional, + // but if present and does not include the flow we use ("code"), bail out early with a clear error. + if let Some(response_types_supported) = metadata.response_types_supported.as_ref() { + if !response_types_supported.contains(&response_type.to_string()) { + return Err(AuthError::InvalidScope(response_type.to_string())); + } + } + + // for PKCE, we always send s256 since oauth 2.1 requires servers to support it, + // but warn if the server metadata suggests otherwise + match &metadata.code_challenge_methods_supported { + Some(methods) if !methods.iter().any(|m| m == "S256") => { + warn!( + ?methods, + "server does not advertise S256 in code_challenge_methods_supported, \ + proceeding with S256 anyway as oauth 2.1 requires it. \ + The server is not compliant with the specification!" + ); } + _ => {} } + Ok(()) } /// dynamic register oauth2 client @@ -512,10 +597,7 @@ impl AuthorizationManager { "Dynamic client registration not supported".to_string(), )); }; - - // RFC 8414 RECOMMENDS response_types_supported in the metadata. This field is optional, - // but if present and does not include the flow we use ("code"), bail out early with a clear error. - self.validate_response_supported("code")?; + self.validate_server_metadata("code")?; let registration_request = ClientRegistrationRequest { client_name: name.to_string(), @@ -604,9 +686,7 @@ impl AuthorizationManager { .oauth_client .as_ref() .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; - - // ensure the server supports the response type we intend to use when metadata is available - self.validate_response_supported("code")?; + self.validate_server_metadata("code")?; // generate pkce challenge let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); @@ -614,7 +694,8 @@ impl AuthorizationManager { // build authorization request let mut auth_request = oauth_client .authorize_url(CsrfToken::new_random) - .set_pkce_challenge(pkce_challenge); + .set_pkce_challenge(pkce_challenge) + .add_extra_param("resource", self.base_url.to_string()); // add request scopes for scope in scopes { @@ -632,6 +713,111 @@ impl AuthorizationManager { Ok(auth_url.to_string()) } + /// get the current granted scopes + pub async fn get_current_scopes(&self) -> Vec { + self.current_scopes.read().await.clone() + } + + /// compute the union of current scopes and required scopes + fn compute_scope_union(current: &[String], required: &str) -> Vec { + let mut scope_set: std::collections::HashSet = current.iter().cloned().collect(); + for scope in required.split_whitespace() { + scope_set.insert(scope.to_string()); + } + scope_set.into_iter().collect() + } + + /// check if a scope upgrade is possible and allowed + pub async fn can_attempt_scope_upgrade(&self) -> bool { + if !self.scope_upgrade_config.auto_upgrade { + return false; + } + let attempts = *self.scope_upgrade_attempts.read().await; + attempts < self.scope_upgrade_config.max_upgrade_attempts + } + + /// select scopes based on SEP-835 priority: + /// 1. scope from WWW-Authenticate header (argument or stored from initial 401 probe) + /// 2. scopes_supported from protected resource metadata (RFC 9728) + /// 3. scopes_supported from authorization server metadata + /// 4. provided default scopes + pub fn select_scopes( + &self, + www_authenticate_scope: Option<&str>, + default_scopes: &[&str], + ) -> Vec { + if let Some(scope) = www_authenticate_scope { + return scope.split_whitespace().map(|s| s.to_string()).collect(); + } + + // use scopes from initial 401 WWW-Authenticate header + if let Ok(guard) = self.www_auth_scopes.try_read() { + if !guard.is_empty() { + return guard.clone(); + } + } + + // use scopes_supported from protected resource metadata (RFC 9728) + if let Ok(guard) = self.resource_scopes.try_read() { + if !guard.is_empty() { + return guard.clone(); + } + } + + // use scopes_supported from authorization server metadata + if let Some(metadata) = &self.metadata { + if let Some(scopes_supported) = &metadata.scopes_supported { + if !scopes_supported.is_empty() { + return scopes_supported.clone(); + } + } + } + + default_scopes.iter().map(|s| s.to_string()).collect() + } + + /// attempt to upgrade scopes after receiving a 403 insufficient_scope error. + /// returns the authorization URL for re-authorization with upgraded scopes. + pub async fn request_scope_upgrade(&self, required_scope: &str) -> Result { + if !self.scope_upgrade_config.auto_upgrade { + return Err(AuthError::InvalidScope( + "Scope upgrade is disabled".to_string(), + )); + } + + let mut attempts = self.scope_upgrade_attempts.write().await; + if *attempts >= self.scope_upgrade_config.max_upgrade_attempts { + return Err(AuthError::InvalidScope(format!( + "Maximum scope upgrade attempts ({}) exceeded", + self.scope_upgrade_config.max_upgrade_attempts + ))); + } + + *attempts += 1; + drop(attempts); + + let current_scopes = self.current_scopes.read().await.clone(); + let upgraded_scopes = Self::compute_scope_union(¤t_scopes, required_scope); + + debug!( + "Requesting scope upgrade: current={:?}, required={}, union={:?}", + current_scopes, required_scope, upgraded_scopes + ); + + let scope_refs: Vec<&str> = upgraded_scopes.iter().map(|s| s.as_str()).collect(); + self.get_authorization_url(&scope_refs).await + } + + /// reset scope upgrade attempt counter + pub async fn reset_scope_upgrade_attempts(&self) { + *self.scope_upgrade_attempts.write().await = 0; + } + + /// get the number of scope upgrade attempts made + pub async fn get_scope_upgrade_attempts(&self) -> u32 { + *self.scope_upgrade_attempts.read().await + } + /// exchange authorization code for access token pub async fn exchange_code_for_token( &self, @@ -666,6 +852,7 @@ impl AuthorizationManager { let token_result = match oauth_client .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(pkce_verifier) + .add_extra_param("resource", self.base_url.to_string()) .request_async(&http_client) .await { @@ -690,11 +877,19 @@ impl AuthorizationManager { debug!("exchange token result: {:?}", token_result); - // Store credentials in the credential store + let granted_scopes: Vec = token_result + .scopes() + .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) + .unwrap_or_default(); + + *self.current_scopes.write().await = granted_scopes.clone(); + *self.scope_upgrade_attempts.write().await = 0; + let client_id = oauth_client.client_id().to_string(); let stored = StoredCredentials { client_id, token_response: Some(token_result.clone()), + granted_scopes, }; self.credential_store.save(stored).await?; @@ -751,10 +946,18 @@ impl AuthorizationManager { .await .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; + let granted_scopes: Vec = token_result + .scopes() + .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) + .unwrap_or_else(|| self.current_scopes.blocking_read().clone()); + + *self.current_scopes.write().await = granted_scopes.clone(); + let client_id = oauth_client.client_id().to_string(); let stored = StoredCredentials { client_id, token_response: Some(token_result.clone()), + granted_scopes, }; self.credential_store.save(stored).await?; @@ -770,17 +973,31 @@ impl AuthorizationManager { Ok(request.header(AUTHORIZATION, format!("Bearer {}", token))) } - /// handle response, check if need to re-authorize + /// handle response, check if need to re-authorize or scope upgrade pub async fn handle_response( &self, response: reqwest::Response, ) -> Result { if response.status() == StatusCode::UNAUTHORIZED { - // 401 Unauthorized, need to re-authorize - Err(AuthError::AuthorizationRequired) - } else { - Ok(response) + return Err(AuthError::AuthorizationRequired); } + if response.status() == StatusCode::FORBIDDEN { + for value in response.headers().get_all(WWW_AUTHENTICATE).iter() { + let Ok(value_str) = value.to_str() else { + continue; + }; + let params = Self::extract_www_authenticate_params(value_str, &self.base_url); + if params.is_insufficient_scope() { + let required_scope = params.scope.unwrap_or_default(); + return Err(AuthError::InsufficientScope { + required_scope, + upgrade_url: None, + }); + } + } + return Err(AuthError::AuthorizationFailed("Forbidden".to_string())); + } + Ok(response) } /// Generate discovery endpoint URLs following the priority order in spec-2025-11-25 4.3 "Authorization Server Metadata Discovery". @@ -874,6 +1091,13 @@ impl AuthorizationManager { return Ok(None); }; + // store scopes_supported from protected resource metadata for select_scopes() + if let Some(scopes) = resource_metadata.scopes_supported { + if !scopes.is_empty() { + *self.resource_scopes.write().await = scopes; + } + } + let mut candidates = Vec::new(); if let Some(single) = resource_metadata.authorization_server { @@ -973,9 +1197,14 @@ impl AuthorizationManager { let Ok(value_str) = value.to_str() else { continue; }; - if let Some(url) = - Self::extract_resource_metadata_url_from_header(value_str, &self.base_url) - { + let params = Self::extract_www_authenticate_params(value_str, &self.base_url); + if let Some(url) = params.resource_metadata_url { + if let Some(scope) = ¶ms.scope { + debug!("WWW-Authenticate header contains scope: {}", scope); + let scopes: Vec = + scope.split_whitespace().map(|s| s.to_string()).collect(); + *self.www_auth_scopes.write().await = scopes; + } parsed_url = Some(url); break; } @@ -1023,21 +1252,25 @@ impl AuthorizationManager { Ok(Some(metadata)) } - /// Extracts a url following `resource_metadata=` in a header value - fn extract_resource_metadata_url_from_header(header: &str, base_url: &Url) -> Option { + /// extract parameters from WWW-Authenticate header (resource_metadata and scope) + fn extract_www_authenticate_params(header: &str, base_url: &Url) -> WWWAuthenticateParams { + let mut params = WWWAuthenticateParams::default(); let header_lowercase = header.to_ascii_lowercase(); - let fragment_key = "resource_metadata="; - let mut search_offset = 0; - while let Some(pos) = header_lowercase[search_offset..].find(fragment_key) { - let global_pos = search_offset + pos + fragment_key.len(); + // extract resource_metadata + let mut search_offset = 0; + let resource_key = "resource_metadata="; + while let Some(pos) = header_lowercase[search_offset..].find(resource_key) { + let global_pos = search_offset + pos + resource_key.len(); let value_slice = &header[global_pos..]; if let Some((value, consumed)) = Self::parse_next_header_value(value_slice) { if let Ok(url) = Url::parse(&value) { - return Some(url); + params.resource_metadata_url = Some(url); + break; } if let Ok(url) = base_url.join(&value) { - return Some(url); + params.resource_metadata_url = Some(url); + break; } debug!("failed to parse resource metadata value `{value}` as URL"); search_offset = global_pos + consumed; @@ -1047,7 +1280,37 @@ impl AuthorizationManager { } } - None + // extract scope + let scope_key = "scope="; + if let Some(pos) = header_lowercase.find(scope_key) { + let global_pos = pos + scope_key.len(); + let value_slice = &header[global_pos..]; + if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) { + params.scope = Some(value); + } + } + + // extract error + let error_key = "error="; + if let Some(pos) = header_lowercase.find(error_key) { + let global_pos = pos + error_key.len(); + let value_slice = &header[global_pos..]; + if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) { + params.error = Some(value); + } + } + + // extract error_description + let desc_key = "error_description="; + if let Some(pos) = header_lowercase.find(desc_key) { + let global_pos = pos + desc_key.len(); + let value_slice = &header[global_pos..]; + if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) { + params.error_description = Some(value); + } + } + + params } /// Parses an authentication parameter value from a `WWW-Authenticate` header fragment. @@ -1164,6 +1427,19 @@ impl AuthorizationSession { }) } + /// create session for scope upgrade flow (existing manager + pre-computed auth url) + pub fn for_scope_upgrade( + auth_manager: AuthorizationManager, + auth_url: String, + redirect_uri: &str, + ) -> Self { + Self { + auth_manager, + auth_url, + redirect_uri: redirect_uri.to_string(), + } + } + /// get client_id and credentials pub async fn get_credentials(&self) -> Result { self.auth_manager.get_credentials().await @@ -1278,9 +1554,17 @@ impl OAuthState { AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?, ); + let granted_scopes: Vec = credentials + .scopes() + .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) + .unwrap_or_default(); + + *manager.current_scopes.write().await = granted_scopes.clone(); + let stored = StoredCredentials { client_id: client_id.to_string(), token_response: Some(credentials), + granted_scopes, }; manager.credential_store.save(stored).await?; @@ -1324,10 +1608,16 @@ impl OAuthState { debug!("start discovery"); let metadata = manager.discover_metadata().await?; manager.metadata = Some(metadata); + let selected_scopes: Vec = if scopes.is_empty() { + manager.select_scopes(None, &[]) + } else { + scopes.iter().map(|s| s.to_string()).collect() + }; + let scope_refs: Vec<&str> = selected_scopes.iter().map(|s| s.as_str()).collect(); debug!("start session"); let session = AuthorizationSession::new( manager, - scopes, + &scope_refs, redirect_uri, client_name, client_metadata_url, @@ -1371,6 +1661,35 @@ impl OAuthState { )) } } + + /// request scope upgrade (Authorized -> Session); returns auth URL to open + pub async fn request_scope_upgrade( + &mut self, + required_scope: &str, + redirect_uri: &str, + ) -> Result { + let placeholder = + OAuthState::Authorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?); + let old = std::mem::replace(self, placeholder); + let OAuthState::Authorized(manager) = old else { + *self = old; + return Err(AuthError::InternalError( + "Not in authorized state".to_string(), + )); + }; + let auth_url = match manager.request_scope_upgrade(required_scope).await { + Ok(url) => url, + Err(e) => { + *self = OAuthState::Authorized(manager); + return Err(e); + } + }; + let session = + AuthorizationSession::for_scope_upgrade(manager, auth_url.clone(), redirect_uri); + *self = OAuthState::Session(session); + Ok(auth_url) + } + /// get current authorization url pub async fn get_authorization_url(&self) -> Result { match self { @@ -1457,101 +1776,27 @@ mod tests { use url::Url; use super::{ - AuthError, AuthorizationManager, InMemoryStateStore, StateStore, StoredAuthorizationState, - is_https_url, + AuthError, AuthorizationManager, AuthorizationMetadata, InMemoryStateStore, + ScopeUpgradeConfig, StateStore, StoredAuthorizationState, is_https_url, }; - // SEP-991: URL-based Client IDs - // Tests adapted from the TypeScript SDK's isHttpsUrl test suite + // -- url helpers -- + #[test] fn test_is_https_url_scenarios() { - // Returns true for valid https url with path assert!(is_https_url("https://example.com/client-metadata.json")); - // Returns true for https url with query params assert!(is_https_url("https://example.com/metadata?version=1")); - // Returns false for https url without path assert!(!is_https_url("https://example.com")); assert!(!is_https_url("https://example.com/")); assert!(!is_https_url("https://")); - // Returns false for http url assert!(!is_https_url("http://example.com/metadata")); - // Returns false for non-url strings assert!(!is_https_url("not a url")); - // Returns false for empty string assert!(!is_https_url("")); - // Returns false for javascript scheme assert!(!is_https_url("javascript:alert(1)")); - // Returns false for data scheme assert!(!is_https_url("data:text/html,")); } - #[test] - fn parses_resource_metadata_parameter() { - let header = r#"Bearer error="invalid_request", error_description="missing token", resource_metadata="https://example.com/.well-known/oauth-protected-resource/api""#; - let base = Url::parse("https://example.com/api").unwrap(); - let parsed = AuthorizationManager::extract_resource_metadata_url_from_header(header, &base); - assert_eq!( - parsed.unwrap().as_str(), - "https://example.com/.well-known/oauth-protected-resource/api" - ); - } - - #[test] - fn parses_relative_resource_metadata_parameter() { - let header = r#"Bearer error="invalid_request", resource_metadata="/.well-known/oauth-protected-resource/api""#; - let base = Url::parse("https://example.com/api").unwrap(); - let parsed = AuthorizationManager::extract_resource_metadata_url_from_header(header, &base); - assert_eq!( - parsed.unwrap().as_str(), - "https://example.com/.well-known/oauth-protected-resource/api" - ); - } - - #[test] - fn parse_auth_param_value_handles_quoted_string() { - let fragment = r#""example", realm="foo""#; - let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); - assert_eq!(parsed.0, "example"); - assert_eq!(parsed.1, 9); - } - - #[test] - fn parse_auth_param_value_handles_escaped_quotes_and_whitespace() { - let fragment = r#" "a\"b\\c" ,next=value"#; - let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); - assert_eq!(parsed.0, r#"a"b\c"#); - assert_eq!(parsed.1, 12); - } - - #[test] - fn parse_auth_param_value_handles_token_values() { - let fragment = " token,next"; - let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); - assert_eq!(parsed.0, "token"); - assert_eq!(parsed.1, 7); - } - - #[test] - fn parse_auth_param_value_handles_semicolon_separated_tokens() { - let fragment = r#" https://example.com/meta; error="invalid_token""#; - let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); - assert_eq!(parsed.0, "https://example.com/meta"); - assert_eq!(&fragment[..parsed.1], " https://example.com/meta"); - } - - #[test] - fn parse_auth_param_value_handles_semicolon_after_quoted_value() { - let fragment = r#" "https://example.com/meta"; error="invalid_token""#; - let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); - assert_eq!(parsed.0, "https://example.com/meta"); - assert_eq!(&fragment[..parsed.1], r#" "https://example.com/meta""#); - } - - #[test] - fn parse_auth_param_value_returns_none_for_unterminated_quotes() { - let fragment = r#""unterminated,value"#; - assert!(AuthorizationManager::parse_next_header_value(fragment).is_none()); - } + // -- well-known path generation -- #[test] fn well_known_paths_root() { @@ -1589,9 +1834,18 @@ mod tests { ); } + #[test] + fn test_protected_resource_metadata_paths() { + let paths = + AuthorizationManager::well_known_paths("/mcp/example", "oauth-protected-resource"); + assert!(paths.contains(&"/.well-known/oauth-protected-resource/mcp/example".to_string())); + assert!(paths.contains(&"/.well-known/oauth-protected-resource".to_string())); + } + + // -- discovery url generation -- + #[test] fn generate_discovery_urls() { - // Test root URL (no path components): OAuth first, then OpenID Connect let base_url = Url::parse("https://auth.example.com").unwrap(); let urls = AuthorizationManager::generate_discovery_urls(&base_url); assert_eq!(urls.len(), 2); @@ -1604,7 +1858,6 @@ mod tests { "https://auth.example.com/.well-known/openid-configuration" ); - // Test URL with single path segment: follow spec priority order let base_url = Url::parse("https://auth.example.com/tenant1").unwrap(); let urls = AuthorizationManager::generate_discovery_urls(&base_url); assert_eq!(urls.len(), 4); @@ -1625,7 +1878,6 @@ mod tests { "https://auth.example.com/.well-known/oauth-authorization-server" ); - // Test URL with path and trailing slash let base_url = Url::parse("https://auth.example.com/v1/mcp/").unwrap(); let urls = AuthorizationManager::generate_discovery_urls(&base_url); assert_eq!(urls.len(), 4); @@ -1646,7 +1898,6 @@ mod tests { "https://auth.example.com/.well-known/oauth-authorization-server" ); - // Test URL with multiple path segments let base_url = Url::parse("https://auth.example.com/tenant1/subtenant").unwrap(); let urls = AuthorizationManager::generate_discovery_urls(&base_url); assert_eq!(urls.len(), 4); @@ -1668,57 +1919,181 @@ mod tests { ); } - // StateStore and StoredAuthorizationState tests + #[test] + fn test_discovery_urls_with_path_suffix() { + let base_url = Url::parse("https://mcp.example.com/mcp").unwrap(); + let urls = AuthorizationManager::generate_discovery_urls(&base_url); - #[tokio::test] - async fn test_in_memory_state_store_save_and_load() { - let store = InMemoryStateStore::new(); - let pkce = PkceCodeVerifier::new("test-verifier".to_string()); - let csrf = CsrfToken::new("test-csrf".to_string()); - let state = StoredAuthorizationState::new(&pkce, &csrf); + let canonical_oauth_fallback = + "https://mcp.example.com/.well-known/oauth-authorization-server"; - // Save state - store.save("test-csrf", state).await.unwrap(); + assert!( + urls.iter().any(|u| u.as_str() == canonical_oauth_fallback), + "Expected discovery URLs to include canonical OAuth fallback '{}', but got: {:?}", + canonical_oauth_fallback, + urls.iter().map(|u| u.as_str()).collect::>() + ); + } - // Load state - let loaded = store.load("test-csrf").await.unwrap(); - assert!(loaded.is_some()); - let loaded = loaded.unwrap(); - assert_eq!(loaded.csrf_token, "test-csrf"); - assert_eq!(loaded.pkce_verifier, "test-verifier"); + // -- header value parsing -- + + #[test] + fn parse_auth_param_value_handles_quoted_string() { + let fragment = r#""example", realm="foo""#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "example"); + assert_eq!(parsed.1, 9); } - #[tokio::test] - async fn test_in_memory_state_store_load_nonexistent() { - let store = InMemoryStateStore::new(); - let result = store.load("nonexistent").await.unwrap(); - assert!(result.is_none()); + #[test] + fn parse_auth_param_value_handles_escaped_quotes_and_whitespace() { + let fragment = r#" "a\"b\\c" ,next=value"#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, r#"a"b\c"#); + assert_eq!(parsed.1, 12); } - #[tokio::test] - async fn test_in_memory_state_store_delete() { - let store = InMemoryStateStore::new(); - let pkce = PkceCodeVerifier::new("verifier".to_string()); - let csrf = CsrfToken::new("csrf".to_string()); - let state = StoredAuthorizationState::new(&pkce, &csrf); + #[test] + fn parse_auth_param_value_handles_token_values() { + let fragment = " token,next"; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "token"); + assert_eq!(parsed.1, 7); + } - store.save("csrf", state).await.unwrap(); - store.delete("csrf").await.unwrap(); + #[test] + fn parse_auth_param_value_handles_semicolon_separated_tokens() { + let fragment = r#" https://example.com/meta; error="invalid_token""#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "https://example.com/meta"); + assert_eq!(&fragment[..parsed.1], " https://example.com/meta"); + } - let result = store.load("csrf").await.unwrap(); - assert!(result.is_none()); + #[test] + fn parse_auth_param_value_handles_semicolon_after_quoted_value() { + let fragment = r#" "https://example.com/meta"; error="invalid_token""#; + let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap(); + assert_eq!(parsed.0, "https://example.com/meta"); + assert_eq!(&fragment[..parsed.1], r#" "https://example.com/meta""#); + } + + #[test] + fn parse_auth_param_value_returns_none_for_unterminated_quotes() { + let fragment = r#""unterminated,value"#; + assert!(AuthorizationManager::parse_next_header_value(fragment).is_none()); + } + + // -- www-authenticate param extraction -- + + #[test] + fn parses_resource_metadata_parameter() { + let header = r#"Bearer error="invalid_request", error_description="missing token", resource_metadata="https://example.com/.well-known/oauth-protected-resource/api""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + assert_eq!( + params.resource_metadata_url.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource/api" + ); + } + + #[test] + fn parses_relative_resource_metadata_parameter() { + let header = r#"Bearer error="invalid_request", resource_metadata="/.well-known/oauth-protected-resource/api""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + assert_eq!( + params.resource_metadata_url.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource/api" + ); + } + + #[test] + fn extract_www_authenticate_params_with_all_fields() { + let header = r#"Bearer error="invalid_token", resource_metadata="https://example.com/.well-known/oauth-protected-resource", scope="read:data write:data", error_description="token expired""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert_eq!( + params.resource_metadata_url.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource" + ); + assert_eq!(params.scope.unwrap(), "read:data write:data"); + assert_eq!(params.error.unwrap(), "invalid_token"); + assert_eq!(params.error_description.unwrap(), "token expired"); + } + + #[test] + fn extract_www_authenticate_params_insufficient_scope() { + let header = r#"Bearer error="insufficient_scope", scope="admin:write", error_description="Additional file write permission required""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert!(params.resource_metadata_url.is_none()); + assert!(params.is_insufficient_scope()); + assert!(!params.is_invalid_token()); + assert_eq!(params.scope.unwrap(), "admin:write"); + assert_eq!( + params.error_description.unwrap(), + "Additional file write permission required" + ); + } + + #[test] + fn extract_www_authenticate_params_with_only_resource_metadata() { + let header = r#"Bearer resource_metadata="/.well-known/oauth-protected-resource""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert_eq!( + params.resource_metadata_url.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource" + ); + assert!(params.scope.is_none()); } + #[test] + fn extract_www_authenticate_params_bare_bearer() { + let header = "Bearer"; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert!(params.resource_metadata_url.is_none()); + assert!(params.scope.is_none()); + assert!(params.error.is_none()); + assert!(params.error_description.is_none()); + } + + #[test] + fn extract_www_authenticate_params_error_only() { + let header = r#"Bearer error="invalid_token""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert!(params.resource_metadata_url.is_none()); + assert!(params.scope.is_none()); + assert!(params.is_invalid_token()); + assert!(!params.is_insufficient_scope()); + assert!(params.error_description.is_none()); + } + + #[test] + fn extract_www_authenticate_params_with_unquoted_scope() { + let header = r#"Bearer scope=read:data, error="insufficient_scope""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert_eq!(params.scope.unwrap(), "read:data"); + } + + // -- stored authorization state -- + #[test] fn test_stored_authorization_state_serialization() { let pkce = PkceCodeVerifier::new("my-verifier".to_string()); let csrf = CsrfToken::new("my-csrf".to_string()); let state = StoredAuthorizationState::new(&pkce, &csrf); - // Serialize to JSON let json = serde_json::to_string(&state).unwrap(); - - // Deserialize back let deserialized: StoredAuthorizationState = serde_json::from_str(&json).unwrap(); assert_eq!(deserialized.pkce_verifier, "my-verifier"); @@ -1741,28 +2116,63 @@ mod tests { let csrf = CsrfToken::new("csrf".to_string()); let state = StoredAuthorizationState::new(&pkce, &csrf); - // created_at should be a reasonable timestamp (after year 2020) assert!(state.created_at > 1577836800); // Jan 1, 2020 } + // -- state store -- + + #[tokio::test] + async fn test_in_memory_state_store_save_and_load() { + let store = InMemoryStateStore::new(); + let pkce = PkceCodeVerifier::new("test-verifier".to_string()); + let csrf = CsrfToken::new("test-csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + store.save("test-csrf", state).await.unwrap(); + + let loaded = store.load("test-csrf").await.unwrap(); + assert!(loaded.is_some()); + let loaded = loaded.unwrap(); + assert_eq!(loaded.csrf_token, "test-csrf"); + assert_eq!(loaded.pkce_verifier, "test-verifier"); + } + + #[tokio::test] + async fn test_in_memory_state_store_load_nonexistent() { + let store = InMemoryStateStore::new(); + let result = store.load("nonexistent").await.unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_in_memory_state_store_delete() { + let store = InMemoryStateStore::new(); + let pkce = PkceCodeVerifier::new("verifier".to_string()); + let csrf = CsrfToken::new("csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + store.save("csrf", state).await.unwrap(); + store.delete("csrf").await.unwrap(); + + let result = store.load("csrf").await.unwrap(); + assert!(result.is_none()); + } + #[tokio::test] async fn test_in_memory_state_store_overwrite() { let store = InMemoryStateStore::new(); let csrf_key = "same-csrf"; - // Save first state let pkce1 = PkceCodeVerifier::new("verifier-1".to_string()); let csrf1 = CsrfToken::new(csrf_key.to_string()); let state1 = StoredAuthorizationState::new(&pkce1, &csrf1); store.save(csrf_key, state1).await.unwrap(); - // Overwrite with second state let pkce2 = PkceCodeVerifier::new("verifier-2".to_string()); let csrf2 = CsrfToken::new(csrf_key.to_string()); let state2 = StoredAuthorizationState::new(&pkce2, &csrf2); store.save(csrf_key, state2).await.unwrap(); - // Should get the second state let loaded = store.load(csrf_key).await.unwrap().unwrap(); assert_eq!(loaded.pkce_verifier, "verifier-2"); } @@ -1772,7 +2182,6 @@ mod tests { let store = Arc::new(InMemoryStateStore::new()); let mut handles = vec![]; - // Spawn 10 concurrent tasks that each save and load their own state for i in 0..10 { let store = Arc::clone(&store); let handle = tokio::spawn(async move { @@ -1794,36 +2203,15 @@ mod tests { handles.push(handle); } - // Wait for all tasks to complete for handle in handles { handle.await.unwrap(); } } - #[test] - fn test_discovery_urls_with_path_suffix() { - // When the base URL has a path suffix (e.g., /mcp), the discovery should - // eventually fall back to checking /.well-known/oauth-authorization-server - // at the root, not just /.well-known/oauth-authorization-server/{path}. - let base_url = Url::parse("https://mcp.example.com/mcp").unwrap(); - let urls = AuthorizationManager::generate_discovery_urls(&base_url); - - let canonical_oauth_fallback = - "https://mcp.example.com/.well-known/oauth-authorization-server"; - - assert!( - urls.iter().any(|u| u.as_str() == canonical_oauth_fallback), - "Expected discovery URLs to include canonical OAuth fallback '{}', but got: {:?}", - canonical_oauth_fallback, - urls.iter().map(|u| u.as_str()).collect::>() - ); - } - #[tokio::test] async fn test_custom_state_store_with_authorization_manager() { use std::sync::atomic::{AtomicUsize, Ordering}; - // Custom state store that tracks calls #[derive(Debug, Default)] struct TrackingStateStore { inner: InMemoryStateStore, @@ -1857,7 +2245,6 @@ mod tests { } } - // Verify custom store works standalone let store = TrackingStateStore::default(); let pkce = PkceCodeVerifier::new("test-verifier".to_string()); let csrf = CsrfToken::new("test-csrf".to_string()); @@ -1872,8 +2259,211 @@ mod tests { store.delete("test-csrf").await.unwrap(); assert_eq!(store.delete_count.load(Ordering::SeqCst), 1); - // Verify custom store can be set on AuthorizationManager let mut manager = AuthorizationManager::new("http://localhost").await.unwrap(); manager.set_state_store(TrackingStateStore::default()); } + + // -- metadata deserialization -- + + #[test] + fn test_code_challenge_methods_supported_deserialization() { + let json = r#"{ + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "code_challenge_methods_supported": ["S256", "plain"] + }"#; + let metadata: AuthorizationMetadata = serde_json::from_str(json).unwrap(); + let methods = metadata.code_challenge_methods_supported.unwrap(); + assert!(methods.contains(&"S256".to_string())); + assert!(methods.contains(&"plain".to_string())); + } + + #[test] + fn test_code_challenge_methods_supported_missing_from_json() { + let json = r#"{ + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token" + }"#; + let metadata: AuthorizationMetadata = serde_json::from_str(json).unwrap(); + assert!(metadata.code_challenge_methods_supported.is_none()); + } + + // -- server validation -- + + #[tokio::test] + async fn test_validate_as_metadata_rejects_unsupported_response_type() { + let mut manager = AuthorizationManager::new("https://example.com") + .await + .unwrap(); + let metadata = AuthorizationMetadata { + authorization_endpoint: "https://auth.example.com/authorize".to_string(), + token_endpoint: "https://auth.example.com/token".to_string(), + response_types_supported: Some(vec!["token".to_string()]), + ..Default::default() + }; + manager.set_metadata(metadata); + assert!(manager.validate_server_metadata("code").is_err()); + } + + #[tokio::test] + async fn test_validate_as_metadata_passes_without_pkce_s256() { + let mut manager = AuthorizationManager::new("https://example.com") + .await + .unwrap(); + let metadata = AuthorizationMetadata { + authorization_endpoint: "https://auth.example.com/authorize".to_string(), + token_endpoint: "https://auth.example.com/token".to_string(), + response_types_supported: Some(vec!["code".to_string()]), + code_challenge_methods_supported: Some(vec!["plain".to_string()]), + ..Default::default() + }; + manager.set_metadata(metadata); + assert!(manager.validate_server_metadata("code").is_ok()); + } + + #[tokio::test] + async fn test_validate_as_metadata_passes_without_metadata() { + let manager = AuthorizationManager::new("https://example.com") + .await + .unwrap(); + assert!(manager.validate_server_metadata("code").is_ok()); + } + + // -- authorization flow -- + + #[tokio::test] + async fn test_authorization_url_is_valid() { + let base_url = "https://mcp.example.com/api"; + let auth_endpoint = "https://auth.example.com/authorize"; + let mut manager = AuthorizationManager::new(base_url).await.unwrap(); + + let metadata = AuthorizationMetadata { + authorization_endpoint: auth_endpoint.to_string(), + token_endpoint: "https://auth.example.com/token".to_string(), + registration_endpoint: None, + issuer: None, + jwks_uri: None, + scopes_supported: None, + response_types_supported: Some(vec!["code".to_string()]), + code_challenge_methods_supported: Some(vec!["S256".to_string()]), + additional_fields: std::collections::HashMap::new(), + }; + manager.set_metadata(metadata); + manager.configure_client_id("test-client-id").unwrap(); + + let auth_url = manager + .get_authorization_url(&["read", "write"]) + .await + .unwrap(); + let parsed = Url::parse(&auth_url).unwrap(); + + assert!(auth_url.starts_with(auth_endpoint)); + + let params: std::collections::HashMap<_, _> = parsed.query_pairs().collect(); + + assert_eq!( + params.get("response_type").map(|v| v.as_ref()), + Some("code") + ); + assert_eq!( + params.get("client_id").map(|v| v.as_ref()), + Some("test-client-id") + ); + assert!(params.contains_key("state")); + assert_eq!( + params.get("redirect_uri").map(|v| v.as_ref()), + Some(base_url) + ); + assert!(params.contains_key("code_challenge")); + assert_eq!( + params.get("code_challenge_method").map(|v| v.as_ref()), + Some("S256") + ); + assert_eq!(params.get("resource").map(|v| v.as_ref()), Some(base_url)); + + let scope = params + .get("scope") + .map(|v| v.to_string()) + .unwrap_or_default(); + assert!(scope.contains("read")); + assert!(scope.contains("write")); + } + + // -- scope management -- + + #[test] + fn compute_scope_union_adds_new_scopes() { + let current = vec!["read".to_string(), "write".to_string()]; + let result = AuthorizationManager::compute_scope_union(¤t, "admin delete"); + + assert!(result.contains(&"read".to_string())); + assert!(result.contains(&"write".to_string())); + assert!(result.contains(&"admin".to_string())); + assert!(result.contains(&"delete".to_string())); + assert_eq!(result.len(), 4); + } + + #[test] + fn compute_scope_union_deduplicates() { + let current = vec!["read".to_string(), "write".to_string()]; + let result = AuthorizationManager::compute_scope_union(¤t, "read admin"); + + assert!(result.contains(&"read".to_string())); + assert!(result.contains(&"write".to_string())); + assert!(result.contains(&"admin".to_string())); + assert_eq!(result.len(), 3); + } + + #[test] + fn compute_scope_union_handles_empty_current() { + let current: Vec = vec![]; + let result = AuthorizationManager::compute_scope_union(¤t, "read write"); + + assert!(result.contains(&"read".to_string())); + assert!(result.contains(&"write".to_string())); + assert_eq!(result.len(), 2); + } + + #[test] + fn scope_upgrade_config_default_values() { + let config = ScopeUpgradeConfig::default(); + assert_eq!(config.max_upgrade_attempts, 3); + assert!(config.auto_upgrade); + } + + #[tokio::test] + async fn authorization_manager_tracks_scope_upgrade_attempts() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + + assert_eq!(manager.get_scope_upgrade_attempts().await, 0); + + *manager.scope_upgrade_attempts.write().await = 2; + assert_eq!(manager.get_scope_upgrade_attempts().await, 2); + + manager.reset_scope_upgrade_attempts().await; + assert_eq!(manager.get_scope_upgrade_attempts().await, 0); + } + + #[tokio::test] + async fn authorization_manager_can_attempt_scope_upgrade_respects_config() { + let mut manager = AuthorizationManager::new("http://localhost").await.unwrap(); + + assert!(manager.can_attempt_scope_upgrade().await); + + manager.set_scope_upgrade_config(ScopeUpgradeConfig { + max_upgrade_attempts: 3, + auto_upgrade: false, + }); + assert!(!manager.can_attempt_scope_upgrade().await); + + manager.set_scope_upgrade_config(ScopeUpgradeConfig { + max_upgrade_attempts: 2, + auto_upgrade: true, + }); + *manager.scope_upgrade_attempts.write().await = 2; + assert!(!manager.can_attempt_scope_upgrade().await); + + *manager.scope_upgrade_attempts.write().await = 1; + assert!(manager.can_attempt_scope_upgrade().await); + } } diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index 0ecbad20..cc26bdc9 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -120,6 +120,22 @@ impl StreamableHttpClient for reqwest::Client { })); } } + if response.status() == reqwest::StatusCode::FORBIDDEN { + if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { + let header_str = header.to_str().map_err(|_| { + StreamableHttpError::UnexpectedServerResponse(Cow::from( + "invalid www-authenticate header value", + )) + })?; + let scope = extract_scope_from_header(header_str); + return Err(StreamableHttpError::InsufficientScope( + InsufficientScopeError { + www_authenticate_header: header_str.to_string(), + required_scope: scope, + }, + )); + } + } let status = response.status(); if matches!( status, @@ -197,3 +213,81 @@ impl StreamableHttpClientTransport { StreamableHttpClientTransport::with_client(reqwest::Client::default(), config) } } + +/// extract scope parameter from WWW-Authenticate header +fn extract_scope_from_header(header: &str) -> Option { + let header_lowercase = header.to_ascii_lowercase(); + let scope_key = "scope="; + + if let Some(pos) = header_lowercase.find(scope_key) { + let start = pos + scope_key.len(); + let value_slice = &header[start..]; + + if let Some(stripped) = value_slice.strip_prefix('"') { + if let Some(end_quote) = stripped.find('"') { + return Some(stripped[..end_quote].to_string()); + } + } else { + let end = value_slice + .find(|c: char| c == ',' || c == ';' || c.is_whitespace()) + .unwrap_or(value_slice.len()); + if end > 0 { + return Some(value_slice[..end].to_string()); + } + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::extract_scope_from_header; + use crate::transport::streamable_http_client::InsufficientScopeError; + + #[test] + fn extract_scope_quoted() { + let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#; + assert_eq!( + extract_scope_from_header(header), + Some("files:read files:write".to_string()) + ); + } + + #[test] + fn extract_scope_unquoted() { + let header = r#"Bearer scope=read:data, error="insufficient_scope""#; + assert_eq!( + extract_scope_from_header(header), + Some("read:data".to_string()) + ); + } + + #[test] + fn extract_scope_missing() { + let header = r#"Bearer error="invalid_token""#; + assert_eq!(extract_scope_from_header(header), None); + } + + #[test] + fn extract_scope_empty_header() { + assert_eq!(extract_scope_from_header("Bearer"), None); + } + + #[test] + fn insufficient_scope_error_can_upgrade() { + let with_scope = InsufficientScopeError { + www_authenticate_header: "Bearer scope=\"admin\"".to_string(), + required_scope: Some("admin".to_string()), + }; + assert!(with_scope.can_upgrade()); + assert_eq!(with_scope.get_required_scope(), Some("admin")); + + let without_scope = InsufficientScopeError { + www_authenticate_header: "Bearer error=\"insufficient_scope\"".to_string(), + required_scope: None, + }; + assert!(!without_scope.can_upgrade()); + assert_eq!(without_scope.get_required_scope(), None); + } +} diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 35140c1b..550d261b 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -24,6 +24,24 @@ pub struct AuthRequiredError { pub www_authenticate_header: String, } +#[derive(Debug)] +pub struct InsufficientScopeError { + pub www_authenticate_header: String, + pub required_scope: Option, +} + +impl InsufficientScopeError { + /// check if scope upgrade is possible (i.e., we know what scope is required) + pub fn can_upgrade(&self) -> bool { + self.required_scope.is_some() + } + + /// get the required scope for upgrade + pub fn get_required_scope(&self) -> Option<&str> { + self.required_scope.as_deref() + } +} + #[derive(Error, Debug)] pub enum StreamableHttpError { #[error("SSE error: {0}")] @@ -56,6 +74,8 @@ pub enum StreamableHttpError { Auth(#[from] crate::transport::auth::AuthError), #[error("Auth required")] AuthRequired(AuthRequiredError), + #[error("Insufficient scope")] + InsufficientScope(InsufficientScopeError), } #[derive(Debug, Clone, Error)] diff --git a/docs/OAUTH_SUPPORT.md b/docs/OAUTH_SUPPORT.md index 3142c62b..b0b59f9f 100644 --- a/docs/OAUTH_SUPPORT.md +++ b/docs/OAUTH_SUPPORT.md @@ -1,13 +1,17 @@ # Model Context Protocol OAuth Authorization -This document describes the OAuth 2.1 authorization implementation for Model Context Protocol (MCP), following the [MCP 2025-03-26 Authorization Specification](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization/). +This document describes the OAuth 2.1 authorization implementation for Model Context Protocol (MCP), following the [MCP 2025-11-25 Authorization Specification](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization/). ## Features -- Full support for OAuth 2.1 authorization flow -- PKCE support for enhanced security -- Authorization server metadata discovery -- Dynamic client registration +- Full support for OAuth 2.1 authorization flow with PKCE (S256) +- RFC 8707 resource parameter binding +- Protected Resource Metadata discovery (RFC 9728) +- Authorization Server Metadata discovery (RFC 8414 + OpenID Connect) +- Dynamic client registration (RFC 7591) +- Client ID Metadata Documents (CIMD) (SEP-991 / Client ID Metadata Documents ) +- Scope selection from WWW-Authenticate, Protected Resource Metadata, and AS metadata +- Scope upgrade on 403 insufficient_scope (SEP-835) - Automatic token refresh - Authorized HTTP Client implementation @@ -24,32 +28,43 @@ rmcp = { version = "0.1", features = ["auth", "transport-streamable-http-client- ### 2. Use OAuthState +The `OAuthState` state machine manages the full authorization lifecycle. When no +scopes are provided, the SDK automatically selects scopes from the server's +WWW-Authenticate header, Protected Resource Metadata, or AS metadata. + ```rust ignore - // Initialize oauth state machine + // initialize oauth state machine let mut oauth_state = OAuthState::new(&server_url, None) .await .context("Failed to initialize oauth state machine")?; + + // start authorization - pass empty scopes to let the SDK auto-select oauth_state - .start_authorization(&["mcp", "profile", "email"], MCP_REDIRECT_URI) + .start_authorization(&[], MCP_REDIRECT_URI, Some("My MCP Client")) .await .context("Failed to start authorization")?; +``` +If you know the scopes you need, you can still pass them explicitly: + +```rust ignore + oauth_state + .start_authorization(&["mcp", "profile"], MCP_REDIRECT_URI, Some("My MCP Client")) + .await + .context("Failed to start authorization")?; ``` -### 3. Get authorization url and do callback +### 3. Get authorization url and handle callback ```rust ignore - // Get authorization URL and guide user to open it + // get authorization URL and guide user to open it let auth_url = oauth_state.get_authorization_url().await?; println!("Please open the following URL in your browser for authorization:\n{}", auth_url); - // Handle callback - In real applications, this is typically done in a callback server + // handle callback - in real applications, this is typically done in a callback server let auth_code = "Authorization code (`code` param) obtained from browser after user authorization"; let csrf_token = "CSRF token (`state` param) obtained from browser after user authorization"; - let credentials = oauth_state.handle_callback(auth_code, csrf_token).await?; - - println!("Authorization successful, access token: {}", credentials.access_token); - + oauth_state.handle_callback(auth_code, csrf_token).await?; ``` ### 4. Use Authorized Streamable HTTP Transport and create client @@ -64,15 +79,27 @@ rmcp = { version = "0.1", features = ["auth", "transport-streamable-http-client- StreamableHttpClientTransportConfig::with_uri(MCP_SERVER_URL), ); - // Create client and connect to MCP server + // create client and connect to MCP server let client_service = ClientInfo::default(); let client = client_service.serve(transport).await?; ``` -### 5. Use Authorized HTTP Client after authorized +### 5. Handle scope upgrades + +If a server returns 403 with `insufficient_scope`, you can request a scope +upgrade. The SDK computes the union of current and required scopes and +transitions back to the session state for re-authorization. ```rust ignore - let client = oauth_state.to_authorized_http_client().await?; + match oauth_state.request_scope_upgrade("admin:write", MCP_REDIRECT_URI).await { + Ok(auth_url) => { + // open auth_url in browser, handle callback as before + println!("Re-authorize at: {}", auth_url); + } + Err(e) => { + eprintln!("Scope upgrade failed: {}", e); + } + } ``` ## Complete Examples @@ -92,19 +119,24 @@ cargo run -p mcp-client-examples --example clients_oauth_client ## Authorization Flow Description -1. **Metadata Discovery**: Client attempts to get authorization server metadata from `/.well-known/oauth-authorization-server` -2. **Client Registration**: If supported, client dynamically registers itself -3. **Authorization Request**: Build authorization URL with PKCE and guide user to access -4. **Authorization Code Exchange**: After user authorization, exchange authorization code for access token -5. **Token Usage**: Use access token for API calls -6. **Token Refresh**: Automatically use refresh token to get new access token when current one expires +1. **Resource Metadata Discovery**: Client probes the server and extracts `WWW-Authenticate` parameters including `resource_metadata` URL and `scope` +2. **Protected Resource Metadata**: Client fetches resource server metadata (RFC 9728) to find authorization server(s) and supported scopes +3. **AS Metadata Discovery**: Client discovers authorization server metadata via RFC 8414 and OpenID Connect well-known endpoints +4. **Client Registration**: If supported, client dynamically registers itself (or uses URL-based Client ID via SEP-991) +5. **Scope Selection**: SDK picks scopes from WWW-Authenticate > PRM > AS metadata > caller defaults +6. **Authorization Request**: Build authorization URL with PKCE (S256) and RFC 8707 resource parameter +7. **Authorization Code Exchange**: After user authorization, exchange code for access token (with resource parameter) +8. **Token Usage**: Use access token for API calls via `AuthClient` or `AuthorizedHttpClient` +9. **Token Refresh**: Automatically use refresh token to get new access token when current one expires +10. **Scope Upgrade**: On 403 insufficient_scope, compute scope union and re-authorize with upgraded scopes ## Security Considerations -- All tokens are securely stored in memory -- PKCE implementation prevents authorization code interception attacks -- Automatic token refresh support reduces user intervention -- Only accepts HTTPS connections or secure local callback URIs +- **PKCE S256 always enforced**: never falls back to `plain` or no challenge. OAuth 2.1 mandates S256 as Mandatory To Implement for servers. +- **RFC 8707 resource binding**: authorization and token requests include the `resource` parameter to bind tokens to the protected resource +- All tokens are securely stored in memory (custom credential stores supported) +- Automatic token refresh reduces user intervention +- Server metadata validation warns on non-compliant configurations but proceeds where relatively safe ## Troubleshooting @@ -114,10 +146,15 @@ If you encounter authorization issues, check the following: 2. Verify callback URI matches server's allowed redirect URIs 3. Check network connection and firewall settings 4. Verify server supports metadata discovery or dynamic client registration +5. If PKCE fails, the server may not support S256 (non-compliant with OAuth 2.1) +6. Check `tracing` logs at debug level for detailed discovery and validation info ## References -- [MCP Authorization Specification](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) +- [MCP Authorization Specification (2025-11-25)](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization/) - [OAuth 2.1 Specification Draft](https://oauth.net/2.1/) - [RFC 8414: OAuth 2.0 Authorization Server Metadata](https://datatracker.ietf.org/doc/html/rfc8414) - [RFC 7591: OAuth 2.0 Dynamic Client Registration Protocol](https://datatracker.ietf.org/doc/html/rfc7591) +- [RFC 8707: Resource Indicators for OAuth 2.0](https://datatracker.ietf.org/doc/html/rfc8707) +- [RFC 9728: OAuth 2.0 Protected Resource Metadata](https://datatracker.ietf.org/doc/html/rfc9728) +- [RFC 7636: Proof Key for Code Exchange (PKCE)](https://datatracker.ietf.org/doc/html/rfc7636) diff --git a/examples/clients/src/auth/oauth_client.rs b/examples/clients/src/auth/oauth_client.rs index 4f94a3ce..456f3269 100644 --- a/examples/clients/src/auth/oauth_client.rs +++ b/examples/clients/src/auth/oauth_client.rs @@ -114,14 +114,16 @@ async fn main() -> Result<()> { client_metadata_url ); - // Initialize oauth state machine + // initialize oauth state machine let mut oauth_state = OAuthState::new(&server_url, None) .await .context("Failed to initialize oauth state machine")?; - // Use CIMD (SEP-991) with client metadata URL + // use CIMD (SEP-991) with client metadata URL. + // passing empty scopes lets the SDK auto-select from the server's + // WWW-Authenticate header, Protected Resource Metadata, or AS metadata. oauth_state .start_authorization_with_metadata_url( - &["mcp", "profile", "email"], + &[], MCP_REDIRECT_URI, Some("Test MCP Client"), Some(&client_metadata_url), diff --git a/examples/servers/src/complex_auth_streamhttp.rs b/examples/servers/src/complex_auth_streamhttp.rs index 33fa445c..84b68d9c 100644 --- a/examples/servers/src/complex_auth_streamhttp.rs +++ b/examples/servers/src/complex_auth_streamhttp.rs @@ -520,16 +520,13 @@ async fn oauth_authorization_server() -> impl IntoResponse { "response_types_supported".into(), Value::Array(vec![Value::String("code".into())]), ); - additional_fields.insert( - "code_challenge_methods_supported".into(), - Value::Array(vec![Value::String("S256".into())]), - ); let metadata = AuthorizationMetadata { authorization_endpoint: format!("http://{}/oauth/authorize", BIND_ADDRESS), token_endpoint: format!("http://{}/oauth/token", BIND_ADDRESS), scopes_supported: Some(vec!["profile".to_string(), "email".to_string()]), registration_endpoint: Some(format!("http://{}/oauth/register", BIND_ADDRESS)), response_types_supported: Some(vec!["code".to_string()]), + code_challenge_methods_supported: Some(vec!["S256".to_string()]), issuer: Some(BIND_ADDRESS.to_string()), jwks_uri: Some(format!("http://{}/oauth/jwks", BIND_ADDRESS)), additional_fields,