diff --git a/crates/rmcp-macros/src/prompt.rs b/crates/rmcp-macros/src/prompt.rs index d5b0b0f16..a7a13f450 100644 --- a/crates/rmcp-macros/src/prompt.rs +++ b/crates/rmcp-macros/src/prompt.rs @@ -42,9 +42,9 @@ impl ResolvedPromptAttribute { meta, } = self; let description = if let Some(description) = description { - quote! { Some(#description.into()) } + quote! { Some::(#description.into()) } } else { - quote! { None } + quote! { None:: } }; let title = if let Some(title) = title { quote! { Some(#title.into()) } @@ -63,14 +63,14 @@ impl ResolvedPromptAttribute { }; let tokens = quote! { pub fn #fn_ident() -> rmcp::model::Prompt { - rmcp::model::Prompt { - name: #name.into(), - description: #description, - arguments: #arguments, - title: #title, - icons: #icons, - meta: #meta, - } + rmcp::model::Prompt::from_raw( + #name, + #description, + #arguments, + ) + .with_title(#title) + .with_icons(#icons) + .with_meta(#meta) } }; syn::parse2::(tokens) diff --git a/crates/rmcp-macros/src/task_handler.rs b/crates/rmcp-macros/src/task_handler.rs index 4ad02d6b8..86664b18f 100644 --- a/crates/rmcp-macros/src/task_handler.rs +++ b/crates/rmcp-macros/src/task_handler.rs @@ -42,23 +42,16 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result>(); - Ok(rmcp::model::ListTasksResult { - tasks, - next_cursor: None, - total: Some(total), - }) + Ok(rmcp::model::ListTasksResult::new(tasks)) } }; item_impl.items.push(syn::parse2::(list_fn)?); @@ -106,17 +99,14 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result(enqueue_fn)?); @@ -151,15 +141,15 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result rmcp::model::TaskStatus::Failed, }; let timestamp = current_timestamp(); - let task = rmcp::model::Task { + let mut task = rmcp::model::Task::new( task_id, status, - status_message: None, - created_at: timestamp.clone(), - last_updated_at: timestamp, - ttl: completed_result.descriptor.ttl, - poll_interval: None, - }; + timestamp.clone(), + timestamp, + ); + if let Some(ttl) = completed_result.descriptor.ttl { + task = task.with_ttl(ttl); + } return Ok(rmcp::model::GetTaskResult { meta: None, task }); } @@ -167,15 +157,12 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result syn::Result { let value = ::serde_json::to_value(call_tool).unwrap_or(::serde_json::Value::Null); - return Ok(rmcp::model::GetTaskPayloadResult(value)); + return Ok(rmcp::model::GetTaskPayloadResult::new(value)); } Err(err) => return Err(McpError::internal_error( format!("task failed: {}", err), @@ -254,15 +241,12 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result rmcp::model::Tool { - rmcp::model::Tool { - name: #name.into(), - title: #title, - description: #description, - input_schema: #input_schema, - output_schema: #output_schema, - annotations: #annotations, - execution: #execution, - icons: #icons, - meta: #meta, - } + rmcp::model::Tool::new_with_raw( + #name, + #description, + #input_schema, + ) + .with_title(#title) + .with_raw_output_schema(#output_schema) + .with_annotations(#annotations) + .with_execution(#execution) + .with_icons(#icons) + .with_meta(#meta) } }; syn::parse2::(tokens) @@ -260,13 +260,13 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { let idempotent_hint = wrap_option(idempotent_hint); let open_world_hint = wrap_option(open_world_hint); let token_stream = quote! { - Some(rmcp::model::ToolAnnotations { - title: #title, - read_only_hint: #read_only_hint, - destructive_hint: #destructive_hint, - idempotent_hint: #idempotent_hint, - open_world_hint: #open_world_hint, - }) + Some(rmcp::model::ToolAnnotations::from_raw( + #title, + #read_only_hint, + #destructive_hint, + #idempotent_hint, + #open_world_hint, + )) }; syn::parse2::(token_stream)? } else { @@ -296,9 +296,9 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { }; let token_stream = quote! { - Some(rmcp::model::ToolExecution { - task_support: #task_support_expr, - }) + Some(rmcp::model::ToolExecution::from_raw( + #task_support_expr, + )) }; syn::parse2::(token_stream)? } else { diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index 217b22cd6..ebc1db336 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -19,7 +19,7 @@ Creating a server with tools is simple using the `#[tool]` macro: -```rust,no_run +```rust,ignore use rmcp::{ ServerHandler, ServiceExt, handler::server::tool::ToolRouter, @@ -68,11 +68,8 @@ impl Counter { #[tool_handler] impl ServerHandler for Counter { fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple counter that tallies the number of times the increment tool has been used".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_instructions("A simple counter that tallies the number of times the increment tool has been used") } } @@ -147,7 +144,7 @@ To expose task support, enable the `tasks` capability when building `ServerCapab Creating a client to interact with a server: -```rust,no_run +```rust,ignore use rmcp::{ ServiceExt, model::CallToolRequestParams, @@ -176,12 +173,10 @@ async fn main() -> Result<(), Box> { // Call a tool let result = service - .call_tool(CallToolRequestParams { - meta: None, - name: "git_status".into(), - arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), - task: None, - }) + .call_tool( + CallToolRequestParams::new("git_status") + .with_arguments(serde_json::json!({ "repo_path": "." }).as_object().cloned().unwrap_or_default()) + ) .await?; println!("Result: {result:#?}"); diff --git a/crates/rmcp/src/error.rs b/crates/rmcp/src/error.rs index c7901f4b5..74f7d4383 100644 --- a/crates/rmcp/src/error.rs +++ b/crates/rmcp/src/error.rs @@ -20,6 +20,7 @@ impl std::error::Error for ErrorData {} /// This is an unified error type for the errors could be returned by the service. #[derive(Debug, thiserror::Error)] #[allow(clippy::large_enum_variant)] +#[non_exhaustive] pub enum RmcpError { #[cfg(any(feature = "client", feature = "server"))] #[error("Service error: {0}")] diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index b358f5233..c0b3dc436 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -306,6 +306,7 @@ pub struct ProgressToken(pub NumberOrString); /// - `extensions`: Additional context data (similar to HTTP headers) #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct Request { pub method: M, pub params: P, @@ -379,6 +380,7 @@ impl GetExtensions for RequestNoParam { } #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct Notification { pub method: M, pub params: P, @@ -419,6 +421,17 @@ pub struct JsonRpcRequest { pub request: R, } +impl JsonRpcRequest { + /// Create a new JsonRpcRequest. + pub fn new(id: RequestId, request: R) -> Self { + Self { + jsonrpc: JsonRpcVersion2_0, + id, + request, + } + } +} + type DefaultResponse = JsonObject; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] @@ -436,6 +449,17 @@ pub struct JsonRpcError { pub error: ErrorData, } +impl JsonRpcError { + /// Create a new JsonRpcError. + pub fn new(id: RequestId, error: ErrorData) -> Self { + Self { + jsonrpc: JsonRpcVersion2_0, + id, + error, + } + } +} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct JsonRpcNotification { @@ -467,7 +491,7 @@ impl ErrorCode { /// /// This structure follows the JSON-RPC 2.0 specification for error reporting, /// providing a standardized way to communicate errors between clients and servers. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct ErrorData { /// The error type that occurred (using standard JSON-RPC error codes) @@ -745,6 +769,7 @@ pub type InitializedNotification = NotificationNoParam Self { + Self { + meta: None, + protocol_version: ProtocolVersion::default(), + capabilities, + client_info, + } + } + + pub fn with_protocol_version(mut self, protocol_version: ProtocolVersion) -> Self { + self.protocol_version = protocol_version; + self + } +} + impl RequestParamsMeta for InitializeRequestParams { fn meta(&self) -> Option<&Meta> { self.meta.as_ref() @@ -777,6 +819,7 @@ pub type InitializeRequestParam = InitializeRequestParams; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct InitializeResult { /// The MCP protocol version this server supports pub protocol_version: ProtocolVersion, @@ -789,6 +832,36 @@ pub struct InitializeResult { pub instructions: Option, } +impl InitializeResult { + /// Create a new `InitializeResult` with default protocol version and the given capabilities. + pub fn new(capabilities: ServerCapabilities) -> Self { + Self { + protocol_version: ProtocolVersion::default(), + capabilities, + server_info: Implementation::from_build_env(), + instructions: None, + } + } + + /// Set instructions on this result. + pub fn with_instructions(mut self, instructions: impl Into) -> Self { + self.instructions = Some(instructions.into()); + self + } + + /// Set the server info on this result. + pub fn with_server_info(mut self, server_info: Implementation) -> Self { + self.server_info = server_info; + self + } + + /// Set the protocol version on this result. + pub fn with_protocol_version(mut self, protocol_version: ProtocolVersion) -> Self { + self.protocol_version = protocol_version; + self + } +} + pub type ServerInfo = InitializeResult; pub type ClientInfo = InitializeRequestParams; @@ -828,6 +901,7 @@ impl Default for ClientInfo { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct Icon { /// A standard URI pointing to an icon resource pub src: String, @@ -839,9 +913,33 @@ pub struct Icon { pub sizes: Option>, } +impl Icon { + /// Create a new Icon with the given source URL. + pub fn new(src: impl Into) -> Self { + Self { + src: src.into(), + mime_type: None, + sizes: None, + } + } + + /// Set the MIME type. + pub fn with_mime_type(mut self, mime_type: impl Into) -> Self { + self.mime_type = Some(mime_type.into()); + self + } + + /// Set the sizes. + pub fn with_sizes(mut self, sizes: Vec) -> Self { + self.sizes = Some(sizes); + self + } +} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct Implementation { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -862,6 +960,18 @@ impl Default for Implementation { } impl Implementation { + /// Create a new Implementation. + pub fn new(name: impl Into, version: impl Into) -> Self { + Self { + name: name.into(), + title: None, + version: version.into(), + description: None, + icons: None, + website_url: None, + } + } + pub fn from_build_env() -> Self { Implementation { name: env!("CARGO_CRATE_NAME").to_owned(), @@ -872,11 +982,36 @@ impl Implementation { website_url: None, } } + + /// Set the human-readable title. + pub fn with_title(mut self, title: impl Into) -> Self { + self.title = Some(title.into()); + self + } + + /// Set the description. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + /// Set the icons. + pub fn with_icons(mut self, icons: Vec) -> Self { + self.icons = Some(icons); + self + } + + /// Set the website URL. + pub fn with_website_url(mut self, website_url: impl Into) -> Self { + self.website_url = Some(website_url.into()); + self + } } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct PaginatedRequestParams { /// Protocol-level metadata for this request (SEP-1319) #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] @@ -885,6 +1020,13 @@ pub struct PaginatedRequestParams { pub cursor: Option, } +impl PaginatedRequestParams { + pub fn with_cursor(mut self, cursor: Option) -> Self { + self.cursor = cursor; + self + } +} + impl RequestParamsMeta for PaginatedRequestParams { fn meta(&self) -> Option<&Meta> { self.meta.as_ref() @@ -920,6 +1062,30 @@ pub struct ProgressNotificationParam { pub message: Option, } +impl ProgressNotificationParam { + /// Create a new ProgressNotificationParam with required fields. + pub fn new(progress_token: ProgressToken, progress: f64) -> Self { + Self { + progress_token, + progress, + total: None, + message: None, + } + } + + /// Set the total number of items to process. + pub fn with_total(mut self, total: f64) -> Self { + self.total = Some(total); + self + } + + /// Set a message describing the current progress. + pub fn with_message(mut self, message: impl Into) -> Self { + self.message = Some(message.into()); + self + } +} + pub type ProgressNotification = Notification; pub type Cursor = String; @@ -980,6 +1146,7 @@ const_string!(ReadResourceRequestMethod = "resources/read"); #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ReadResourceRequestParams { /// Protocol-level metadata for this request (SEP-1319) #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] @@ -988,6 +1155,22 @@ pub struct ReadResourceRequestParams { pub uri: String, } +impl ReadResourceRequestParams { + /// Create a new ReadResourceRequestParams with the given URI. + pub fn new(uri: impl Into) -> Self { + Self { + meta: None, + uri: uri.into(), + } + } + + /// Set the metadata for this request. + pub fn with_meta(mut self, meta: Meta) -> Self { + self.meta = Some(meta); + self + } +} + impl RequestParamsMeta for ReadResourceRequestParams { fn meta(&self) -> Option<&Meta> { self.meta.as_ref() @@ -1004,11 +1187,19 @@ pub type ReadResourceRequestParam = ReadResourceRequestParams; /// Result containing the contents of a read resource #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ReadResourceResult { /// The actual content of the resource pub contents: Vec, } +impl ReadResourceResult { + /// Create a new ReadResourceResult with the given contents. + pub fn new(contents: Vec) -> Self { + Self { contents } + } +} + /// Request to read a specific resource pub type ReadResourceRequest = Request; @@ -1022,6 +1213,7 @@ const_string!(SubscribeRequestMethod = "resources/subscribe"); #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct SubscribeRequestParams { /// Protocol-level metadata for this request (SEP-1319) #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] @@ -1030,6 +1222,16 @@ pub struct SubscribeRequestParams { pub uri: String, } +impl SubscribeRequestParams { + /// Create a new SubscribeRequestParams. + pub fn new(uri: impl Into) -> Self { + Self { + meta: None, + uri: uri.into(), + } + } +} + impl RequestParamsMeta for SubscribeRequestParams { fn meta(&self) -> Option<&Meta> { self.meta.as_ref() @@ -1051,6 +1253,7 @@ const_string!(UnsubscribeRequestMethod = "resources/unsubscribe"); #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct UnsubscribeRequestParams { /// Protocol-level metadata for this request (SEP-1319) #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] @@ -1084,6 +1287,14 @@ pub struct ResourceUpdatedNotificationParam { /// The URI of the resource that was updated pub uri: String, } + +impl ResourceUpdatedNotificationParam { + /// Create a new ResourceUpdatedNotificationParam. + pub fn new(uri: impl Into) -> Self { + Self { uri: uri.into() } + } +} + /// Notification sent when a subscribed resource is updated pub type ResourceUpdatedNotification = Notification; @@ -1103,9 +1314,10 @@ paginated_result!(ListPromptsResult { const_string!(GetPromptRequestMethod = "prompts/get"); /// Parameters for retrieving a specific prompt -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct GetPromptRequestParams { /// Protocol-level metadata for this request (SEP-1319) #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] @@ -1115,6 +1327,29 @@ pub struct GetPromptRequestParams { pub arguments: Option, } +impl GetPromptRequestParams { + /// Create a new `GetPromptRequestParams` with the given prompt name. + pub fn new(name: impl Into) -> Self { + Self { + meta: None, + name: name.into(), + arguments: None, + } + } + + /// Set the arguments for this prompt request. + pub fn with_arguments(mut self, arguments: JsonObject) -> Self { + self.arguments = Some(arguments); + self + } + + /// Set the metadata for this request. + pub fn with_meta(mut self, meta: Meta) -> Self { + self.meta = Some(meta); + self + } +} + impl RequestParamsMeta for GetPromptRequestParams { fn meta(&self) -> Option<&Meta> { self.meta.as_ref() @@ -1163,6 +1398,7 @@ const_string!(SetLevelRequestMethod = "logging/setLevel"); #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct SetLevelRequestParams { /// Protocol-level metadata for this request (SEP-1319) #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] @@ -1171,6 +1407,13 @@ pub struct SetLevelRequestParams { pub level: LoggingLevel, } +impl SetLevelRequestParams { + /// Create a new SetLevelRequestParams with the given logging level. + pub fn new(level: LoggingLevel) -> Self { + Self { meta: None, level } + } +} + impl RequestParamsMeta for SetLevelRequestParams { fn meta(&self) -> Option<&Meta> { self.meta.as_ref() @@ -1201,6 +1444,27 @@ pub struct LoggingMessageNotificationParam { /// The actual log data pub data: Value, } + +impl LoggingMessageNotificationParam { + /// Create a new LoggingMessageNotificationParam. + pub fn new(level: LoggingLevel, data: Value) -> Self { + Self { + level, + logger: None, + data, + } + } + + /// Create with a logger name. + pub fn with_logger(level: LoggingLevel, logger: impl Into, data: Value) -> Self { + Self { + level, + logger: Some(logger.into()), + data, + } + } +} + /// Notification containing a log message pub type LoggingMessageNotification = Notification; @@ -1249,6 +1513,7 @@ impl Default for ToolChoiceMode { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ToolChoice { #[serde(skip_serializing_if = "Option::is_none")] pub mode: Option, @@ -1379,6 +1644,7 @@ impl From> for SamplingContent { /// for generating appropriate responses. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct SamplingMessage { /// The role of the message sender (User or Assistant) pub role: Role, @@ -1539,9 +1805,10 @@ pub enum ContextInclusion { /// /// This implements `TaskAugmentedRequestParamsMeta` as sampling requests can be /// long-running and may benefit from task-based execution. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct CreateMessageRequestParams { /// Protocol-level metadata for this request (SEP-1319) #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] @@ -1598,6 +1865,72 @@ impl TaskAugmentedRequestParamsMeta for CreateMessageRequestParams { } impl CreateMessageRequestParams { + /// Create a new CreateMessageRequestParams with required fields. + pub fn new(messages: Vec, max_tokens: u32) -> Self { + Self { + meta: None, + task: None, + messages, + model_preferences: None, + system_prompt: None, + include_context: None, + temperature: None, + max_tokens, + stop_sequences: None, + metadata: None, + tools: None, + tool_choice: None, + } + } + + /// Set model preferences. + pub fn with_model_preferences(mut self, model_preferences: ModelPreferences) -> Self { + self.model_preferences = Some(model_preferences); + self + } + + /// Set system prompt. + pub fn with_system_prompt(mut self, system_prompt: impl Into) -> Self { + self.system_prompt = Some(system_prompt.into()); + self + } + + /// Set include context. + pub fn with_include_context(mut self, include_context: ContextInclusion) -> Self { + self.include_context = Some(include_context); + self + } + + /// Set temperature. + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = Some(temperature); + self + } + + /// Set stop sequences. + pub fn with_stop_sequences(mut self, stop_sequences: Vec) -> Self { + self.stop_sequences = Some(stop_sequences); + self + } + + /// Set metadata. + pub fn with_metadata(mut self, metadata: Value) -> Self { + self.metadata = Some(metadata); + self + } + + /// Set tools. + pub fn with_tools(mut self, tools: Vec) -> Self { + self.tools = Some(tools); + self + } + + /// Set tool choice. + pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self { + self.tool_choice = Some(tool_choice); + self + } + /// Validate the sampling request parameters per SEP-1577 spec requirements. /// /// Checks: @@ -1688,6 +2021,7 @@ pub type CreateMessageRequestParam = CreateMessageRequestParams; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ModelPreferences { /// Specific model names or families to prefer (e.g., "claude", "gpt") #[serde(skip_serializing_if = "Option::is_none")] @@ -1703,18 +2037,70 @@ pub struct ModelPreferences { pub intelligence_priority: Option, } +impl ModelPreferences { + /// Create a new default ModelPreferences. + pub fn new() -> Self { + Self { + hints: None, + cost_priority: None, + speed_priority: None, + intelligence_priority: None, + } + } + + /// Set hints for model selection. + pub fn with_hints(mut self, hints: Vec) -> Self { + self.hints = Some(hints); + self + } + + /// Set cost priority (0.0 to 1.0). + pub fn with_cost_priority(mut self, cost_priority: f32) -> Self { + self.cost_priority = Some(cost_priority); + self + } + + /// Set speed priority (0.0 to 1.0). + pub fn with_speed_priority(mut self, speed_priority: f32) -> Self { + self.speed_priority = Some(speed_priority); + self + } + + /// Set intelligence priority (0.0 to 1.0). + pub fn with_intelligence_priority(mut self, intelligence_priority: f32) -> Self { + self.intelligence_priority = Some(intelligence_priority); + self + } +} + +impl Default for ModelPreferences { + fn default() -> Self { + Self::new() + } +} + /// A hint suggesting a preferred model name or family. /// /// Model hints are advisory suggestions that help clients choose appropriate /// models. They can be specific model names or general families like "claude" or "gpt". -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ModelHint { /// The suggested model name or family identifier #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, } +impl ModelHint { + /// Create a new ModelHint with a name. + pub fn new(name: impl Into) -> Self { + Self { + name: Some(name.into()), + } + } +} + // ============================================================================= // COMPLETION AND AUTOCOMPLETE // ============================================================================= @@ -1768,6 +2154,7 @@ impl CompletionContext { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct CompleteRequestParams { /// Protocol-level metadata for this request (SEP-1319) #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] @@ -1779,6 +2166,24 @@ pub struct CompleteRequestParams { pub context: Option, } +impl CompleteRequestParams { + /// Create a new CompleteRequestParams with required fields. + pub fn new(r#ref: Reference, argument: ArgumentInfo) -> Self { + Self { + meta: None, + r#ref, + argument, + context: None, + } + } + + /// Set the completion context + pub fn with_context(mut self, context: CompletionContext) -> Self { + self.context = Some(context); + self + } +} + impl RequestParamsMeta for CompleteRequestParams { fn meta(&self) -> Option<&Meta> { self.meta.as_ref() @@ -1875,10 +2280,18 @@ impl CompletionInfo { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct CompleteResult { pub completion: CompletionInfo, } +impl CompleteResult { + /// Create a new CompleteResult with the given completion info. + pub fn new(completion: CompletionInfo) -> Self { + Self { completion } + } +} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(tag = "type")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] @@ -1939,6 +2352,7 @@ pub struct ResourceReference { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct PromptReference { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -1960,6 +2374,7 @@ pub struct ArgumentInfo { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct Root { pub uri: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -1972,6 +2387,7 @@ pub type ListRootsRequest = RequestNoParam; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ListRootsResult { pub roots: Vec, } @@ -2180,18 +2596,45 @@ pub struct CreateElicitationResult { pub content: Option, } +impl CreateElicitationResult { + /// Create a new CreateElicitationResult. + pub fn new(action: ElicitationAction) -> Self { + Self { + action, + content: None, + } + } + + /// Create with content. + pub fn with_content(action: ElicitationAction, content: Value) -> Self { + Self { + action, + content: Some(content), + } + } +} + /// Request type for creating an elicitation to gather user input pub type CreateElicitationRequest = Request; /// Notification parameters for an url elicitation completion notification. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct ElicitationResponseNotificationParam { pub elicitation_id: String, } +impl ElicitationResponseNotificationParam { + /// Create a new ElicitationResponseNotificationParam. + pub fn new(elicitation_id: impl Into) -> Self { + Self { + elicitation_id: elicitation_id.into(), + } + } +} + /// Notification sent when an url elicitation process is completed. pub type ElicitationCompletionNotification = Notification; @@ -2204,9 +2647,10 @@ pub type ElicitationCompletionNotification = /// /// Contains the content returned by the tool execution and an optional /// flag indicating whether the operation resulted in an error. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct CallToolResult { /// The content returned by the tool (text, images, etc.) pub content: Vec, @@ -2289,6 +2733,12 @@ impl CallToolResult { } } + /// Set the metadata on this result + pub fn with_meta(mut self, meta: Option) -> Self { + self.meta = meta; + self + } + /// Convert the `structured_content` part of response into a certain type. /// /// # About json schema validation @@ -2336,9 +2786,10 @@ const_string!(CallToolRequestMethod = "tools/call"); /// /// This implements `TaskAugmentedRequestParamsMeta` as tool calls can be /// long-running and may benefit from task-based execution. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct CallToolRequestParams { /// Protocol-level metadata for this request (SEP-1319) #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] @@ -2353,6 +2804,30 @@ pub struct CallToolRequestParams { pub task: Option, } +impl CallToolRequestParams { + /// Creates a new `CallToolRequestParams` with the given tool name. + pub fn new(name: impl Into>) -> Self { + Self { + meta: None, + name: name.into(), + arguments: None, + task: None, + } + } + + /// Sets the arguments for this tool call. + pub fn with_arguments(mut self, arguments: JsonObject) -> Self { + self.arguments = Some(arguments); + self + } + + /// Sets the task metadata for this tool call. + pub fn with_task(mut self, task: Option) -> Self { + self.task = task; + self + } +} + impl RequestParamsMeta for CallToolRequestParams { fn meta(&self) -> Option<&Meta> { self.meta.as_ref() @@ -2386,6 +2861,7 @@ pub type CallToolRequest = Request #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct CreateMessageResult { /// The identifier of the model that generated the response pub model: String, @@ -2398,11 +2874,32 @@ pub struct CreateMessageResult { } impl CreateMessageResult { + /// Create a new CreateMessageResult with required fields. + pub fn new(message: SamplingMessage, model: String) -> Self { + Self { + message, + model, + stop_reason: None, + } + } + pub const STOP_REASON_END_TURN: &str = "endTurn"; pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence"; pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens"; pub const STOP_REASON_TOOL_USE: &str = "toolUse"; + /// Set the stop reason. + pub fn with_stop_reason(mut self, stop_reason: Option) -> Self { + self.stop_reason = stop_reason; + self + } + + /// Set the model identifier. + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = model.into(); + self + } + /// Validate the result per SEP-1577: role must be "assistant". pub fn validate(&self) -> Result<(), String> { if self.message.role != Role::Assistant { @@ -2412,15 +2909,32 @@ impl CreateMessageResult { } } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct GetPromptResult { #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, pub messages: Vec, } +impl GetPromptResult { + /// Create a new GetPromptResult with required fields. + pub fn new(messages: Vec) -> Self { + Self { + description: None, + messages, + } + } + + /// Set the description + pub fn with_description>(mut self, description: D) -> Self { + self.description = Some(description.into()); + self + } +} + // ============================================================================= // TASK MANAGEMENT // ============================================================================= @@ -2512,6 +3026,7 @@ pub type GetTaskInfoResult = GetTaskResult; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ListTasksResult { pub tasks: Vec, #[serde(skip_serializing_if = "Option::is_none")] @@ -2520,6 +3035,17 @@ pub struct ListTasksResult { pub total: Option, } +impl ListTasksResult { + /// Create a new ListTasksResult. + pub fn new(tasks: Vec) -> Self { + Self { + tasks, + next_cursor: None, + total: None, + } + } +} + // ============================================================================= // MESSAGE TYPE UNIONS // ============================================================================= diff --git a/crates/rmcp/src/model/annotated.rs b/crates/rmcp/src/model/annotated.rs index f9921146a..9158e10be 100644 --- a/crates/rmcp/src/model/annotated.rs +++ b/crates/rmcp/src/model/annotated.rs @@ -11,6 +11,7 @@ use super::{ #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct Annotations { #[serde(skip_serializing_if = "Option::is_none")] pub audience: Option>, diff --git a/crates/rmcp/src/model/capabilities.rs b/crates/rmcp/src/model/capabilities.rs index e5716acca..b47a8a849 100644 --- a/crates/rmcp/src/model/capabilities.rs +++ b/crates/rmcp/src/model/capabilities.rs @@ -243,6 +243,7 @@ pub struct SamplingCapability { /// ``` #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ClientCapabilities { #[serde(skip_serializing_if = "Option::is_none")] pub experimental: Option, @@ -280,6 +281,7 @@ pub struct ClientCapabilities { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ServerCapabilities { #[serde(skip_serializing_if = "Option::is_none")] pub experimental: Option, diff --git a/crates/rmcp/src/model/content.rs b/crates/rmcp/src/model/content.rs index beb4d9f5d..83658b023 100644 --- a/crates/rmcp/src/model/content.rs +++ b/crates/rmcp/src/model/content.rs @@ -38,6 +38,17 @@ pub struct RawEmbeddedResource { pub meta: Option, pub resource: ResourceContents, } + +impl RawEmbeddedResource { + /// Create a new RawEmbeddedResource. + pub fn new(resource: ResourceContents) -> Self { + Self { + meta: None, + resource, + } + } +} + pub type EmbeddedResource = Annotated; impl EmbeddedResource { @@ -63,6 +74,7 @@ pub type AudioContent = Annotated; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ToolUseContent { /// Unique identifier for this tool call pub id: String, @@ -79,6 +91,7 @@ pub struct ToolUseContent { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ToolResultContent { /// Optional metadata #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] diff --git a/crates/rmcp/src/model/elicitation_schema.rs b/crates/rmcp/src/model/elicitation_schema.rs index 5e7506e49..cdbb87d6d 100644 --- a/crates/rmcp/src/model/elicitation_schema.rs +++ b/crates/rmcp/src/model/elicitation_schema.rs @@ -89,6 +89,7 @@ pub enum StringFormat { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[serde(rename_all = "camelCase")] +#[non_exhaustive] pub struct StringSchema { /// Type discriminator #[serde(rename = "type")] @@ -237,6 +238,7 @@ impl StringSchema { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[serde(rename_all = "camelCase")] +#[non_exhaustive] pub struct NumberSchema { /// Type discriminator #[serde(rename = "type")] @@ -444,6 +446,7 @@ impl IntegerSchema { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[serde(rename_all = "camelCase")] +#[non_exhaustive] pub struct BooleanSchema { /// Type discriminator #[serde(rename = "type")] @@ -513,6 +516,16 @@ pub struct ConstTitle { pub title: String, } +impl ConstTitle { + /// Create a new ConstTitle. + pub fn new(const_: impl Into, title: impl Into) -> Self { + Self { + const_: const_.into(), + title: title.into(), + } + } +} + /// Legacy enum schema, keep for backward compatibility #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -533,6 +546,7 @@ pub struct LegacyEnumSchema { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct UntitledSingleSelectEnumSchema { #[serde(rename = "type")] pub type_: StringTypeConst, @@ -550,6 +564,7 @@ pub struct UntitledSingleSelectEnumSchema { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct TitledSingleSelectEnumSchema { #[serde(rename = "type")] pub type_: StringTypeConst, @@ -563,6 +578,19 @@ pub struct TitledSingleSelectEnumSchema { pub default: Option, } +impl TitledSingleSelectEnumSchema { + /// Create a new TitledSingleSelectEnumSchema. + pub fn new(one_of: Vec) -> Self { + Self { + type_: StringTypeConst, + title: None, + description: None, + one_of, + default: None, + } + } +} + /// Combined single-select #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] @@ -592,10 +620,18 @@ pub struct TitledItems { pub any_of: Vec, } +impl TitledItems { + /// Create a new TitledItems. + pub fn new(any_of: Vec) -> Self { + Self { any_of } + } +} + /// Multi-select untitled options #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[serde(rename_all = "camelCase")] +#[non_exhaustive] pub struct UntitledMultiSelectEnumSchema { #[serde(rename = "type")] pub type_: ArrayTypeConst, @@ -616,6 +652,7 @@ pub struct UntitledMultiSelectEnumSchema { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[serde(rename_all = "camelCase")] +#[non_exhaustive] pub struct TitledMultiSelectEnumSchema { #[serde(rename = "type")] pub type_: ArrayTypeConst, @@ -632,6 +669,51 @@ pub struct TitledMultiSelectEnumSchema { pub default: Option>, } +impl TitledMultiSelectEnumSchema { + /// Create a new TitledMultiSelectEnumSchema. + pub fn new(items: TitledItems) -> Self { + Self { + type_: ArrayTypeConst, + title: None, + description: None, + min_items: None, + max_items: None, + items, + default: None, + } + } + + /// Set the title. + pub fn with_title(mut self, title: impl Into>) -> Self { + self.title = Some(title.into()); + self + } + + /// Set the description. + pub fn with_description(mut self, description: impl Into>) -> Self { + self.description = Some(description.into()); + self + } + + /// Set the minimum number of items. + pub fn with_min_items(mut self, min_items: u64) -> Self { + self.min_items = Some(min_items); + self + } + + /// Set the maximum number of items. + pub fn with_max_items(mut self, max_items: u64) -> Self { + self.max_items = Some(max_items); + self + } + + /// Set the default values. + pub fn with_default(mut self, default: Vec) -> Self { + self.default = Some(default); + self + } +} + /// Multi-select enum options #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] diff --git a/crates/rmcp/src/model/prompt.rs b/crates/rmcp/src/model/prompt.rs index f90aff199..4d491d0e1 100644 --- a/crates/rmcp/src/model/prompt.rs +++ b/crates/rmcp/src/model/prompt.rs @@ -7,9 +7,10 @@ use super::{ }; /// A prompt that can be used to generate text from a model -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct Prompt { /// The name of the prompt pub name: String, @@ -49,11 +50,46 @@ impl Prompt { meta: None, } } + + /// Create a new prompt from raw fields (used by the macro) + pub fn from_raw( + name: impl Into, + description: Option>, + arguments: Option>, + ) -> Self { + Prompt { + name: name.into(), + title: None, + description: description.map(Into::into), + arguments, + icons: None, + meta: None, + } + } + + /// Set the human-readable title + pub fn with_title(mut self, title: Option) -> Self { + self.title = title; + self + } + + /// Set the icons + pub fn with_icons(mut self, icons: Option>) -> Self { + self.icons = icons; + self + } + + /// Set the metadata + pub fn with_meta(mut self, meta: Option) -> Self { + self.meta = meta; + self + } } /// Represents a prompt argument that can be passed to customize the prompt -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct PromptArgument { /// The name of the argument pub name: String, @@ -68,6 +104,36 @@ pub struct PromptArgument { pub required: Option, } +impl PromptArgument { + /// Create a new prompt argument + pub fn new>(name: N) -> Self { + PromptArgument { + name: name.into(), + title: None, + description: None, + required: None, + } + } + + /// Set the title + pub fn with_title>(mut self, title: T) -> Self { + self.title = Some(title.into()); + self + } + + /// Set the description + pub fn with_description>(mut self, description: D) -> Self { + self.description = Some(description.into()); + self + } + + /// Set the required flag + pub fn with_required(mut self, required: bool) -> Self { + self.required = Some(required); + self + } +} + /// Represents the role of a message sender in a prompt conversation #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -112,6 +178,7 @@ impl PromptMessageContent { /// A message in a prompt conversation #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct PromptMessage { /// The role of the message sender pub role: PromptMessageRole, @@ -120,6 +187,11 @@ pub struct PromptMessage { } impl PromptMessage { + /// Create a new prompt message with the given role and content + pub fn new(role: PromptMessageRole, content: PromptMessageContent) -> Self { + Self { role, content } + } + /// Create a new text message with the given role and text content pub fn new_text>(role: PromptMessageRole, text: S) -> Self { Self { diff --git a/crates/rmcp/src/model/resource.rs b/crates/rmcp/src/model/resource.rs index cf3a1071f..8a25e25ba 100644 --- a/crates/rmcp/src/model/resource.rs +++ b/crates/rmcp/src/model/resource.rs @@ -80,6 +80,7 @@ pub enum ResourceContents { } impl ResourceContents { + /// Create text resource contents. pub fn text(text: impl Into, uri: impl Into) -> Self { Self::TextResourceContents { uri: uri.into(), @@ -88,6 +89,34 @@ impl ResourceContents { meta: None, } } + + /// Create blob resource contents. + pub fn blob(blob: impl Into, uri: impl Into) -> Self { + Self::BlobResourceContents { + uri: uri.into(), + mime_type: None, + blob: blob.into(), + meta: None, + } + } + + /// Set the MIME type on this resource contents. + pub fn with_mime_type(mut self, mime_type: impl Into) -> Self { + match &mut self { + Self::TextResourceContents { mime_type: mt, .. } => *mt = Some(mime_type.into()), + Self::BlobResourceContents { mime_type: mt, .. } => *mt = Some(mime_type.into()), + } + self + } + + /// Set the metadata on this resource contents. + pub fn with_meta(mut self, meta: Meta) -> Self { + match &mut self { + Self::TextResourceContents { meta: m, .. } => *m = Some(meta), + Self::BlobResourceContents { meta: m, .. } => *m = Some(meta), + } + self + } } impl RawResource { @@ -104,6 +133,80 @@ impl RawResource { meta: None, } } + + /// Set the human-readable title. + pub fn with_title(mut self, title: impl Into) -> Self { + self.title = Some(title.into()); + self + } + + /// Set the description. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + /// Set the MIME type. + pub fn with_mime_type(mut self, mime_type: impl Into) -> Self { + self.mime_type = Some(mime_type.into()); + self + } + + /// Set the size in bytes. + pub fn with_size(mut self, size: u32) -> Self { + self.size = Some(size); + self + } + + /// Set the icons. + pub fn with_icons(mut self, icons: Vec) -> Self { + self.icons = Some(icons); + self + } + + /// Set the metadata. + pub fn with_meta(mut self, meta: Meta) -> Self { + self.meta = Some(meta); + self + } +} + +impl RawResourceTemplate { + /// Creates a new RawResourceTemplate with a URI template and name. + pub fn new(uri_template: impl Into, name: impl Into) -> Self { + Self { + uri_template: uri_template.into(), + name: name.into(), + title: None, + description: None, + mime_type: None, + icons: None, + } + } + + /// Set the human-readable title. + pub fn with_title(mut self, title: impl Into) -> Self { + self.title = Some(title.into()); + self + } + + /// Set the description. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + /// Set the MIME type. + pub fn with_mime_type(mut self, mime_type: impl Into) -> Self { + self.mime_type = Some(mime_type.into()); + self + } + + /// Set the icons. + pub fn with_icons(mut self, icons: Vec) -> Self { + self.icons = Some(icons); + self + } } #[cfg(test)] diff --git a/crates/rmcp/src/model/task.rs b/crates/rmcp/src/model/task.rs index a18ed0c59..8373aa243 100644 --- a/crates/rmcp/src/model/task.rs +++ b/crates/rmcp/src/model/task.rs @@ -28,6 +28,7 @@ pub enum TaskStatus { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct Task { /// Unique task identifier generated by the receiver. pub task_id: String, @@ -48,14 +49,60 @@ pub struct Task { pub poll_interval: Option, } +impl Task { + /// Create a new Task with required fields. + pub fn new( + task_id: String, + status: TaskStatus, + created_at: String, + last_updated_at: String, + ) -> Self { + Self { + task_id, + status, + status_message: None, + created_at, + last_updated_at, + ttl: None, + poll_interval: None, + } + } + + /// Set the status message. + pub fn with_status_message(mut self, status_message: impl Into) -> Self { + self.status_message = Some(status_message.into()); + self + } + + /// Set the TTL in milliseconds. `None` means unlimited retention. + pub fn with_ttl(mut self, ttl: u64) -> Self { + self.ttl = Some(ttl); + self + } + + /// Set the poll interval in milliseconds. + pub fn with_poll_interval(mut self, poll_interval: u64) -> Self { + self.poll_interval = Some(poll_interval); + self + } +} + /// Wrapper returned by task-augmented requests (CreateTaskResult in SEP-1686). #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct CreateTaskResult { pub task: Task, } +impl CreateTaskResult { + /// Create a new CreateTaskResult. + pub fn new(task: Task) -> Self { + Self { task } + } +} + /// Response to a `tasks/get` request. /// /// Per spec, `GetTaskResult = allOf[Result, Task]` — the Task fields are @@ -78,8 +125,16 @@ pub struct GetTaskResult { /// serialized as a JSON value. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct GetTaskPayloadResult(pub Value); +impl GetTaskPayloadResult { + /// Create a new GetTaskPayloadResult with the given value. + pub fn new(value: Value) -> Self { + Self(value) + } +} + /// Response to a `tasks/cancel` request. /// /// Per spec, `CancelTaskResult = allOf[Result, Task]` — same shape as `GetTaskResult`. @@ -104,3 +159,14 @@ pub struct TaskList { #[serde(skip_serializing_if = "Option::is_none")] pub total: Option, } + +impl TaskList { + /// Create a new TaskList. + pub fn new(tasks: Vec) -> Self { + Self { + tasks, + next_cursor: None, + total: None, + } + } +} diff --git a/crates/rmcp/src/model/tool.rs b/crates/rmcp/src/model/tool.rs index 9732faca1..82b762de3 100644 --- a/crates/rmcp/src/model/tool.rs +++ b/crates/rmcp/src/model/tool.rs @@ -10,9 +10,10 @@ use serde_json::Value; use super::{Icon, JsonObject, Meta}; /// A tool that can be used by a model. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct Tool { /// The name of the tool pub name: Cow<'static, str>, @@ -67,6 +68,7 @@ pub enum TaskSupport { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ToolExecution { /// Indicates whether this tool supports task-based invocation. /// @@ -83,6 +85,11 @@ impl ToolExecution { Self::default() } + /// Create a ToolExecution from raw optional fields. + pub fn from_raw(task_support: Option) -> Self { + Self { task_support } + } + /// Set the task support mode. pub fn with_task_support(mut self, task_support: TaskSupport) -> Self { self.task_support = Some(task_support); @@ -101,6 +108,7 @@ impl ToolExecution { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[non_exhaustive] pub struct ToolAnnotations { /// A human-readable title for the tool. #[serde(skip_serializing_if = "Option::is_none")] @@ -145,6 +153,24 @@ impl ToolAnnotations { pub fn new() -> Self { Self::default() } + + /// Create a new ToolAnnotations with all fields specified + pub fn from_raw( + title: Option, + read_only_hint: Option, + destructive_hint: Option, + idempotent_hint: Option, + open_world_hint: Option, + ) -> Self { + ToolAnnotations { + title, + read_only_hint, + destructive_hint, + idempotent_hint, + open_world_hint, + } + } + pub fn with_title(title: T) -> Self where T: Into, @@ -211,6 +237,59 @@ impl Tool { } } + /// Create a new tool with just a name and input schema (no description) + pub fn new_with_raw( + name: N, + description: Option>, + input_schema: S, + ) -> Self + where + N: Into>, + S: Into>, + { + Tool { + name: name.into(), + title: None, + description, + input_schema: input_schema.into(), + output_schema: None, + annotations: None, + execution: None, + icons: None, + meta: None, + } + } + + /// Set the human-readable title + pub fn with_title(mut self, title: Option) -> Self { + self.title = title; + self + } + + /// Set the output schema from a raw value + pub fn with_raw_output_schema(mut self, output_schema: Option>) -> Self { + self.output_schema = output_schema; + self + } + + /// Set the annotations + pub fn with_annotations(mut self, annotations: Option) -> Self { + self.annotations = annotations; + self + } + + /// Set the icons + pub fn with_icons(mut self, icons: Option>) -> Self { + self.icons = icons; + self + } + + /// Set the metadata + pub fn with_meta(mut self, meta: Option) -> Self { + self.meta = meta; + self + } + pub fn annotate(self, annotations: ToolAnnotations) -> Self { Tool { annotations: Some(annotations), @@ -219,11 +298,9 @@ impl Tool { } /// Set the execution configuration for this tool. - pub fn with_execution(self, execution: ToolExecution) -> Self { - Tool { - execution: Some(execution), - ..self - } + pub fn with_execution(mut self, execution: Option) -> Self { + self.execution = execution; + self } /// Returns the task support mode for this tool. diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index b12839c6f..d6613dd3c 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -566,6 +566,7 @@ impl RunningServiceCancellationToken { } #[derive(Debug)] +#[non_exhaustive] pub enum QuitReason { Cancelled, Closed, @@ -584,6 +585,19 @@ pub struct RequestContext { pub peer: Peer, } +impl RequestContext { + /// Create a new RequestContext. + pub fn new(id: RequestId, peer: Peer) -> Self { + Self { + ct: CancellationToken::new(), + id, + meta: Meta::default(), + extensions: Extensions::default(), + peer, + } + } +} + /// Request execution context #[derive(Debug, Clone)] pub struct NotificationContext { diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 837fafeff..6528e4144 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -25,6 +25,7 @@ use crate::{ /// /// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result, ClientError>` #[derive(Error, Debug)] +#[non_exhaustive] pub enum ClientInitializeError { #[error("expect initialized response, but received: {0:?}")] ExpectedInitResponse(Option), diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 5f54f3dcd..666a79980 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -47,6 +47,7 @@ impl ServiceRole for RoleServer { /// /// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result, ServerError>` #[derive(Error, Debug)] +#[non_exhaustive] pub enum ServerInitializeError { #[error("expect initialized request, but received: {0:?}")] ExpectedInitializeRequest(Option), @@ -457,6 +458,7 @@ impl Peer { /// Errors that can occur during typed elicitation operations #[cfg(feature = "elicitation")] #[derive(Error, Debug)] +#[non_exhaustive] pub enum ElicitationError { /// The elicitation request failed at the service level #[error("Service error: {0}")] @@ -808,6 +810,7 @@ impl Peer { /// ElicitationAction::Cancel => { /// println!("User cancelled/dismissed the request"); /// } + /// _ => {} /// } /// Ok(()) /// } @@ -858,6 +861,7 @@ impl Peer { /// ElicitationAction::Cancel => { /// println!("User cancelled/dismissed the request"); /// } + /// _ => {} /// } /// Ok(()) /// } diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 2236590ab..6578b5c3b 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -266,6 +266,7 @@ impl AuthClient { /// Auth error #[derive(Debug, Error)] +#[non_exhaustive] pub enum AuthError { #[error("OAuth authorization required")] AuthorizationRequired, diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 779dfe1c5..85915c976 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -44,6 +44,7 @@ impl InsufficientScopeError { } #[derive(Error, Debug)] +#[non_exhaustive] pub enum StreamableHttpError { #[error("SSE error: {0}")] Sse(#[from] SseError), @@ -81,12 +82,14 @@ pub enum StreamableHttpError { } #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum StreamableHttpProtocolError { #[error("Missing session id in response")] MissingSessionIdInResponse, } #[allow(clippy::large_enum_variant)] +#[non_exhaustive] pub enum StreamableHttpPostResponse { Accepted, Json(ServerJsonRpcMessage, Option), diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index 6e197b5b8..cad533802 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -688,6 +688,7 @@ pub enum SessionEvent { } #[derive(Debug, Clone)] +#[non_exhaustive] pub enum SessionQuitReason { ServiceTerminated, ClientTerminated, @@ -880,6 +881,7 @@ pub type SessionTransport = WorkerTransport; #[allow(clippy::large_enum_variant)] #[derive(Debug, Error)] +#[non_exhaustive] pub enum LocalSessionWorkerError { #[error("transport terminated")] TransportTerminated, diff --git a/crates/rmcp/src/transport/worker.rs b/crates/rmcp/src/transport/worker.rs index 769d448a5..d7c53afd4 100644 --- a/crates/rmcp/src/transport/worker.rs +++ b/crates/rmcp/src/transport/worker.rs @@ -7,6 +7,7 @@ use super::{IntoTransport, Transport}; use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum WorkerQuitReason { #[error("Join error {0}")] Join(#[from] tokio::task::JoinError), diff --git a/crates/rmcp/tests/common/calculator.rs b/crates/rmcp/tests/common/calculator.rs index 5b8cebf7a..22c6d38ef 100644 --- a/crates/rmcp/tests/common/calculator.rs +++ b/crates/rmcp/tests/common/calculator.rs @@ -53,10 +53,7 @@ impl Calculator { impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_instructions("A simple calculator") } } diff --git a/crates/rmcp/tests/common/handlers.rs b/crates/rmcp/tests/common/handlers.rs index 654413fa8..2084981a2 100644 --- a/crates/rmcp/tests/common/handlers.rs +++ b/crates/rmcp/tests/common/handlers.rs @@ -3,15 +3,17 @@ use std::{ sync::{Arc, Mutex}, }; -use rmcp::{ - ClientHandler, ErrorData as McpError, RoleClient, RoleServer, ServerHandler, - model::*, - service::{NotificationContext, RequestContext}, -}; +#[cfg(feature = "client")] +use rmcp::service::NotificationContext; +#[cfg(feature = "client")] +use rmcp::{ClientHandler, RoleClient}; +use rmcp::{ErrorData as McpError, RoleServer, ServerHandler, model::*, service::RequestContext}; +#[cfg(feature = "client")] use serde_json::json; use tokio::sync::Notify; #[derive(Clone)] +#[allow(dead_code)] pub struct TestClientHandler { pub honor_this_server: bool, pub honor_all_servers: bool, @@ -46,6 +48,7 @@ impl TestClientHandler { } } +#[cfg(feature = "client")] impl ClientHandler for TestClientHandler { async fn create_message( &self, @@ -71,11 +74,11 @@ impl ClientHandler for TestClientHandler { _ => "Test response without context", }; - Ok(CreateMessageResult { - message: SamplingMessage::assistant_text(response.to_string()), - model: "test-model".to_string(), - stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), - }) + Ok(CreateMessageResult::new( + SamplingMessage::assistant_text(response.to_string()), + "test-model".to_string(), + ) + .with_stop_reason(Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()))) } fn on_logging_message( @@ -106,10 +109,7 @@ impl TestServer { impl ServerHandler for TestServer { fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder().enable_logging().build(), - ..Default::default() - } + ServerInfo::new(ServerCapabilities::builder().enable_logging().build()) } fn set_level( diff --git a/crates/rmcp/tests/test_client_initialization.rs b/crates/rmcp/tests/test_client_initialization.rs index fed6eceed..c9b8f94a2 100644 --- a/crates/rmcp/tests/test_client_initialization.rs +++ b/crates/rmcp/tests/test_client_initialization.rs @@ -1,4 +1,6 @@ // cargo test --features "server client" --package rmcp test_client_initialization +#![cfg(feature = "client")] + mod common; use std::borrow::Cow; diff --git a/crates/rmcp/tests/test_completion.rs b/crates/rmcp/tests/test_completion.rs index bd563cadf..694ae4d9a 100644 --- a/crates/rmcp/tests/test_completion.rs +++ b/crates/rmcp/tests/test_completion.rs @@ -52,15 +52,14 @@ fn test_complete_request_param_serialization() { let mut args = HashMap::new(); args.insert("previous_input".to_string(), "test".to_string()); - let request = CompleteRequestParams { - meta: None, - r#ref: Reference::for_prompt("weather_prompt"), - argument: ArgumentInfo { + let request = CompleteRequestParams::new( + Reference::for_prompt("weather_prompt"), + ArgumentInfo { name: "location".to_string(), value: "San".to_string(), }, - context: Some(CompletionContext::with_arguments(args)), - }; + ) + .with_context(CompletionContext::with_arguments(args)); let json = serde_json::to_value(&request).unwrap(); assert!(json["ref"]["name"].as_str().unwrap() == "weather_prompt"); @@ -196,15 +195,13 @@ fn test_completion_context_empty() { #[test] fn test_mcp_schema_compliance() { // Test that our types serialize correctly according to MCP specification - let request = CompleteRequestParams { - meta: None, - r#ref: Reference::for_resource("file://{path}"), - argument: ArgumentInfo { + let request = CompleteRequestParams::new( + Reference::for_resource("file://{path}"), + ArgumentInfo { name: "path".to_string(), value: "src/".to_string(), }, - context: None, - }; + ); let json_str = serde_json::to_string(&request).unwrap(); let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap(); diff --git a/crates/rmcp/tests/test_custom_headers.rs b/crates/rmcp/tests/test_custom_headers.rs index 82537a80c..b83c85772 100644 --- a/crates/rmcp/tests/test_custom_headers.rs +++ b/crates/rmcp/tests/test_custom_headers.rs @@ -730,10 +730,7 @@ async fn test_server_rejects_unsupported_protocol_version() { impl ServerHandler for TestHandler { fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder().build(), - ..Default::default() - } + ServerInfo::new(ServerCapabilities::builder().build()) } } diff --git a/crates/rmcp/tests/test_elicitation.rs b/crates/rmcp/tests/test_elicitation.rs index ce8be280e..7d946a2bf 100644 --- a/crates/rmcp/tests/test_elicitation.rs +++ b/crates/rmcp/tests/test_elicitation.rs @@ -141,15 +141,13 @@ async fn test_elicitation_json_rpc_protocol() { let request = JsonRpcRequest { jsonrpc: JsonRpcVersion2_0, id: RequestId::Number(1), - request: CreateElicitationRequest { - method: ElicitationCreateRequestMethod, - params: CreateElicitationRequestParams::FormElicitationParams { + request: CreateElicitationRequest::new( + CreateElicitationRequestParams::FormElicitationParams { meta: None, message: "Do you want to continue?".to_string(), requested_schema: schema, }, - extensions: Default::default(), - }, + ), }; // Test serialization of complete request @@ -710,30 +708,24 @@ async fn test_elicitation_multi_select_enum() { assert_eq!( schema, &EnumSchema::Multi(MultiSelectEnumSchema::Titled( - TitledMultiSelectEnumSchema { - type_: ArrayTypeConst, - title: None, - description: None, - min_items: Some(1), - max_items: Some(2), - items: TitledItems { - any_of: vec![ - ConstTitle { - const_: "A".to_string(), - title: "A name".to_string() - }, - ConstTitle { - const_: "B".to_string(), - title: "B name".to_string() - }, - ConstTitle { - const_: "C".to_string(), - title: "C name".to_string() - } - ], - }, - default: None - } + TitledMultiSelectEnumSchema::new(TitledItems { + any_of: vec![ + ConstTitle { + const_: "A".to_string(), + title: "A name".to_string() + }, + ConstTitle { + const_: "B".to_string(), + title: "B name".to_string() + }, + ConstTitle { + const_: "C".to_string(), + title: "C name".to_string() + }, + ], + }) + .with_min_items(1) + .with_max_items(2) )) ) } @@ -789,26 +781,20 @@ async fn test_elicitation_single_select_enum() { assert_eq!( schema, &EnumSchema::Single(SingleSelectEnumSchema::Titled( - TitledSingleSelectEnumSchema { - type_: StringTypeConst, - title: None, - description: None, - one_of: vec![ - ConstTitle { - const_: "A".to_string(), - title: "A name".to_string() - }, - ConstTitle { - const_: "B".to_string(), - title: "B name".to_string() - }, - ConstTitle { - const_: "C".to_string(), - title: "C name".to_string() - } - ], - default: None - } + TitledSingleSelectEnumSchema::new(vec![ + ConstTitle { + const_: "A".to_string(), + title: "A name".to_string() + }, + ConstTitle { + const_: "B".to_string(), + title: "B name".to_string() + }, + ConstTitle { + const_: "C".to_string(), + title: "C name".to_string() + } + ]) )) ) } @@ -850,11 +836,8 @@ async fn test_elicitation_direction_server_to_client() { assert_eq!(serialized["requestedSchema"]["type"], "object"); // Test that elicitation requests are part of ServerRequest - let _server_request = ServerRequest::CreateElicitationRequest(CreateElicitationRequest { - method: ElicitationCreateRequestMethod, - params: elicitation_request, - extensions: Default::default(), - }); + let _server_request = + ServerRequest::CreateElicitationRequest(CreateElicitationRequest::new(elicitation_request)); // Test that client can respond with elicitation results let client_result = ClientResult::CreateElicitationResult(CreateElicitationResult { @@ -889,15 +872,13 @@ async fn test_elicitation_json_rpc_direction() { // 1. Server creates elicitation request let server_request = ServerJsonRpcMessage::request( - ServerRequest::CreateElicitationRequest(CreateElicitationRequest { - method: ElicitationCreateRequestMethod, - params: CreateElicitationRequestParams::FormElicitationParams { + ServerRequest::CreateElicitationRequest(CreateElicitationRequest::new( + CreateElicitationRequestParams::FormElicitationParams { meta: None, message: "Do you want to continue?".to_string(), requested_schema: schema, }, - extensions: Default::default(), - }), + )), RequestId::Number(1), ); @@ -1051,15 +1032,14 @@ async fn test_elicitation_capability_structure() { #[tokio::test] async fn test_client_capabilities_with_elicitation() { // Test ClientCapabilities with elicitation capability - let capabilities = ClientCapabilities { - elicitation: Some(ElicitationCapability { + let capabilities = ClientCapabilities::builder() + .enable_elicitation_with(ElicitationCapability { form: Some(FormElicitationCapability { schema_validation: Some(true), }), url: None, - }), - ..Default::default() - }; + }) + .build(); // Verify elicitation capability is present assert!(capabilities.elicitation.is_some()); @@ -1084,10 +1064,7 @@ async fn test_client_capabilities_with_elicitation() { ); // Test ClientCapabilities without elicitation - let capabilities_without = ClientCapabilities { - elicitation: None, - ..Default::default() - }; + let capabilities_without = ClientCapabilities::default(); assert!(capabilities_without.elicitation.is_none()); } @@ -1096,27 +1073,17 @@ async fn test_client_capabilities_with_elicitation() { #[tokio::test] async fn test_initialize_request_with_elicitation() { // Test InitializeRequestParam with elicitation capability - let init_param = InitializeRequestParams { - meta: None, - protocol_version: ProtocolVersion::LATEST, - capabilities: ClientCapabilities { - elicitation: Some(ElicitationCapability { + let init_param = InitializeRequestParams::new( + ClientCapabilities::builder() + .enable_elicitation_with(ElicitationCapability { form: Some(FormElicitationCapability { schema_validation: Some(true), }), url: None, - }), - ..Default::default() - }, - client_info: Implementation { - name: "test-client".to_string(), - version: "1.0.0".to_string(), - title: None, - description: None, - website_url: None, - icons: None, - }, - }; + }) + .build(), + Implementation::new("test-client", "1.0.0"), + ); // Verify the structure assert!(init_param.capabilities.elicitation.is_some()); @@ -1148,49 +1115,27 @@ async fn test_capability_checking_logic() { // Simulate the logic that would be used in supports_elicitation() // Case 1: Client with elicitation capability - let client_with_capability = InitializeRequestParams { - meta: None, - protocol_version: ProtocolVersion::LATEST, - capabilities: ClientCapabilities { - elicitation: Some(ElicitationCapability { + let client_with_capability = InitializeRequestParams::new( + ClientCapabilities::builder() + .enable_elicitation_with(ElicitationCapability { form: Some(FormElicitationCapability { schema_validation: Some(true), }), url: None, - }), - ..Default::default() - }, - client_info: Implementation { - name: "test-client".to_string(), - version: "1.0.0".to_string(), - title: None, - description: None, - website_url: None, - icons: None, - }, - }; + }) + .build(), + Implementation::new("test-client", "1.0.0"), + ); // Simulate supports_elicitation() logic let supports_elicitation = client_with_capability.capabilities.elicitation.is_some(); assert!(supports_elicitation); // Case 2: Client without elicitation capability - let client_without_capability = InitializeRequestParams { - meta: None, - protocol_version: ProtocolVersion::LATEST, - capabilities: ClientCapabilities { - elicitation: None, - ..Default::default() - }, - client_info: Implementation { - name: "test-client".to_string(), - version: "1.0.0".to_string(), - title: None, - description: None, - website_url: None, - icons: None, - }, - }; + let client_without_capability = InitializeRequestParams::new( + ClientCapabilities::default(), + Implementation::new("test-client", "1.0.0"), + ); let supports_elicitation = client_without_capability.capabilities.elicitation.is_some(); assert!(!supports_elicitation); } @@ -1910,16 +1855,14 @@ async fn test_url_elicitation_json_rpc_protocol() { let request = JsonRpcRequest { jsonrpc: JsonRpcVersion2_0, id: RequestId::Number(1), - request: CreateElicitationRequest { - method: ElicitationCreateRequestMethod, - params: CreateElicitationRequestParams::UrlElicitationParams { + request: CreateElicitationRequest::new( + CreateElicitationRequestParams::UrlElicitationParams { meta: None, message: "Please authorize this action at the following URL".to_string(), url: "https://auth.example.com/authorize/abc123".to_string(), elicitation_id: "auth-request-456".to_string(), }, - extensions: Default::default(), - }, + ), }; // Test serialization of complete request @@ -1977,11 +1920,7 @@ async fn test_elicitation_completion_notification() { assert_eq!(deserialized.elicitation_id, "elicit-789"); // Test complete notification structure - let notification = ElicitationCompletionNotification { - method: ElicitationCompletionNotificationMethod, - params: notification_params, - extensions: Default::default(), - }; + let notification = ElicitationCompletionNotification::new(notification_params); let json = serde_json::to_value(¬ification).unwrap(); assert_eq!(json["method"], "notifications/elicitation/complete"); @@ -2142,15 +2081,14 @@ async fn test_url_elicitation_required_error_code() { #[tokio::test] async fn test_client_capabilities_elicitation_modes() { // Test with form-only capability - let form_only_caps = ClientCapabilities { - elicitation: Some(ElicitationCapability { + let form_only_caps = ClientCapabilities::builder() + .enable_elicitation_with(ElicitationCapability { form: Some(FormElicitationCapability { schema_validation: Some(true), }), url: None, - }), - ..Default::default() - }; + }) + .build(); let json = serde_json::to_value(&form_only_caps).unwrap(); assert!(json["elicitation"]["form"].is_object()); @@ -2160,13 +2098,12 @@ async fn test_client_capabilities_elicitation_modes() { ); // Test with URL-only capability - let url_only_caps = ClientCapabilities { - elicitation: Some(ElicitationCapability { + let url_only_caps = ClientCapabilities::builder() + .enable_elicitation_with(ElicitationCapability { form: None, url: Some(UrlElicitationCapability::default()), - }), - ..Default::default() - }; + }) + .build(); let json = serde_json::to_value(&url_only_caps).unwrap(); assert!(json["elicitation"]["url"].is_object()); @@ -2179,15 +2116,14 @@ async fn test_client_capabilities_elicitation_modes() { ); // Test with both capabilities - let both_caps = ClientCapabilities { - elicitation: Some(ElicitationCapability { + let both_caps = ClientCapabilities::builder() + .enable_elicitation_with(ElicitationCapability { form: Some(FormElicitationCapability { schema_validation: Some(false), }), url: Some(UrlElicitationCapability::default()), - }), - ..Default::default() - }; + }) + .build(); let json = serde_json::to_value(&both_caps).unwrap(); assert!(json["elicitation"]["form"].is_object()); @@ -2201,11 +2137,8 @@ async fn test_elicitation_completion_in_server_notification() { elicitation_id: "notify-123".to_string(), }; - let completion_notification = ElicitationCompletionNotification { - method: ElicitationCompletionNotificationMethod, - params: notification_param.clone(), - extensions: Default::default(), - }; + let completion_notification = + ElicitationCompletionNotification::new(notification_param.clone()); // Test that it's part of ServerNotification let server_notification = diff --git a/crates/rmcp/tests/test_handler_wrappers.rs b/crates/rmcp/tests/test_handler_wrappers.rs index e1faddc91..18ec242ba 100644 --- a/crates/rmcp/tests/test_handler_wrappers.rs +++ b/crates/rmcp/tests/test_handler_wrappers.rs @@ -4,8 +4,8 @@ mod common; use std::sync::Arc; -use common::handlers::{TestClientHandler, TestServer}; -use rmcp::{ClientHandler, ServerHandler}; +use common::handlers::TestServer; +use rmcp::ServerHandler; #[test] fn test_wrapped_server_handlers() { @@ -16,8 +16,11 @@ fn test_wrapped_server_handlers() { accepts_server_handler(Arc::new(TestServer::new())); } +#[cfg(feature = "client")] #[test] fn test_wrapped_client_handlers() { + use common::handlers::TestClientHandler; + use rmcp::ClientHandler; // This test asserts that, when T: ClientHandler, both Box and Arc also implement ClientHandler. fn accepts_client_handler(_handler: H) {} diff --git a/crates/rmcp/tests/test_logging.rs b/crates/rmcp/tests/test_logging.rs index be63b24fb..11efd84c9 100644 --- a/crates/rmcp/tests/test_logging.rs +++ b/crates/rmcp/tests/test_logging.rs @@ -63,7 +63,7 @@ async fn test_logging_spec_compliance() -> anyhow::Result<()> { ] { client .peer() - .set_level(SetLevelRequestParams { meta: None, level }) + .set_level(SetLevelRequestParams::new(level)) .await?; // Wait for each message response @@ -121,10 +121,7 @@ async fn test_logging_user_scenarios() -> anyhow::Result<()> { // Test 1: Error reporting scenario client .peer() - .set_level(SetLevelRequestParams { - meta: None, - level: LoggingLevel::Error, - }) + .set_level(SetLevelRequestParams::new(LoggingLevel::Error)) .await?; receive_signal.notified().await; // Wait for response { @@ -148,10 +145,7 @@ async fn test_logging_user_scenarios() -> anyhow::Result<()> { // Test 2: Debug scenario client .peer() - .set_level(SetLevelRequestParams { - meta: None, - level: LoggingLevel::Debug, - }) + .set_level(SetLevelRequestParams::new(LoggingLevel::Debug)) .await?; receive_signal.notified().await; // Wait for response { @@ -172,10 +166,7 @@ async fn test_logging_user_scenarios() -> anyhow::Result<()> { // Test 3: Production monitoring scenario client .peer() - .set_level(SetLevelRequestParams { - meta: None, - level: LoggingLevel::Info, - }) + .set_level(SetLevelRequestParams::new(LoggingLevel::Info)) .await?; receive_signal.notified().await; // Wait for response { @@ -259,7 +250,7 @@ async fn test_logging_edge_cases() -> anyhow::Result<()> { ] { client .peer() - .set_level(SetLevelRequestParams { meta: None, level }) + .set_level(SetLevelRequestParams::new(level)) .await?; receive_signal.notified().await; @@ -319,7 +310,7 @@ async fn test_logging_optional_fields() -> anyhow::Result<()> { for level in [LoggingLevel::Info, LoggingLevel::Debug] { client .peer() - .set_level(SetLevelRequestParams { meta: None, level }) + .set_level(SetLevelRequestParams::new(level)) .await?; // Wait for each message response diff --git a/crates/rmcp/tests/test_message_protocol.rs b/crates/rmcp/tests/test_message_protocol.rs index 7ec3258c0..073486ff9 100644 --- a/crates/rmcp/tests/test_message_protocol.rs +++ b/crates/rmcp/tests/test_message_protocol.rs @@ -39,24 +39,10 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { let client = handler.clone().serve(client_transport).await?; // Test ThisServer context inclusion - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("test message")], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new(vec![SamplingMessage::user_text("test message")], 100) + .with_include_context(ContextInclusion::ThisServer), + )); let result = handler .handle_request( @@ -90,24 +76,10 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { } // Test AllServers context inclusion - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("test message")], - include_context: Some(ContextInclusion::AllServers), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new(vec![SamplingMessage::user_text("test message")], 100) + .with_include_context(ContextInclusion::AllServers), + )); let result = handler .handle_request( @@ -141,24 +113,10 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { } // Test No context inclusion - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("test message")], - include_context: Some(ContextInclusion::None), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new(vec![SamplingMessage::user_text("test message")], 100) + .with_include_context(ContextInclusion::None), + )); let result = handler .handle_request( @@ -212,24 +170,10 @@ async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { let client = handler.clone().serve(client_transport).await?; // Test that context requests are ignored - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("test message")], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new(vec![SamplingMessage::user_text("test message")], 100) + .with_include_context(ContextInclusion::ThisServer), + )); let result = handler .handle_request( @@ -282,27 +226,16 @@ async fn test_message_sequence_integration() -> anyhow::Result<()> { let handler = TestClientHandler::new(true, true); let client = handler.clone().serve(client_transport).await?; - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![ + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new( + vec![ SamplingMessage::user_text("first message"), SamplingMessage::assistant_text("second message"), ], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + 100, + ) + .with_include_context(ContextInclusion::ThisServer), + )); let result = handler .handle_request( @@ -359,28 +292,16 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { let client = handler.clone().serve(client_transport).await?; // Test valid sequence: User -> Assistant -> User - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![ + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new( + vec![ SamplingMessage::user_text("first user message"), SamplingMessage::assistant_text("first assistant response"), SamplingMessage::user_text("second user message"), ], - include_context: None, - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + 100, + ), + )); let result = handler .handle_request( @@ -398,24 +319,12 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { assert!(matches!(result, ClientResult::CreateMessageResult(_))); // Test invalid: No user message - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::assistant_text("assistant message")], - include_context: None, - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new( + vec![SamplingMessage::assistant_text("assistant message")], + 100, + ), + )); let result = handler .handle_request( @@ -452,24 +361,10 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { let client = handler.clone().serve(client_transport).await?; // Test ThisServer is honored - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("test message")], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new(vec![SamplingMessage::user_text("test message")], 100) + .with_include_context(ContextInclusion::ThisServer), + )); let result = handler .handle_request( @@ -501,24 +396,10 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { } // Test AllServers is ignored - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("test message")], - include_context: Some(ContextInclusion::AllServers), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new(vec![SamplingMessage::user_text("test message")], 100) + .with_include_context(ContextInclusion::AllServers), + )); let result = handler .handle_request( @@ -567,24 +448,10 @@ async fn test_context_inclusion() -> anyhow::Result<()> { let client = handler.clone().serve(client_transport).await?; // Test context handling - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("test")], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new(vec![SamplingMessage::user_text("test")], 100) + .with_include_context(ContextInclusion::ThisServer), + )); let result = handler .handle_request( diff --git a/crates/rmcp/tests/test_notification.rs b/crates/rmcp/tests/test_notification.rs index 018374212..7d930678e 100644 --- a/crates/rmcp/tests/test_notification.rs +++ b/crates/rmcp/tests/test_notification.rs @@ -15,14 +15,13 @@ pub struct Server {} impl ServerHandler for Server { fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder() + ServerInfo::new( + ServerCapabilities::builder() .enable_resources() .enable_resources_subscribe() .enable_resources_list_changed() .build(), - ..Default::default() - } + ) } async fn subscribe( @@ -87,10 +86,7 @@ async fn test_server_notification() -> anyhow::Result<()> { .serve(client_transport) .await?; client - .subscribe(SubscribeRequestParams { - meta: None, - uri: "test://test-resource".to_owned(), - }) + .subscribe(SubscribeRequestParams::new("test://test-resource")) .await?; receive_signal.notified().await; client.cancel().await?; diff --git a/crates/rmcp/tests/test_progress_subscriber.rs b/crates/rmcp/tests/test_progress_subscriber.rs index 521219a3b..092f35747 100644 --- a/crates/rmcp/tests/test_progress_subscriber.rs +++ b/crates/rmcp/tests/test_progress_subscriber.rs @@ -107,12 +107,9 @@ async fn test_progress_subscriber() -> anyhow::Result<()> { let client_service = client.serve(transport_client).await?; let handle = client_service .send_cancellable_request( - ClientRequest::CallToolRequest(Request::new(CallToolRequestParams { - meta: None, - name: "some_progress".into(), - arguments: None, - task: None, - })), + ClientRequest::CallToolRequest(Request::new(CallToolRequestParams::new( + "some_progress", + ))), PeerRequestOptions::no_options(), ) .await?; diff --git a/crates/rmcp/tests/test_prompt_macro_annotations.rs b/crates/rmcp/tests/test_prompt_macro_annotations.rs index f313927f5..caa017936 100644 --- a/crates/rmcp/tests/test_prompt_macro_annotations.rs +++ b/crates/rmcp/tests/test_prompt_macro_annotations.rs @@ -109,13 +109,11 @@ async fn complex_args_prompt( _server: &TestServer, _args: Parameters, ) -> GetPromptResult { - GetPromptResult { - description: Some("Complex args result".to_string()), - messages: vec![PromptMessage::new_text( - PromptMessageRole::Assistant, - "Complex response", - )], - } + GetPromptResult::new(vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Complex response", + )]) + .with_description("Complex args result") } // Test sync prompt diff --git a/crates/rmcp/tests/test_prompt_macros.rs b/crates/rmcp/tests/test_prompt_macros.rs index 2407571d7..a41d2e7e5 100644 --- a/crates/rmcp/tests/test_prompt_macros.rs +++ b/crates/rmcp/tests/test_prompt_macros.rs @@ -107,22 +107,20 @@ impl GenericServer { #[prompt(description = "Get contextual help from the service")] async fn get_help(&self) -> GetPromptResult { let context = self.data_service.get_context(); - GetPromptResult { - description: Some("Contextual help based on service data".to_string()), - messages: vec![ - PromptMessage::new_text( - PromptMessageRole::User, - "I need help with the current context.".to_string(), - ), - PromptMessage::new_text( - PromptMessageRole::Assistant, - format!( - "Based on the context '{}', here's how I can help...", - context - ), + GetPromptResult::new(vec![ + PromptMessage::new_text( + PromptMessageRole::User, + "I need help with the current context.".to_string(), + ), + PromptMessage::new_text( + PromptMessageRole::Assistant, + format!( + "Based on the context '{}', here's how I can help...", + context ), - ], - } + ), + ]) + .with_description("Contextual help based on service data") } } @@ -250,13 +248,11 @@ impl OptionalSchemaTester { None => "Received null count".to_string(), }; - GetPromptResult { - description: Some("Test result for optional i64".to_string()), - messages: vec![PromptMessage::new_text( - PromptMessageRole::Assistant, - message, - )], - } + GetPromptResult::new(vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + message, + )]) + .with_description("Test result for optional i64") } } @@ -327,10 +323,8 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Test null case let result = client - .get_prompt(GetPromptRequestParams { - meta: None, - name: "test_optional_i64".into(), - arguments: Some( + .get_prompt( + GetPromptRequestParams::new("test_optional_i64").with_arguments( serde_json::json!({ "count": null, "mandatory_field": "test_null" @@ -339,7 +333,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), - }) + ) .await?; let result_text = match &result.messages.first().unwrap().content { @@ -354,10 +348,8 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Test Some case let some_result = client - .get_prompt(GetPromptRequestParams { - meta: None, - name: "test_optional_i64".into(), - arguments: Some( + .get_prompt( + GetPromptRequestParams::new("test_optional_i64").with_arguments( serde_json::json!({ "count": 42, "mandatory_field": "test_some" @@ -366,7 +358,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), - }) + ) .await?; let some_result_text = match &some_result.messages.first().unwrap().content { diff --git a/crates/rmcp/tests/test_prompt_routers.rs b/crates/rmcp/tests/test_prompt_routers.rs index 0917a7f1d..53b13b131 100644 --- a/crates/rmcp/tests/test_prompt_routers.rs +++ b/crates/rmcp/tests/test_prompt_routers.rs @@ -64,13 +64,11 @@ async fn async_function(Parameters(Request { fields }): Parameters) -> #[rmcp::prompt] fn async_function2(_callee: &TestHandler) -> BoxFuture<'_, GetPromptResult> { Box::pin(async move { - GetPromptResult { - description: Some("Async function 2".to_string()), - messages: vec![PromptMessage::new_text( - PromptMessageRole::Assistant, - "Async function 2 response", - )], - } + GetPromptResult::new(vec![PromptMessage::new_text( + PromptMessageRole::Assistant, + "Async function 2 response", + )]) + .with_description("Async function 2") }) } diff --git a/crates/rmcp/tests/test_resource_link_integration.rs b/crates/rmcp/tests/test_resource_link_integration.rs index ab6635258..7507d71e0 100644 --- a/crates/rmcp/tests/test_resource_link_integration.rs +++ b/crates/rmcp/tests/test_resource_link_integration.rs @@ -82,10 +82,10 @@ fn test_resource_link_roundtrip() { } // Test with prompt message - let prompt_message = PromptMessage { - role: PromptMessageRole::User, - content: PromptMessageContent::resource_link(resource.no_annotation()), - }; + let prompt_message = PromptMessage::new( + PromptMessageRole::User, + PromptMessageContent::resource_link(resource.no_annotation()), + ); let prompt_json = serde_json::to_string(&prompt_message).unwrap(); let prompt_deserialized: PromptMessage = serde_json::from_str(&prompt_json).unwrap(); diff --git a/crates/rmcp/tests/test_sampling.rs b/crates/rmcp/tests/test_sampling.rs index e5191d3c1..d885e46ce 100644 --- a/crates/rmcp/tests/test_sampling.rs +++ b/crates/rmcp/tests/test_sampling.rs @@ -23,27 +23,20 @@ async fn test_basic_sampling_message_creation() -> Result<()> { #[tokio::test] async fn test_sampling_request_params() -> Result<()> { - let params = CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("Hello, world!")], - model_preferences: Some(ModelPreferences { - hints: Some(vec![ModelHint { - name: Some("claude".to_string()), - }]), - cost_priority: Some(0.5), - speed_priority: Some(0.8), - intelligence_priority: Some(0.7), - }), - system_prompt: Some("You are a helpful assistant.".to_string()), - temperature: Some(0.7), - max_tokens: 100, - stop_sequences: Some(vec!["STOP".to_string()]), - include_context: Some(ContextInclusion::None), - metadata: Some(serde_json::json!({"test": "value"})), - tools: None, - tool_choice: None, - }; + let params = + CreateMessageRequestParams::new(vec![SamplingMessage::user_text("Hello, world!")], 100) + .with_model_preferences( + ModelPreferences::new() + .with_hints(vec![ModelHint::new("claude")]) + .with_cost_priority(0.5) + .with_speed_priority(0.8) + .with_intelligence_priority(0.7), + ) + .with_system_prompt("You are a helpful assistant.") + .with_temperature(0.7) + .with_stop_sequences(vec!["STOP".to_string()]) + .with_include_context(ContextInclusion::None) + .with_metadata(serde_json::json!({"test": "value"})); let json = serde_json::to_string(¶ms)?; let deserialized: CreateMessageRequestParams = serde_json::from_str(&json)?; @@ -58,11 +51,11 @@ async fn test_sampling_request_params() -> Result<()> { #[tokio::test] async fn test_sampling_result_structure() -> Result<()> { - let result = CreateMessageResult { - message: SamplingMessage::assistant_text("The capital of France is Paris."), - model: "test-model".to_string(), - stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), - }; + let result = CreateMessageResult::new( + SamplingMessage::assistant_text("The capital of France is Paris."), + "test-model".to_string(), + ) + .with_stop_reason(Some(CreateMessageResult::STOP_REASON_END_TURN.to_string())); let json = serde_json::to_string(&result)?; let deserialized: CreateMessageResult = serde_json::from_str(&json)?; @@ -112,31 +105,22 @@ async fn test_sampling_integration_with_test_handlers() -> Result<()> { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("What is the capital of France?")], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: Some(ModelPreferences { - hints: Some(vec![ModelHint { - name: Some("test-model".to_string()), - }]), - cost_priority: Some(0.5), - speed_priority: Some(0.8), - intelligence_priority: Some(0.7), - }), - system_prompt: Some("You are a helpful assistant.".to_string()), - temperature: Some(0.7), - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new( + vec![SamplingMessage::user_text("What is the capital of France?")], + 100, + ) + .with_include_context(ContextInclusion::ThisServer) + .with_model_preferences( + ModelPreferences::new() + .with_hints(vec![ModelHint::new("test-model")]) + .with_cost_priority(0.5) + .with_speed_priority(0.8) + .with_intelligence_priority(0.7), + ) + .with_system_prompt("You are a helpful assistant.") + .with_temperature(0.7), + )); let result = handler .handle_request( @@ -196,24 +180,10 @@ async fn test_sampling_no_context_inclusion() -> Result<()> { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text("Hello")], - include_context: Some(ContextInclusion::None), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 50, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new(vec![SamplingMessage::user_text("Hello")], 50) + .with_include_context(ContextInclusion::None), + )); let result = handler .handle_request( @@ -269,26 +239,15 @@ async fn test_sampling_error_invalid_message_sequence() -> Result<()> { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::assistant_text( + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest::new( + CreateMessageRequestParams::new( + vec![SamplingMessage::assistant_text( "I'm an assistant message without a user message", )], - include_context: Some(ContextInclusion::None), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 50, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }, - extensions: Default::default(), - }); + 50, + ) + .with_include_context(ContextInclusion::None), + )); let result = handler .handle_request( @@ -357,22 +316,14 @@ async fn test_sampling_with_tools() -> Result<()> { ), ); - let params = CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text( + let params = CreateMessageRequestParams::new( + vec![SamplingMessage::user_text( "What's the weather in San Francisco?", )], - model_preferences: None, - system_prompt: None, - include_context: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: Some(vec![tool]), - tool_choice: Some(ToolChoice::auto()), - }; + 100, + ) + .with_tools(vec![tool]) + .with_tool_choice(ToolChoice::auto()); let json = serde_json::to_string(¶ms)?; let deserialized: CreateMessageRequestParams = serde_json::from_str(&json)?; @@ -472,8 +423,8 @@ async fn test_sampling_message_with_tool_result() -> Result<()> { #[tokio::test] async fn test_create_message_result_tool_use_stop_reason() -> Result<()> { - let result = CreateMessageResult { - message: SamplingMessage::assistant_tool_use( + let result = CreateMessageResult::new( + SamplingMessage::assistant_tool_use( "call_123", "get_weather", serde_json::json!({ @@ -483,9 +434,9 @@ async fn test_create_message_result_tool_use_stop_reason() -> Result<()> { .unwrap() .clone(), ), - model: "test-model".to_string(), - stop_reason: Some(CreateMessageResult::STOP_REASON_TOOL_USE.to_string()), - }; + "test-model".to_string(), + ) + .with_stop_reason(Some(CreateMessageResult::STOP_REASON_TOOL_USE.to_string())); let json = serde_json::to_string(&result)?; let deserialized: CreateMessageResult = serde_json::from_str(&json)?; @@ -623,23 +574,13 @@ async fn test_content_conversion_unsupported_variants() { #[tokio::test] async fn test_validate_rejects_tool_use_in_user_message() { - let params = CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::new( + let params = CreateMessageRequestParams::new( + vec![SamplingMessage::new( Role::User, SamplingMessageContent::tool_use("call_1", "some_tool", Default::default()), )], - model_preferences: None, - system_prompt: None, - include_context: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }; + 100, + ); let err = params.validate().unwrap_err(); assert!( @@ -650,23 +591,13 @@ async fn test_validate_rejects_tool_use_in_user_message() { #[tokio::test] async fn test_validate_rejects_tool_result_in_assistant_message() { - let params = CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::new( + let params = CreateMessageRequestParams::new( + vec![SamplingMessage::new( Role::Assistant, SamplingMessageContent::tool_result("call_1", vec![Content::text("result")]), )], - model_preferences: None, - system_prompt: None, - include_context: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }; + 100, + ); let err = params.validate().unwrap_err(); assert!( @@ -677,26 +608,16 @@ async fn test_validate_rejects_tool_result_in_assistant_message() { #[tokio::test] async fn test_validate_rejects_mixed_content_with_tool_result() { - let params = CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::new_multiple( + let params = CreateMessageRequestParams::new( + vec![SamplingMessage::new_multiple( Role::User, vec![ SamplingMessageContent::tool_result("call_1", vec![Content::text("result")]), SamplingMessageContent::text("some extra text"), ], )], - model_preferences: None, - system_prompt: None, - include_context: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }; + 100, + ); let err = params.validate().unwrap_err(); assert!( @@ -707,23 +628,13 @@ async fn test_validate_rejects_mixed_content_with_tool_result() { #[tokio::test] async fn test_validate_rejects_unbalanced_tool_use_result() { - let params = CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![ + let params = CreateMessageRequestParams::new( + vec![ SamplingMessage::user_text("Hello"), SamplingMessage::assistant_tool_use("call_1", "some_tool", Default::default()), ], - model_preferences: None, - system_prompt: None, - include_context: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }; + 100, + ); let err = params.validate().unwrap_err(); assert!( @@ -734,23 +645,13 @@ async fn test_validate_rejects_unbalanced_tool_use_result() { #[tokio::test] async fn test_validate_rejects_tool_result_without_matching_use() { - let params = CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![ + let params = CreateMessageRequestParams::new( + vec![ SamplingMessage::user_text("Hello"), SamplingMessage::user_tool_result("nonexistent_call", vec![Content::text("result")]), ], - model_preferences: None, - system_prompt: None, - include_context: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }; + 100, + ); let err = params.validate().unwrap_err(); assert!( @@ -761,10 +662,8 @@ async fn test_validate_rejects_tool_result_without_matching_use() { #[tokio::test] async fn test_validate_accepts_valid_tool_conversation() { - let params = CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![ + let params = CreateMessageRequestParams::new( + vec![ SamplingMessage::user_text("What's the weather?"), SamplingMessage::assistant_tool_use( "call_1", @@ -777,27 +676,19 @@ async fn test_validate_accepts_valid_tool_conversation() { SamplingMessage::user_tool_result("call_1", vec![Content::text("72°F and sunny")]), SamplingMessage::assistant_text("It's 72°F and sunny in SF."), ], - model_preferences: None, - system_prompt: None, - include_context: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }; + 100, + ); assert!(params.validate().is_ok()); } #[tokio::test] async fn test_create_message_result_validate_rejects_user_role() { - let result = CreateMessageResult { - message: SamplingMessage::user_text("This should not be a user message"), - model: "test-model".to_string(), - stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), - }; + let result = CreateMessageResult::new( + SamplingMessage::user_text("This should not be a user message"), + "test-model".to_string(), + ) + .with_stop_reason(Some(CreateMessageResult::STOP_REASON_END_TURN.to_string())); let err = result.validate().unwrap_err(); assert!( @@ -808,11 +699,11 @@ async fn test_create_message_result_validate_rejects_user_role() { #[tokio::test] async fn test_create_message_result_validate_accepts_assistant_role() { - let result = CreateMessageResult { - message: SamplingMessage::assistant_text("Hello!"), - model: "test-model".to_string(), - stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), - }; + let result = CreateMessageResult::new( + SamplingMessage::assistant_text("Hello!"), + "test-model".to_string(), + ) + .with_stop_reason(Some(CreateMessageResult::STOP_REASON_END_TURN.to_string())); assert!(result.validate().is_ok()); } diff --git a/crates/rmcp/tests/test_sse_concurrent_streams.rs b/crates/rmcp/tests/test_sse_concurrent_streams.rs index b54ed5562..33625a741 100644 --- a/crates/rmcp/tests/test_sse_concurrent_streams.rs +++ b/crates/rmcp/tests/test_sse_concurrent_streams.rs @@ -17,7 +17,7 @@ use std::time::Duration; use futures::StreamExt; use rmcp::{ RoleServer, ServerHandler, - model::{Implementation, ProtocolVersion, ServerCapabilities, ServerInfo, ToolsCapability}, + model::{Implementation, ServerCapabilities, ServerInfo, ToolsCapability}, service::NotificationContext, transport::streamable_http_server::{ StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, @@ -45,20 +45,14 @@ impl TestServer { impl ServerHandler for TestServer { fn get_info(&self) -> ServerInfo { - ServerInfo { - protocol_version: ProtocolVersion::LATEST, - capabilities: ServerCapabilities::builder() + ServerInfo::new( + ServerCapabilities::builder() .enable_tools_with(ToolsCapability { list_changed: Some(true), }) .build(), - server_info: Implementation { - name: "test-server".to_string(), - version: "1.0.0".to_string(), - ..Default::default() - }, - instructions: None, - } + ) + .with_server_info(Implementation::new("test-server", "1.0.0")) } async fn on_initialized(&self, context: NotificationContext) { diff --git a/crates/rmcp/tests/test_structured_output.rs b/crates/rmcp/tests/test_structured_output.rs index 0edb8bce8..082d3e439 100644 --- a/crates/rmcp/tests/test_structured_output.rs +++ b/crates/rmcp/tests/test_structured_output.rs @@ -327,12 +327,7 @@ async fn test_empty_content_deserializes_as_call_tool_result_variant() { #[tokio::test] async fn test_empty_content_roundtrip() { - let result = CallToolResult { - content: vec![], - structured_content: None, - is_error: Some(false), - meta: None, - }; + let result = CallToolResult::success(vec![]); let v = serde_json::to_value(&result).unwrap(); assert_eq!(v["content"], json!([])); let deserialized: CallToolResult = serde_json::from_value(v).unwrap(); diff --git a/crates/rmcp/tests/test_task_support_validation.rs b/crates/rmcp/tests/test_task_support_validation.rs index 016ed2403..cd9997684 100644 --- a/crates/rmcp/tests/test_task_support_validation.rs +++ b/crates/rmcp/tests/test_task_support_validation.rs @@ -5,6 +5,7 @@ //! - `Required`: MUST be invoked as a task, returns -32601 otherwise //! - `Forbidden`: MUST NOT be invoked as a task, returns error otherwise //! - `Optional`: MAY be invoked either way +#![cfg(feature = "client")] use rmcp::{ ClientHandler, ServerHandler, ServiceError, ServiceExt, @@ -93,12 +94,7 @@ async fn test_required_task_tool_without_task_returns_method_not_found() -> anyh // Call the task-required tool without a task - should fail with -32601 let result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "required_task_tool".into(), - arguments: None, - task: None, // No task provided! - }) + .call_tool(CallToolRequestParams::new("required_task_tool")) .await; // Should be an error with code -32601 (METHOD_NOT_FOUND) @@ -147,12 +143,7 @@ async fn test_forbidden_task_tool_with_task_returns_error() -> anyhow::Result<() // Call the forbidden task tool WITH a task - should fail let result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "forbidden_task_tool".into(), - arguments: None, - task: make_task(), // Task provided but forbidden! - }) + .call_tool(CallToolRequestParams::new("forbidden_task_tool").with_task(make_task())) .await; // Should be an error with code INVALID_PARAMS @@ -201,12 +192,7 @@ async fn test_forbidden_task_tool_without_task_succeeds() -> anyhow::Result<()> // Call the forbidden task tool WITHOUT a task - should succeed let result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "forbidden_task_tool".into(), - arguments: None, - task: None, // No task - allowed for forbidden - }) + .call_tool(CallToolRequestParams::new("forbidden_task_tool")) .await; assert!( @@ -242,12 +228,7 @@ async fn test_optional_task_tool_without_task_succeeds() -> anyhow::Result<()> { // Call the optional task tool WITHOUT a task - should succeed let result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "optional_task_tool".into(), - arguments: None, - task: None, // No task - allowed for optional - }) + .call_tool(CallToolRequestParams::new("optional_task_tool")) .await; assert!( diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 837198cbb..bd06ca6ea 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -308,10 +308,8 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Test null case let result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "test_optional_i64".into(), - arguments: Some( + .call_tool( + CallToolRequestParams::new("test_optional_i64").with_arguments( serde_json::json!({ "count": null, "mandatory_field": "test_null" @@ -320,8 +318,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), - task: None, - }) + ) .await?; let result_text = result @@ -338,10 +335,8 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Test Some case let some_result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "test_optional_i64".into(), - arguments: Some( + .call_tool( + CallToolRequestParams::new("test_optional_i64").with_arguments( serde_json::json!({ "count": 42, "mandatory_field": "test_some" @@ -350,8 +345,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), - task: None, - }) + ) .await?; let some_result_text = some_result diff --git a/crates/rmcp/tests/test_tool_result_meta.rs b/crates/rmcp/tests/test_tool_result_meta.rs index 78e1809ef..f64d8e3f1 100644 --- a/crates/rmcp/tests/test_tool_result_meta.rs +++ b/crates/rmcp/tests/test_tool_result_meta.rs @@ -6,12 +6,7 @@ fn serialize_tool_result_with_meta() { let content = vec![Content::text("ok")]; let mut meta = Meta::new(); meta.insert("foo".to_string(), json!("bar")); - let result = CallToolResult { - content, - structured_content: None, - is_error: Some(false), - meta: Some(meta), - }; + let result = CallToolResult::success(content).with_meta(Some(meta)); let v = serde_json::to_value(&result).unwrap(); let expected = json!({ "content": [{"type":"text","text":"ok"}], diff --git a/examples/clients/src/collection.rs b/examples/clients/src/collection.rs index a4c734824..62088e48d 100644 --- a/examples/clients/src/collection.rs +++ b/examples/clients/src/collection.rs @@ -46,12 +46,14 @@ async fn main() -> Result<()> { // Call tool 'git_status' with arguments = {"repo_path": "."} let _tool_result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "git_status".into(), - arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), - task: None, - }) + .call_tool( + CallToolRequestParams::new("git_status").with_arguments( + serde_json::json!({ "repo_path": "." }) + .as_object() + .unwrap() + .clone(), + ), + ) .await?; } for (_, service) in clients_map { diff --git a/examples/clients/src/everything_stdio.rs b/examples/clients/src/everything_stdio.rs index 763a880a6..8a7fce7de 100644 --- a/examples/clients/src/everything_stdio.rs +++ b/examples/clients/src/everything_stdio.rs @@ -37,23 +37,19 @@ async fn main() -> Result<()> { // Call tool echo let tool_result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "echo".into(), - arguments: Some(object!({ "message": "hi from rmcp" })), - task: None, - }) + .call_tool( + CallToolRequestParams::new("echo") + .with_arguments(object!({ "message": "hi from rmcp" })), + ) .await?; tracing::info!("Tool result for echo: {tool_result:#?}"); // Call tool longRunningOperation let tool_result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "longRunningOperation".into(), - arguments: Some(object!({ "duration": 3, "steps": 1 })), - task: None, - }) + .call_tool( + CallToolRequestParams::new("longRunningOperation") + .with_arguments(object!({ "duration": 3, "steps": 1 })), + ) .await?; tracing::info!("Tool result for longRunningOperation: {tool_result:#?}"); @@ -63,10 +59,7 @@ async fn main() -> Result<()> { // Read resource let resource = client - .read_resource(ReadResourceRequestParams { - meta: None, - uri: "test://static/resource/3".into(), - }) + .read_resource(ReadResourceRequestParams::new("test://static/resource/3")) .await?; tracing::info!("Resource: {resource:#?}"); @@ -76,21 +69,16 @@ async fn main() -> Result<()> { // Get simple prompt let prompt = client - .get_prompt(GetPromptRequestParams { - meta: None, - name: "simple_prompt".into(), - arguments: None, - }) + .get_prompt(GetPromptRequestParams::new("simple_prompt")) .await?; tracing::info!("Prompt - simple: {prompt:#?}"); // Get complex prompt (returns text & image) let prompt = client - .get_prompt(GetPromptRequestParams { - meta: None, - name: "complex_prompt".into(), - arguments: Some(object!({ "temperature": "0.5", "style": "formal" })), - }) + .get_prompt( + GetPromptRequestParams::new("complex_prompt") + .with_arguments(object!({ "temperature": "0.5", "style": "formal" })), + ) .await?; tracing::info!("Prompt - complex: {prompt:#?}"); diff --git a/examples/clients/src/git_stdio.rs b/examples/clients/src/git_stdio.rs index 9960c16b9..703258b16 100644 --- a/examples/clients/src/git_stdio.rs +++ b/examples/clients/src/git_stdio.rs @@ -39,12 +39,14 @@ async fn main() -> Result<(), RmcpError> { // Call tool 'git_status' with arguments = {"repo_path": "."} let tool_result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "git_status".into(), - arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), - task: None, - }) + .call_tool( + CallToolRequestParams::new("git_status").with_arguments( + serde_json::json!({ "repo_path": "." }) + .as_object() + .unwrap() + .clone(), + ), + ) .await?; tracing::info!("Tool result: {tool_result:#?}"); client.cancel().await?; diff --git a/examples/clients/src/progress_client.rs b/examples/clients/src/progress_client.rs index db66a8ed6..89c48738d 100644 --- a/examples/clients/src/progress_client.rs +++ b/examples/clients/src/progress_client.rs @@ -122,16 +122,10 @@ impl ClientHandler for ProgressAwareClient { } fn get_info(&self) -> ClientInfo { - ClientInfo { - meta: None, - protocol_version: Default::default(), - capabilities: ClientCapabilities::default(), - client_info: Implementation { - name: "progress-test-client".to_string(), - version: "1.0.0".to_string(), - ..Default::default() - }, - } + ClientInfo::new( + ClientCapabilities::default(), + Implementation::new("progress-test-client", "1.0.0"), + ) } } @@ -182,12 +176,7 @@ async fn test_stdio_transport(records: u32) -> Result<()> { // Call stream processor tool tracing::info!("Starting to process {} records...", records); let tool_result = service - .call_tool(CallToolRequestParams { - meta: None, - name: "stream_processor".into(), - arguments: None, - task: None, - }) + .call_tool(CallToolRequestParams::new("stream_processor")) .await?; if let Some(content) = tool_result.content.first() { @@ -238,12 +227,7 @@ async fn test_http_transport(http_url: &str, records: u32) -> Result<()> { // Call stream processor tool tracing::info!("Starting to process {} records...", records); let tool_result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "stream_processor".into(), - arguments: None, - task: None, - }) + .call_tool(CallToolRequestParams::new("stream_processor")) .await?; if let Some(content) = tool_result.content.first() { diff --git a/examples/clients/src/sampling_stdio.rs b/examples/clients/src/sampling_stdio.rs index e2a7a6d51..27b9273c0 100644 --- a/examples/clients/src/sampling_stdio.rs +++ b/examples/clients/src/sampling_stdio.rs @@ -40,11 +40,11 @@ impl ClientHandler for SamplingDemoClient { let response_text = self.mock_llm_response(¶ms.messages, params.system_prompt.as_deref()); - Ok(CreateMessageResult { - message: SamplingMessage::assistant_text(response_text), - model: "mock_llm".to_string(), - stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), - }) + Ok(CreateMessageResult::new( + SamplingMessage::assistant_text(response_text), + "mock_llm".to_string(), + ) + .with_stop_reason(Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()))) } } @@ -98,14 +98,11 @@ async fn main() -> Result<()> { // Test the ask_llm tool tracing::info!("Testing ask_llm tool..."); match client - .call_tool(CallToolRequestParams { - meta: None, - name: "ask_llm".into(), - arguments: Some(object!({ + .call_tool( + CallToolRequestParams::new("ask_llm").with_arguments(object!({ "question": "Hello world" })), - task: None, - }) + ) .await { Ok(result) => tracing::info!("Ask LLM result: {result:#?}"), diff --git a/examples/clients/src/streamable_http.rs b/examples/clients/src/streamable_http.rs index baf1838a3..0c27fd358 100644 --- a/examples/clients/src/streamable_http.rs +++ b/examples/clients/src/streamable_http.rs @@ -17,19 +17,10 @@ async fn main() -> Result<()> { .with(tracing_subscriber::fmt::layer()) .init(); let transport = StreamableHttpClientTransport::from_uri("http://localhost:8000/mcp"); - let client_info = ClientInfo { - meta: None, - protocol_version: Default::default(), - capabilities: ClientCapabilities::default(), - client_info: Implementation { - name: "test sse client".to_string(), - title: None, - version: "0.0.1".to_string(), - description: None, - website_url: None, - icons: None, - }, - }; + let client_info = ClientInfo::new( + ClientCapabilities::default(), + Implementation::new("test sse client", "0.0.1"), + ); let client = client_info.serve(transport).await.inspect_err(|e| { tracing::error!("client error: {:?}", e); })?; @@ -43,12 +34,10 @@ async fn main() -> Result<()> { tracing::info!("Available tools: {tools:#?}"); let tool_result = client - .call_tool(CallToolRequestParams { - meta: None, - name: "increment".into(), - arguments: serde_json::json!({}).as_object().cloned(), - task: None, - }) + .call_tool( + CallToolRequestParams::new("increment") + .with_arguments(serde_json::json!({}).as_object().cloned().unwrap()), + ) .await?; tracing::info!("Tool result: {tool_result:#?}"); client.cancel().await?; diff --git a/examples/rig-integration/src/mcp_adaptor.rs b/examples/rig-integration/src/mcp_adaptor.rs index af57935ee..41de15768 100644 --- a/examples/rig-integration/src/mcp_adaptor.rs +++ b/examples/rig-integration/src/mcp_adaptor.rs @@ -41,13 +41,11 @@ impl RigTool for McpToolAdaptor { let server = self.server.clone(); Box::pin(async move { let call_mcp_tool_result = server - .call_tool(CallToolRequestParams { - meta: None, - name: self.tool.name.clone(), - arguments: serde_json::from_str(&args) - .map_err(rig::tool::ToolError::JsonError)?, - task: None, - }) + .call_tool( + CallToolRequestParams::new(self.tool.name.clone()).with_arguments( + serde_json::from_str(&args).map_err(rig::tool::ToolError::JsonError)?, + ), + ) .await .inspect(|result| tracing::info!(?result)) .inspect_err(|error| tracing::error!(%error)) diff --git a/examples/servers/src/common/calculator.rs b/examples/servers/src/common/calculator.rs index e6f97ce0f..2b0ab8e33 100644 --- a/examples/servers/src/common/calculator.rs +++ b/examples/servers/src/common/calculator.rs @@ -49,10 +49,7 @@ impl Calculator { #[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_instructions("A simple calculator".to_string()) } } diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index e92b142af..1806a9fa3 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -171,10 +171,10 @@ impl Counter { "This is an example prompt with your message here: '{}'", args.message ); - Ok(vec![PromptMessage { - role: PromptMessageRole::User, - content: PromptMessageContent::text(prompt), - }]) + Ok(vec![PromptMessage::new_text( + PromptMessageRole::User, + prompt, + )]) } /// Analyze the current counter value and suggest next steps @@ -202,13 +202,10 @@ impl Counter { ), ]; - Ok(GetPromptResult { - description: Some(format!( - "Counter analysis for reaching {} from {}", - args.goal, current_value - )), - messages, - }) + Ok(GetPromptResult::new(messages).with_description(format!( + "Counter analysis for reaching {} from {}", + args.goal, current_value + ))) } } @@ -217,16 +214,16 @@ impl Counter { #[task_handler] impl ServerHandler for Counter { fn get_info(&self) -> ServerInfo { - ServerInfo { - protocol_version: ProtocolVersion::V_2024_11_05, - capabilities: ServerCapabilities::builder() + ServerInfo::new( + ServerCapabilities::builder() .enable_prompts() .enable_resources() .enable_tools() .build(), - server_info: Implementation::from_build_env(), - instructions: Some("This server provides counter tools and prompts. Tools: increment, decrement, get_value, say_hello, echo, sum. Prompts: example_prompt (takes a message), counter_analysis (analyzes counter state with a goal).".to_string()), - } + ) + .with_server_info(Implementation::from_build_env()) + .with_protocol_version(ProtocolVersion::V_2024_11_05) + .with_instructions("This server provides counter tools and prompts. Tools: increment, decrement, get_value, say_hello, echo, sum. Prompts: example_prompt (takes a message), counter_analysis (analyzes counter state with a goal).".to_string()) } async fn list_resources( @@ -246,21 +243,24 @@ impl ServerHandler for Counter { async fn read_resource( &self, - ReadResourceRequestParams { meta: _, uri }: ReadResourceRequestParams, + request: ReadResourceRequestParams, _: RequestContext, ) -> Result { + let uri = &request.uri; match uri.as_str() { "str:////Users/to/some/path/" => { let cwd = "/Users/to/some/path/"; - Ok(ReadResourceResult { - contents: vec![ResourceContents::text(cwd, uri)], - }) + Ok(ReadResourceResult::new(vec![ResourceContents::text( + cwd, + uri.clone(), + )])) } "memo://insights" => { let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; - Ok(ReadResourceResult { - contents: vec![ResourceContents::text(memo, uri)], - }) + Ok(ReadResourceResult::new(vec![ResourceContents::text( + memo, + uri.clone(), + )])) } _ => Err(McpError::resource_not_found( "resource_not_found", @@ -364,12 +364,7 @@ mod tests { "source".into(), serde_json::Value::String("integration-test".into()), ); - let params = CallToolRequestParams { - meta: None, - name: "long_task".into(), - arguments: None, - task: Some(task_meta), - }; + let params = CallToolRequestParams::new("long_task").with_task(Some(task_meta)); let response = client_service .send_request(ClientRequest::CallToolRequest(Request::new(params.clone()))) .await?; diff --git a/examples/servers/src/common/generic_service.rs b/examples/servers/src/common/generic_service.rs index de1b9c184..8034d5214 100644 --- a/examples/servers/src/common/generic_service.rs +++ b/examples/servers/src/common/generic_service.rs @@ -77,10 +77,7 @@ impl GenericService { #[tool_handler] impl ServerHandler for GenericService { fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("generic data service".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_instructions("generic data service".to_string()) } } diff --git a/examples/servers/src/common/progress_demo.rs b/examples/servers/src/common/progress_demo.rs index 8fec1179d..1a613e0c7 100644 --- a/examples/servers/src/common/progress_demo.rs +++ b/examples/servers/src/common/progress_demo.rs @@ -118,15 +118,13 @@ impl ProgressDemo { #[tool_handler] impl ServerHandler for ProgressDemo { fn get_info(&self) -> ServerInfo { - ServerInfo { - protocol_version: ProtocolVersion::V_2024_11_05, - capabilities: ServerCapabilities::builder().enable_tools().build(), - server_info: Implementation::from_build_env(), - instructions: Some( + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_protocol_version(ProtocolVersion::V_2024_11_05) + .with_server_info(Implementation::from_build_env()) + .with_instructions( "This server demonstrates progress notifications during long-running operations. \ Use the tools to see real-time progress updates for batch processing" .to_string(), - ), - } + ) } } diff --git a/examples/servers/src/completion_stdio.rs b/examples/servers/src/completion_stdio.rs index e4365cadc..7caa8e8c3 100644 --- a/examples/servers/src/completion_stdio.rs +++ b/examples/servers/src/completion_stdio.rs @@ -292,47 +292,41 @@ impl SqlQueryServer { ] }; - Ok(GetPromptResult { - description: Some(format!( - "SQL Query: {} on {}", - if args.operation.is_empty() { - "Unknown" - } else { - &args.operation - }, - if args.table.is_empty() { - "table" - } else { - &args.table - } - )), - messages, - }) + Ok(GetPromptResult::new(messages).with_description(format!( + "SQL Query: {} on {}", + if args.operation.is_empty() { + "Unknown" + } else { + &args.operation + }, + if args.table.is_empty() { + "table" + } else { + &args.table + } + ))) } } #[prompt_handler] impl ServerHandler for SqlQueryServer { fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder() + ServerInfo::new( + ServerCapabilities::builder() .enable_completions() .enable_prompts() .build(), - server_info: Implementation::from_build_env(), - instructions: Some( - "Smart SQL query builder with progressive completion that adapts based on your choices:\n\n\ - Step 1: Choose operation type ('sel' → SELECT, 'ins' → INSERT, 'upd' → UPDATE, 'del' → DELETE)\n\ - Step 2: Specify table name ('users', 'orders', 'products')\n\ - Step 3: Add relevant fields based on operation type:\n\ - • SELECT/UPDATE: columns ('name', 'email', 'id')\n\ - • INSERT: values to insert\n\ - • All: optional WHERE clause\n\n\ - The completion adapts - only relevant fields appear based on your SQL operation!" - .to_string(), - ), - ..Default::default() - } + ) + .with_instructions( + "Smart SQL query builder with progressive completion that adapts based on your choices:\n\n\ + Step 1: Choose operation type ('sel' → SELECT, 'ins' → INSERT, 'upd' → UPDATE, 'del' → DELETE)\n\ + Step 2: Specify table name ('users', 'orders', 'products')\n\ + Step 3: Add relevant fields based on operation type:\n\ + • SELECT/UPDATE: columns ('name', 'email', 'id')\n\ + • INSERT: values to insert\n\ + • All: optional WHERE clause\n\n\ + The completion adapts - only relevant fields appear based on your SQL operation!", + ) } async fn complete( @@ -417,7 +411,7 @@ impl ServerHandler for SqlQueryServer { has_more: Some(false), }; - Ok(CompleteResult { completion }) + Ok(CompleteResult::new(completion)) } } diff --git a/examples/servers/src/elicitation_enum_inference.rs b/examples/servers/src/elicitation_enum_inference.rs index 2ecec3115..27bde508e 100644 --- a/examples/servers/src/elicitation_enum_inference.rs +++ b/examples/servers/src/elicitation_enum_inference.rs @@ -156,14 +156,11 @@ impl ElicitationEnumFormServer { #[tool_handler] impl ServerHandler for ElicitationEnumFormServer { fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder().enable_tools().build(), - server_info: Implementation::from_build_env(), - instructions: Some( + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_server_info(Implementation::from_build_env()) + .with_instructions( "Simple server demonstrating elicitation for enum selection".to_string(), - ), - ..Default::default() - } + ) } } diff --git a/examples/servers/src/elicitation_stdio.rs b/examples/servers/src/elicitation_stdio.rs index 82f8d696a..3bf38056e 100644 --- a/examples/servers/src/elicitation_stdio.rs +++ b/examples/servers/src/elicitation_stdio.rs @@ -154,14 +154,11 @@ impl ElicitationServer { #[tool_handler] impl ServerHandler for ElicitationServer { fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder().enable_tools().build(), - server_info: Implementation::from_build_env(), - instructions: Some( + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_server_info(Implementation::from_build_env()) + .with_instructions( "Simple server demonstrating elicitation for user name collection".to_string(), - ), - ..Default::default() - } + ) } } diff --git a/examples/servers/src/prompt_stdio.rs b/examples/servers/src/prompt_stdio.rs index 0937c3e2c..812ce0e1b 100644 --- a/examples/servers/src/prompt_stdio.rs +++ b/examples/servers/src/prompt_stdio.rs @@ -174,14 +174,11 @@ impl PromptServer { ), ]; - Ok(GetPromptResult { - description: Some(format!( - "Code review for {} file focusing on {}", - args.language, - focus_areas.join(", ") - )), - messages, - }) + Ok(GetPromptResult::new(messages).with_description(format!( + "Code review for {} file focusing on {}", + args.language, + focus_areas.join(", ") + ))) } /// Data analysis prompt demonstrating context usage @@ -270,13 +267,10 @@ impl PromptServer { )); } - GetPromptResult { - description: Some(format!( - "Writing {} for {} audience with {} tone", - args.content_type, args.audience, tone - )), - messages, - } + GetPromptResult::new(messages).with_description(format!( + "Writing {} for {} audience with {} tone", + args.content_type, args.audience, tone + )) } /// Debug assistant demonstrating error handling patterns @@ -332,14 +326,11 @@ impl PromptServer { "Let's debug this systematically. First, let me understand the error context better.", )); - Ok(GetPromptResult { - description: Some(format!( - "Debugging {} error in {}", - args.error_message.chars().take(50).collect::(), - args.stack.first().map(|s| s.as_str()).unwrap_or("unknown") - )), - messages, - }) + Ok(GetPromptResult::new(messages).with_description(format!( + "Debugging {} error in {}", + args.error_message.chars().take(50).collect::(), + args.stack.first().map(|s| s.as_str()).unwrap_or("unknown") + ))) } /// Learning path prompt that uses server state @@ -376,17 +367,11 @@ impl PromptServer { #[prompt_handler] impl ServerHandler for PromptServer { fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder().enable_prompts().build(), - server_info: Implementation::from_build_env(), - instructions: Some( - "This server provides various prompt templates for code review, data analysis, \ + ServerInfo::new(ServerCapabilities::builder().enable_prompts().build()).with_instructions( + "This server provides various prompt templates for code review, data analysis, \ writing assistance, debugging help, and personalized learning paths. \ - All prompts are designed to provide structured, context-aware assistance." - .to_string(), - ), - ..Default::default() - } + All prompts are designed to provide structured, context-aware assistance.", + ) } } diff --git a/examples/servers/src/sampling_stdio.rs b/examples/servers/src/sampling_stdio.rs index 297af9d03..bd244d871 100644 --- a/examples/servers/src/sampling_stdio.rs +++ b/examples/servers/src/sampling_stdio.rs @@ -18,17 +18,12 @@ pub struct SamplingDemoServer; impl ServerHandler for SamplingDemoServer { fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some(concat!( + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_instructions(concat!( "This is a demo server that requests sampling from clients. It provides tools that use LLM capabilities.\n\n", "IMPORTANT: This server requires a client that supports the 'sampling/createMessage' method. ", "Without sampling support, the tools will return errors." - ).into()), - capabilities: ServerCapabilities::builder() - .enable_tools() - .build(), - ..Default::default() - } + )) } async fn call_tool( @@ -48,27 +43,22 @@ impl ServerHandler for SamplingDemoServer { let response = context .peer - .create_message(CreateMessageRequestParams { - meta: None, - task: None, - messages: vec![SamplingMessage::user_text(question)], - model_preferences: Some(ModelPreferences { - hints: Some(vec![ModelHint { - name: Some("claude".to_string()), - }]), - cost_priority: Some(0.3), - speed_priority: Some(0.8), - intelligence_priority: Some(0.7), - }), - system_prompt: Some("You are a helpful assistant.".to_string()), - include_context: Some(ContextInclusion::None), - temperature: Some(0.7), - max_tokens: 150, - stop_sequences: None, - metadata: None, - tools: None, - tool_choice: None, - }) + .create_message( + CreateMessageRequestParams::new( + vec![SamplingMessage::user_text(question)], + 150, + ) + .with_model_preferences( + ModelPreferences::new() + .with_hints(vec![ModelHint::new("claude")]) + .with_cost_priority(0.3) + .with_speed_priority(0.8) + .with_intelligence_priority(0.7), + ) + .with_system_prompt("You are a helpful assistant.") + .with_include_context(ContextInclusion::None) + .with_temperature(0.7), + ) .await .map_err(|e| { ErrorData::new( @@ -105,11 +95,10 @@ impl ServerHandler for SamplingDemoServer { _context: RequestContext, ) -> Result { Ok(ListToolsResult { - tools: vec![Tool { - name: "ask_llm".into(), - title: None, - description: Some("Ask a question to the LLM through sampling".into()), - input_schema: Arc::new( + tools: vec![Tool::new( + "ask_llm", + "Ask a question to the LLM through sampling", + Arc::new( serde_json::from_value(serde_json::json!({ "type": "object", "properties": { @@ -122,12 +111,7 @@ impl ServerHandler for SamplingDemoServer { })) .unwrap(), ), - output_schema: None, - annotations: None, - execution: None, - icons: None, - meta: None, - }], + )], meta: None, next_cursor: None, }) diff --git a/examples/simple-chat-client/src/tool.rs b/examples/simple-chat-client/src/tool.rs index 14f073a24..173a0a260 100644 --- a/examples/simple-chat-client/src/tool.rs +++ b/examples/simple-chat-client/src/tool.rs @@ -57,15 +57,11 @@ impl Tool for McpToolAdapter { _ => None, }; println!("arguments: {:?}", arguments); - let call_result = self - .server - .call_tool(CallToolRequestParams { - meta: None, - name: self.tool.name.clone(), - arguments, - task: None, - }) - .await?; + let mut params = CallToolRequestParams::new(self.tool.name.clone()); + if let Some(args) = arguments { + params = params.with_arguments(args); + } + let call_result = self.server.call_tool(params).await?; Ok(call_result) } diff --git a/examples/transport/src/common/calculator.rs b/examples/transport/src/common/calculator.rs index 9ae475dd9..f6d4c2a74 100644 --- a/examples/transport/src/common/calculator.rs +++ b/examples/transport/src/common/calculator.rs @@ -53,10 +53,7 @@ impl Calculator { #[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_instructions("A simple calculator") } } diff --git a/examples/transport/src/named-pipe.rs b/examples/transport/src/named-pipe.rs index 1231059bc..6f08ef221 100644 --- a/examples/transport/src/named-pipe.rs +++ b/examples/transport/src/named-pipe.rs @@ -48,15 +48,14 @@ async fn main() -> anyhow::Result<()> { println!("Calling sum tool: {}", sum_tool.name); let result = client .peer() - .call_tool(rmcp::model::CallToolRequestParams { - meta: None, - name: sum_tool.name.clone(), - arguments: Some(rmcp::object!({ - "a": 10, - "b": 20 - })), - task: None, - }) + .call_tool( + rmcp::model::CallToolRequestParams::new(sum_tool.name.clone()).with_arguments( + rmcp::object!({ + "a": 10, + "b": 20 + }), + ), + ) .await?; println!("Result: {:?}", result); diff --git a/examples/transport/src/unix_socket.rs b/examples/transport/src/unix_socket.rs index 666a61f3a..a8eb6271d 100644 --- a/examples/transport/src/unix_socket.rs +++ b/examples/transport/src/unix_socket.rs @@ -46,15 +46,14 @@ async fn main() -> anyhow::Result<()> { println!("Calling sum tool: {}", sum_tool.name); let result = client .peer() - .call_tool(rmcp::model::CallToolRequestParams { - meta: None, - name: sum_tool.name.clone(), - arguments: Some(rmcp::object!({ - "a": 10, - "b": 20 - })), - task: None, - }) + .call_tool( + rmcp::model::CallToolRequestParams::new(sum_tool.name.clone()).with_arguments( + rmcp::object!({ + "a": 10, + "b": 20 + }), + ), + ) .await?; println!("Result: {:?}", result); diff --git a/examples/wasi/src/calculator.rs b/examples/wasi/src/calculator.rs index 1806aeffd..a6f63fbe5 100644 --- a/examples/wasi/src/calculator.rs +++ b/examples/wasi/src/calculator.rs @@ -60,10 +60,7 @@ impl Calculator { #[tool_handler] impl ServerHandler for Calculator { fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + .with_instructions("A simple calculator") } }