Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 30 additions & 81 deletions conformance/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use rmcp::{
model::*,
service::RequestContext,
transport::{
AuthClient, AuthorizationManager, StreamableHttpClientTransport,
auth::{OAuthClientConfig, OAuthState},
AuthClient, AuthorizationManager, StreamableHttpClientTransport, auth::OAuthState,
streamable_http_client::StreamableHttpClientTransportConfig,
},
};
Expand All @@ -17,9 +16,6 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[derive(Debug, Default, serde::Deserialize)]
struct ConformanceContext {
#[serde(default)]
name: Option<String>,
// pre-registration / client-credentials-basic
#[serde(default)]
client_id: Option<String>,
#[serde(default)]
Expand All @@ -29,15 +25,6 @@ struct ConformanceContext {
private_key_pem: Option<String>,
#[serde(default)]
signing_algorithm: Option<String>,
// cross-app-access
#[serde(default)]
idp_client_id: Option<String>,
#[serde(default)]
idp_id_token: Option<String>,
#[serde(default)]
idp_issuer: Option<String>,
#[serde(default)]
idp_token_endpoint: Option<String>,
}

fn load_context() -> ConformanceContext {
Expand Down Expand Up @@ -175,17 +162,17 @@ impl ClientHandler for FullClientHandler {
.and_then(|c| c.as_text())
.map(|t| t.text.clone())
.unwrap_or_default();
Ok(CreateMessageResult {
message: SamplingMessage::new(
Ok(CreateMessageResult::new(
SamplingMessage::new(
Role::Assistant,
SamplingMessageContent::text(format!(
"This is a mock LLM response to: {}",
prompt_text
)),
),
model: "mock-model".into(),
stop_reason: Some("endTurn".into()),
})
"mock-model".into(),
)
.with_stop_reason("endTurn"))
}
}

Expand Down Expand Up @@ -216,7 +203,7 @@ const REDIRECT_URI: &str = "http://localhost:3000/callback";
/// 4. Return an `AuthClient` wrapping `reqwest::Client`
async fn perform_oauth_flow(
server_url: &str,
ctx: &ConformanceContext,
_ctx: &ConformanceContext,
) -> anyhow::Result<AuthClient<reqwest::Client>> {
let mut oauth = OAuthState::new(server_url, None).await?;

Expand Down Expand Up @@ -335,12 +322,7 @@ async fn run_auth_client(server_url: &str, ctx: &ConformanceContext) -> anyhow::
for tool in &tools.tools {
let args = build_tool_arguments(tool);
let _ = client
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: args,
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), args))
.await;
}

Expand All @@ -352,7 +334,7 @@ async fn run_auth_client(server_url: &str, ctx: &ConformanceContext) -> anyhow::
/// then call tool which triggers 403 → re-auth with expanded scopes → retry.
async fn run_auth_scope_step_up_client(
server_url: &str,
ctx: &ConformanceContext,
_ctx: &ConformanceContext,
) -> anyhow::Result<()> {
// First auth
let mut oauth = OAuthState::new(server_url, None).await?;
Expand Down Expand Up @@ -388,12 +370,7 @@ async fn run_auth_scope_step_up_client(
for tool in &tools.tools {
let args = build_tool_arguments(tool);
match client
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: args.clone(),
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), args.clone()))
.await
{
Ok(_) => {
Expand Down Expand Up @@ -428,12 +405,7 @@ async fn run_auth_scope_step_up_client(
);
let client2 = BasicClientHandler.serve(transport2).await?;
let _ = client2
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: args,
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), args))
.await;
client2.cancel().await.ok();
return Ok(());
Expand Down Expand Up @@ -481,12 +453,7 @@ async fn run_auth_scope_retry_limit_client(
for tool in &tools.tools {
let args = build_tool_arguments(tool);
match client
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: args,
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), args))
.await
{
Ok(_) => {}
Expand Down Expand Up @@ -539,12 +506,7 @@ async fn run_auth_preregistered_client(
for tool in &tools.tools {
let args = build_tool_arguments(tool);
let _ = client
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: args,
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), args))
.await;
}
client.cancel().await?;
Expand Down Expand Up @@ -597,12 +559,7 @@ async fn run_client_credentials_basic(
for tool in &tools.tools {
let args = build_tool_arguments(tool);
let _ = client
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: args,
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), args))
.await;
}
client.cancel().await?;
Expand Down Expand Up @@ -667,12 +624,7 @@ async fn run_client_credentials_jwt(
for tool in &tools.tools {
let args = build_tool_arguments(tool);
let _ = client
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: args,
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), args))
.await;
}
client.cancel().await?;
Expand Down Expand Up @@ -783,6 +735,18 @@ async fn headless_authorize(auth_url: &str) -> anyhow::Result<(String, String)>
Ok((code, state))
}

/// Build a `CallToolRequestParams` for a tool, optionally with arguments.
fn call_tool_params(
name: std::borrow::Cow<'static, str>,
arguments: Option<serde_json::Map<String, Value>>,
) -> CallToolRequestParams {
let mut p = CallToolRequestParams::new(name);
if let Some(a) = arguments {
p = p.with_arguments(a);
}
p
}

/// Build arguments for a tool based on its input schema.
fn build_tool_arguments(tool: &Tool) -> Option<serde_json::Map<String, Value>> {
let schema = &tool.input_schema;
Expand Down Expand Up @@ -840,12 +804,7 @@ async fn run_tools_call_client(server_url: &str) -> anyhow::Result<()> {
for tool in &tools.tools {
let args = build_tool_arguments(tool);
let _ = client
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: args,
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), args))
.await?;
}
client.cancel().await?;
Expand All @@ -862,12 +821,7 @@ async fn run_elicitation_defaults_client(server_url: &str) -> anyhow::Result<()>
});
if let Some(tool) = test_tool {
let _ = client
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: None,
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), None))
.await?;
}
client.cancel().await?;
Expand All @@ -884,12 +838,7 @@ async fn run_sse_retry_client(server_url: &str) -> anyhow::Result<()> {
.find(|t| t.name.as_ref() == "test_reconnection")
{
let _ = client
.call_tool(CallToolRequestParams {
meta: None,
name: tool.name.clone(),
arguments: None,
task: None,
})
.call_tool(call_tool_params(tool.name.clone(), None))
.await?;
}
client.cancel().await?;
Expand Down
Loading