From 4939e5de41e938af3400407f51545ab8f17b5622 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sat, 28 Feb 2026 17:44:16 +0000 Subject: [PATCH 01/20] refactor(service): remove uses of tokio::spawn --- crates/rmcp/Cargo.toml | 41 +- crates/rmcp/src/lib.rs | 1 + crates/rmcp/src/service.rs | 633 ++++++++++-------- .../transport/streamable_http_server/tower.rs | 7 +- crates/rmcp/src/util.rs | 5 + 5 files changed, 396 insertions(+), 291 deletions(-) create mode 100644 crates/rmcp/src/util.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 96c319dc..76820c81 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -19,7 +19,7 @@ async-trait = "0.1.89" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" thiserror = "2" -tokio = { version = "1", features = ["sync", "macros", "rt", "time"] } +tokio = { version = "1", features = ["sync", "macros", "time"] } futures = "0.3" tracing = { version = "0.1" } tokio-util = { version = "0.7" } @@ -109,7 +109,10 @@ client-side-sse = ["dep:sse-stream", "dep:http"] # Streamable HTTP client transport-streamable-http-client = ["client-side-sse", "transport-worker"] -transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "__reqwest"] +transport-streamable-http-client-reqwest = [ + "transport-streamable-http-client", + "__reqwest", +] transport-async-rw = ["tokio/io-util", "tokio-util/codec"] transport-io = ["transport-async-rw", "tokio/io-std"] @@ -135,7 +138,10 @@ schemars = ["dep:schemars"] [dev-dependencies] tokio = { version = "1", features = ["full"] } schemars = { version = "1.1.0", features = ["chrono04"] } -axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] } +axum = { version = "0.8", default-features = false, features = [ + "http1", + "tokio", +] } anyhow = "1.0" tracing-subscriber = { version = "0.3", features = [ "env-filter", @@ -150,12 +156,7 @@ path = "tests/test_tool_macros.rs" [[test]] name = "test_with_python" -required-features = [ - "reqwest", - "server", - "client", - "transport-child-process", -] +required-features = ["reqwest", "server", "client", "transport-child-process"] path = "tests/test_with_python.rs" [[test]] @@ -207,12 +208,22 @@ path = "tests/test_task.rs" [[test]] name = "test_streamable_http_priming" -required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "reqwest", +] path = "tests/test_streamable_http_priming.rs" [[test]] name = "test_streamable_http_json_response" -required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "reqwest", +] path = "tests/test_streamable_http_json_response.rs" @@ -249,5 +260,11 @@ path = "tests/test_custom_headers.rs" [[test]] name = "test_sse_concurrent_streams" -required-features = ["server", "client", "transport-streamable-http-server", "transport-streamable-http-client", "reqwest"] +required-features = [ + "server", + "client", + "transport-streamable-http-server", + "transport-streamable-http-client", + "reqwest", +] path = "tests/test_sse_concurrent_streams.rs" diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 9ae3f958..456bc3ea 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -3,6 +3,7 @@ #![doc = include_str!("../README.md")] mod error; +mod util; #[allow(deprecated)] pub use error::{Error, ErrorData, RmcpError}; diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index b12839c6..0e8e95ea 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -1,5 +1,6 @@ -use futures::{FutureExt, future::BoxFuture}; +use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::FuturesUnordered}; use thiserror::Error; +use tokio_stream::wrappers::ReceiverStream; #[cfg(feature = "server")] use crate::model::ServerJsonRpcMessage; @@ -11,6 +12,7 @@ use crate::{ NumberOrString, ProgressToken, RequestId, }, transport::{DynamicTransportError, IntoTransport, Transport}, + util::PinnedFuture, }; #[cfg(feature = "client")] mod client; @@ -188,6 +190,7 @@ impl> DynService for S { use std::{ collections::{HashMap, VecDeque}, + fmt::Debug, ops::Deref, sync::{Arc, atomic::AtomicU64}, time::Duration, @@ -246,6 +249,8 @@ impl RequestHandle { pub const REQUEST_TIMEOUT_REASON: &str = "request timeout"; pub async fn await_response(self) -> Result { if let Some(timeout) = self.options.timeout { + // TODO: tokio timeout won't work if not in the tokio RT + // Find an alternative let timeout_result = tokio::time::timeout(timeout, async move { self.rx.await.map_err(|_e| ServiceError::TransportClosed)? }) @@ -426,14 +431,29 @@ impl Peer { } } -#[derive(Debug)] pub struct RunningService> { service: Arc, peer: Peer, - handle: Option>, + handle: Option>, cancellation_token: CancellationToken, dg: DropGuard, } + +impl> Debug for RunningService +where + S: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RunningService") + .field("service", &self.service) + .field("peer", &self.peer) + .field("handle", &self.handle.as_ref().map(|_| "")) + .field("cancellation_token", &self.cancellation_token) + .field("dg", &self.dg) + .finish() + } +} + impl> Deref for RunningService { type Target = Peer; @@ -467,10 +487,10 @@ impl> RunningService { /// This will block until the service loop terminates (either due to /// cancellation, transport closure, or an error). #[inline] - pub async fn waiting(mut self) -> Result { + pub async fn waiting(mut self) -> QuitReason { match self.handle.take() { Some(handle) => handle.await, - None => Ok(QuitReason::Closed), + None => QuitReason::Closed, } } @@ -491,7 +511,7 @@ impl> RunningService { /// // ... use the client ... /// client.close().await?; /// ``` - pub async fn close(&mut self) -> Result { + pub async fn close(&mut self) -> QuitReason { if let Some(handle) = self.handle.take() { // Disarm the drop guard so it doesn't try to cancel again // We need to cancel manually and wait for completion @@ -499,7 +519,7 @@ impl> RunningService { handle.await } else { // Already closed - Ok(QuitReason::Closed) + QuitReason::Closed } } @@ -511,24 +531,22 @@ impl> RunningService { /// /// Returns `Ok(Some(reason))` if shutdown completed within the timeout, /// `Ok(None)` if the timeout was reached, or `Err` if there was a join error. - pub async fn close_with_timeout( - &mut self, - timeout: Duration, - ) -> Result, tokio::task::JoinError> { + pub async fn close_with_timeout(&mut self, timeout: Duration) -> Option { if let Some(handle) = self.handle.take() { self.cancellation_token.cancel(); + // TODO: tokio timeout won't work if not in the tokio RT, find an alternative match tokio::time::timeout(timeout, handle).await { - Ok(result) => result.map(Some), + Ok(reason) => Some(reason), Err(_elapsed) => { tracing::warn!( "close_with_timeout: cleanup did not complete within {:?}", timeout ); - Ok(None) + None } } } else { - Ok(Some(QuitReason::Closed)) + Some(QuitReason::Closed) } } @@ -536,7 +554,7 @@ impl> RunningService { /// /// This consumes the `RunningService` and ensures the connection is properly /// closed. For a non-consuming alternative, see [`close`](Self::close). - pub async fn cancel(mut self) -> Result { + pub async fn cancel(mut self) -> QuitReason { // Disarm the drop guard since we're handling cancellation explicitly let _ = std::mem::replace(&mut self.dg, self.cancellation_token.clone().drop_guard()); self.close().await @@ -594,11 +612,20 @@ pub struct NotificationContext { } /// Use this function to skip initialization process +/// +/// TODO: What initialization process? Reference that here +/// +/// Creates a handle to the running service, and the async task that runs the service +/// business logic. +/// +/// The caller is responsible for running the business logic task, either by spawning it on +/// a runtime or awaiting it directly. You can use the [RunningService] to cancel the +/// business logic or wait for it to finish. pub fn serve_directly( service: S, transport: T, peer_info: Option, -) -> RunningService +) -> (RunningService, impl Future) where R: ServiceRole, S: Service, @@ -609,12 +636,21 @@ where } /// Use this function to skip initialization process +/// +/// TODO: What initialization process? Reference that here +/// +/// Creates a handle to the running service, and the async task that runs the service +/// business logic. +/// +/// The caller is responsible for running the business logic task, either by spawning it on +/// a runtime or awaiting it directly. You can use the [RunningService] to cancel the +/// business logic or wait for it to finish. pub fn serve_directly_with_ct( service: S, transport: T, peer_info: Option, ct: CancellationToken, -) -> RunningService +) -> (RunningService, impl Future) where R: ServiceRole, S: Service, @@ -622,25 +658,30 @@ where E: std::error::Error + Send + Sync + 'static, { let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info); + let peer_rx = ReceiverStream::new(peer_rx); serve_inner(service, transport.into_transport(), peer, peer_rx, ct) } +/// Creates a handle to the running service, and the async task that runs the service +/// business logic. +/// +/// The caller is responsible for running the business logic task, either by spawning it on +/// a runtime or awaiting it directly. You can use the [RunningService] to cancel the +/// business logic or wait for it to finish. #[instrument(skip_all)] -fn serve_inner( +fn serve_inner( service: S, transport: T, peer: Peer, - mut peer_rx: tokio::sync::mpsc::Receiver>, + peer_rx: PeerStream, ct: CancellationToken, -) -> RunningService +) -> (RunningService, impl Future) where R: ServiceRole, S: Service, T: Transport + 'static, + PeerStream: Stream> + Unpin, { - const SINK_PROXY_BUFFER_SIZE: usize = 64; - let (sink_proxy_tx, mut sink_proxy_rx) = - tokio::sync::mpsc::channel::>(SINK_PROXY_BUFFER_SIZE); let peer_info = peer.peer_info(); if R::IS_CLIENT { tracing::info!(?peer_info, "Service initialized as client"); @@ -648,9 +689,6 @@ where tracing::info!(?peer_info, "Service initialized as server"); } - let mut local_responder_pool = - HashMap::>>::new(); - let mut local_ct_pool = HashMap::::new(); let shared_service = Arc::new(service); // for return let service = shared_service.clone(); @@ -658,283 +696,330 @@ where // let message_sink = tokio::sync:: // let mut stream = std::pin::pin!(stream); let serve_loop_ct = ct.child_token(); - let peer_return: Peer = peer.clone(); + let peer_return = peer.clone(); let current_span = tracing::Span::current(); - let handle = tokio::spawn(async move { - let mut transport = transport.into_transport(); - let mut batch_messages = VecDeque::>::new(); - let mut send_task_set = tokio::task::JoinSet::::new(); - #[derive(Debug)] - enum SendTaskResult { - Request { - id: RequestId, - result: Result<(), DynamicTransportError>, - }, - Notification { - responder: Responder>, - cancellation_param: Option, - result: Result<(), DynamicTransportError>, - }, - } - #[derive(Debug)] - enum Event { - ProxyMessage(PeerSinkMessage), - PeerMessage(RxJsonRpcMessage), - ToSink(TxJsonRpcMessage), - SendTaskResult(SendTaskResult), - } - let quit_reason = loop { - let evt = if let Some(m) = batch_messages.pop_front() { - Event::PeerMessage(m) - } else { - tokio::select! { - m = sink_proxy_rx.recv(), if !sink_proxy_rx.is_closed() => { - if let Some(m) = m { - Event::ToSink(m) - } else { - continue - } - } - m = transport.receive() => { - if let Some(m) = m { - Event::PeerMessage(m) - } else { - // input stream closed - tracing::info!("input stream terminated"); - break QuitReason::Closed - } - } - m = peer_rx.recv(), if !peer_rx.is_closed() => { - if let Some(m) = m { - Event::ProxyMessage(m) - } else { - continue - } - } - m = send_task_set.join_next(), if !send_task_set.is_empty() => { - let Some(result) = m else { - continue - }; - match result { - Err(e) => { - // join error, which is serious, we should quit. - tracing::error!(%e, "send request task encounter a tokio join error"); - break QuitReason::JoinError(e) - } - Ok(result) => { - Event::SendTaskResult(result) - } - } - } - _ = serve_loop_ct.cancelled() => { - tracing::info!("task cancelled"); - break QuitReason::Cancelled + let work = controller(transport, peer_rx, serve_loop_ct, shared_service, peer) + .instrument(current_span); + + let (work, work_handle) = work.remote_handle(); + // If the handle is dropped, don't stop the work. + // We don't want to force the user to keep the `RunningService` + // struct alive just to keep the work running (since the work + // future will be explicitly managed by the caller) + work_handle.forget(); + + let running_service = RunningService { + service, + peer: peer_return, + handle: Some(work_handle.boxed()), + cancellation_token: ct.clone(), + dg: ct.drop_guard(), + }; + + (running_service, work) +} + +/// Main business logic for event dispatching and handling. +async fn controller( + transport: T, + peer_rx: PeerStream, + cancel_token: CancellationToken, + shared_service: Arc>, + peer: Peer, +) -> QuitReason +where + R: ServiceRole, + T: Transport + 'static, + PeerStream: Stream> + Unpin, +{ + let mut transport = transport.into_transport(); + let mut batch_messages = VecDeque::>::new(); + let mut send_task_set = FuturesUnordered::>::new(); + let mut side_effects_set = FuturesUnordered::>::new(); + + let mut local_responder_pool = + HashMap::>>::new(); + let mut local_ct_pool = HashMap::::new(); + + const SINK_PROXY_BUFFER_SIZE: usize = 64; + let (sink_proxy_tx, mut rpc_rx) = + tokio::sync::mpsc::channel::>(SINK_PROXY_BUFFER_SIZE); + + // Fuse the stream, so that once it return `None` it is guaranteed to never + // be polled again. Additionally, we can check if it have been fused by checking + // `is_done()`, which we use in the select branches below. + let mut peer_rx = peer_rx.fuse(); + + #[derive(Debug)] + enum SendTaskResult { + Request { + id: RequestId, + result: Result<(), DynamicTransportError>, + }, + Notification { + responder: Responder>, + cancellation_param: Option, + result: Result<(), DynamicTransportError>, + }, + } + #[derive(Debug)] + enum Event { + ProxyMessage(PeerSinkMessage), + PeerMessage(RxJsonRpcMessage), + ToSink(TxJsonRpcMessage), + SendTaskResult(SendTaskResult), + } + + let quit_reason = loop { + // Prioritize processing batch messages before other things + let evt = if let Some(m) = batch_messages.pop_front() { + Event::PeerMessage(m) + } else { + tokio::select! { + m = rpc_rx.recv(), if !rpc_rx.is_closed() => { + if let Some(m) = m { + Event::ToSink(m) + } else { + continue } } - }; - - tracing::trace!(?evt, "new event"); - match evt { - Event::SendTaskResult(SendTaskResult::Request { id, result }) => { - if let Err(e) = result { - if let Some(responder) = local_responder_pool.remove(&id) { - let _ = responder.send(Err(ServiceError::TransportSend(e))); - } + m = transport.receive() => { + if let Some(m) = m { + Event::PeerMessage(m) + } else { + // input stream closed + tracing::info!("input stream terminated"); + break QuitReason::Closed } } - Event::SendTaskResult(SendTaskResult::Notification { - responder, - result, - cancellation_param, - }) => { - let response = if let Err(e) = result { - Err(ServiceError::TransportSend(e)) + m = peer_rx.next(), if !peer_rx.is_done() => { + if let Some(m) = m { + Event::ProxyMessage(m) } else { - Ok(()) - }; - let _ = responder.send(response); - if let Some(param) = cancellation_param { - if let Some(responder) = local_responder_pool.remove(¶m.request_id) { - tracing::info!(id = %param.request_id, reason = param.reason, "cancelled"); - let _response_result = responder.send(Err(ServiceError::Cancelled { - reason: param.reason.clone(), - })); - } + continue } } - // response and error - Event::ToSink(m) => { - if let Some(id) = match &m { - JsonRpcMessage::Response(response) => Some(&response.id), - JsonRpcMessage::Error(error) => Some(&error.id), - _ => None, - } { - if let Some(ct) = local_ct_pool.remove(id) { - ct.cancel(); - } - let send = transport.send(m); - let current_span = tracing::Span::current(); - tokio::spawn(async move { - let send_result = send.await; - if let Err(error) = send_result { - tracing::error!(%error, "fail to response message"); - } - }.instrument(current_span)); + m = send_task_set.next(), if !send_task_set.is_empty() => { + let Some(send_result) = m else { + continue + }; + Event::SendTaskResult(send_result) + } + _ = side_effects_set.next(), if !side_effects_set.is_empty() => { + // just drive the future, we don't care about the result + continue + } + _ = cancel_token.cancelled() => { + tracing::info!("task cancelled"); + break QuitReason::Cancelled + } + } + }; + + tracing::trace!(?evt, "new event"); + match evt { + Event::SendTaskResult(SendTaskResult::Request { id, result }) => { + if let Err(e) = result { + if let Some(responder) = local_responder_pool.remove(&id) { + let _ = responder.send(Err(ServiceError::TransportSend(e))); } } - Event::ProxyMessage(PeerSinkMessage::Request { - request, - id, - responder, - }) => { - local_responder_pool.insert(id.clone(), responder); - let send = transport.send(JsonRpcMessage::request(request, id.clone())); - { - let id = id.clone(); - let current_span = tracing::Span::current(); - send_task_set.spawn(send.map(move |r| SendTaskResult::Request { - id, - result: r.map_err(DynamicTransportError::new::), - }).instrument(current_span)); + } + Event::SendTaskResult(SendTaskResult::Notification { + responder, + result, + cancellation_param, + }) => { + let response = if let Err(e) = result { + Err(ServiceError::TransportSend(e)) + } else { + Ok(()) + }; + let _ = responder.send(response); + if let Some(param) = cancellation_param { + if let Some(responder) = local_responder_pool.remove(¶m.request_id) { + tracing::info!(id = %param.request_id, reason = param.reason, "cancelled"); + let _response_result = responder.send(Err(ServiceError::Cancelled { + reason: param.reason.clone(), + })); } } - Event::ProxyMessage(PeerSinkMessage::Notification { - notification, - responder, - }) => { - // catch cancellation notification - let mut cancellation_param = None; - let notification = match notification.try_into() { - Ok::(cancelled) => { - cancellation_param.replace(cancelled.params.clone()); - cancelled.into() - } - Err(notification) => notification, - }; - let send = transport.send(JsonRpcMessage::notification(notification)); + } + // response and error + Event::ToSink(m) => { + if let Some(id) = match &m { + JsonRpcMessage::Response(response) => Some(&response.id), + JsonRpcMessage::Error(error) => Some(&error.id), + _ => None, + } { + if let Some(ct) = local_ct_pool.remove(id) { + ct.cancel(); + } + let send = transport.send(m); let current_span = tracing::Span::current(); - send_task_set.spawn(send.map(move |result| SendTaskResult::Notification { + let send_work = async move { + let send_result = send.await; + if let Err(error) = send_result { + tracing::error!(%error, "fail to response message"); + } + } + .instrument(current_span) + .boxed(); + side_effects_set.push(send_work); + } + } + Event::ProxyMessage(PeerSinkMessage::Request { + request, + id, + responder, + }) => { + local_responder_pool.insert(id.clone(), responder); + let send = transport.send(JsonRpcMessage::request(request, id.clone())); + let id = id.clone(); + let current_span = tracing::Span::current(); + + let send = send + .map(move |r| SendTaskResult::Request { + id, + result: r.map_err(DynamicTransportError::new::), + }) + .instrument(current_span) + .boxed(); + send_task_set.push(send); + } + Event::ProxyMessage(PeerSinkMessage::Notification { + notification, + responder, + }) => { + // catch cancellation notification + let mut cancellation_param = None; + let notification = match notification.try_into() { + Ok::(cancelled) => { + cancellation_param.replace(cancelled.params.clone()); + cancelled.into() + } + Err(notification) => notification, + }; + let send = transport.send(JsonRpcMessage::notification(notification)); + let current_span = tracing::Span::current(); + let send = send + .map(move |result| SendTaskResult::Notification { responder, cancellation_param, result: result.map_err(DynamicTransportError::new::), - }).instrument(current_span)); - } - Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { - id, - mut request, - .. - })) => { - tracing::debug!(%id, ?request, "received request"); - { - let service = shared_service.clone(); - let sink = sink_proxy_tx.clone(); - let request_ct = serve_loop_ct.child_token(); - let context_ct = request_ct.child_token(); - local_ct_pool.insert(id.clone(), request_ct); - let mut extensions = Extensions::new(); - let mut meta = Meta::new(); - // avoid clone - // swap meta firstly, otherwise progress token will be lost - std::mem::swap(&mut meta, request.get_meta_mut()); - std::mem::swap(&mut extensions, request.extensions_mut()); - let context = RequestContext { - ct: context_ct, - id: id.clone(), - peer: peer.clone(), - meta, - extensions, + }) + .instrument(current_span) + .boxed(); + send_task_set.push(send); + } + Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest { + id, mut request, .. + })) => { + tracing::debug!(%id, ?request, "received request"); + { + let service = shared_service.clone(); + let sink = sink_proxy_tx.clone(); + let request_ct = cancel_token.child_token(); + let context_ct = request_ct.child_token(); + local_ct_pool.insert(id.clone(), request_ct); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + // swap meta firstly, otherwise progress token will be lost + std::mem::swap(&mut meta, request.get_meta_mut()); + std::mem::swap(&mut extensions, request.extensions_mut()); + let context = RequestContext { + ct: context_ct, + id: id.clone(), + peer: peer.clone(), + meta, + extensions, + }; + let current_span = tracing::Span::current(); + let work = async move { + let result = service.handle_request(request, context).await; + let response = match result { + Ok(result) => { + tracing::debug!(%id, ?result, "response message"); + JsonRpcMessage::response(result, id) + } + Err(error) => { + tracing::warn!(%id, ?error, "response error"); + JsonRpcMessage::error(error, id) + } }; - let current_span = tracing::Span::current(); - tokio::spawn(async move { - let result = service - .handle_request(request, context) - .await; - let response = match result { - Ok(result) => { - tracing::debug!(%id, ?result, "response message"); - JsonRpcMessage::response(result, id) - } - Err(error) => { - tracing::warn!(%id, ?error, "response error"); - JsonRpcMessage::error(error, id) - } - }; - let _send_result = sink.send(response).await; - }.instrument(current_span)); + let _send_result = sink.send(response).await; } + .instrument(current_span) + .boxed(); + side_effects_set.push(work); } - Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification { - notification, - .. - })) => { - tracing::info!(?notification, "received notification"); - // catch cancelled notification - let mut notification = match notification.try_into() { - Ok::(cancelled) => { - if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) { - tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled"); - ct.cancel(); - } - cancelled.into() + } + Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification { + notification, + .. + })) => { + tracing::info!(?notification, "received notification"); + // catch cancelled notification + let mut notification = match notification.try_into() { + Ok::(cancelled) => { + if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) { + tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled"); + ct.cancel(); } - Err(notification) => notification, + cancelled.into() + } + Err(notification) => notification, + }; + { + let service = shared_service.clone(); + let mut extensions = Extensions::new(); + let mut meta = Meta::new(); + // avoid clone + std::mem::swap(&mut extensions, notification.extensions_mut()); + std::mem::swap(&mut meta, notification.get_meta_mut()); + let context = NotificationContext { + peer: peer.clone(), + meta, + extensions, }; - { - let service = shared_service.clone(); - let mut extensions = Extensions::new(); - let mut meta = Meta::new(); - // avoid clone - std::mem::swap(&mut extensions, notification.extensions_mut()); - std::mem::swap(&mut meta, notification.get_meta_mut()); - let context = NotificationContext { - peer: peer.clone(), - meta, - extensions, - }; - let current_span = tracing::Span::current(); - tokio::spawn(async move { - let result = service.handle_notification(notification, context).await; - if let Err(error) = result { - tracing::warn!(%error, "Error sending notification"); - } - }.instrument(current_span)); + let current_span = tracing::Span::current(); + let work = async move { + let result = service.handle_notification(notification, context).await; + if let Err(error) = result { + tracing::warn!(%error, "Error sending notification"); + } } + .instrument(current_span) + .boxed(); + side_effects_set.push(work); } - Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse { - result, - id, - .. - })) => { - if let Some(responder) = local_responder_pool.remove(&id) { - let response_result = responder.send(Ok(result)); - if let Err(_error) = response_result { - tracing::warn!(%id, "Error sending response"); - } + } + Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse { + result, id, .. + })) => { + if let Some(responder) = local_responder_pool.remove(&id) { + let response_result = responder.send(Ok(result)); + if let Err(_error) = response_result { + tracing::warn!(%id, "Error sending response"); } } - Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => { - if let Some(responder) = local_responder_pool.remove(&id) { - let _response_result = responder.send(Err(ServiceError::McpError(error))); - if let Err(_error) = _response_result { - tracing::warn!(%id, "Error sending response"); - } + } + Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => { + if let Some(responder) = local_responder_pool.remove(&id) { + let _response_result = responder.send(Err(ServiceError::McpError(error))); + if let Err(_error) = _response_result { + tracing::warn!(%id, "Error sending response"); } } } - }; - let sink_close_result = transport.close().await; - if let Err(e) = sink_close_result { - tracing::error!(%e, "fail to close sink"); } - tracing::info!(?quit_reason, "serve finished"); - quit_reason - }.instrument(current_span)); - RunningService { - service, - peer: peer_return, - handle: Some(handle), - cancellation_token: ct.clone(), - dg: ct.drop_guard(), + }; + let sink_close_result = transport.close().await; + if let Err(e) = sink_close_result { + tracing::error!(%e, "fail to close sink"); } + tracing::info!(?quit_reason, "serve finished"); + quit_reason } diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 74b1fd79..ba28a8bb 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -593,11 +593,8 @@ where request.request.extensions_mut().insert(part); let (transport, mut receiver) = OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); - let service = serve_directly(service, transport, None); - tokio::spawn(async move { - // on service created - let _ = service.waiting().await; - }); + let (_, work) = serve_directly(service, transport, None); + tokio::spawn(work); if self.config.json_response { // JSON-direct mode: await the single response and return as // application/json, eliminating SSE framing overhead. diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs new file mode 100644 index 00000000..06a563e3 --- /dev/null +++ b/crates/rmcp/src/util.rs @@ -0,0 +1,5 @@ +use std::pin::Pin; + +pub type PinnedFuture<'a, T> = Pin + Send + 'a>>; + +pub type PinnedLocalFuture<'a, T> = Pin + 'a>>; From 3b79366999a4131546e755ff2d20a4ae031b5e1e Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sat, 28 Feb 2026 22:09:59 +0000 Subject: [PATCH 02/20] refactor(operation-processor): remove uses of tokio::spawn Refactor by using a worker future and bubling that up to the top-level of the API. The callee is now responsible for polling the worker task, or else no work will get done. --- crates/rmcp/Cargo.toml | 7 +- crates/rmcp/src/service.rs | 26 +++--- crates/rmcp/src/service/client.rs | 14 +++- crates/rmcp/src/service/server.rs | 17 +++- crates/rmcp/src/task_manager.rs | 126 +++++++++++++++++++++++------- crates/rmcp/tests/test_task.rs | 10 ++- 6 files changed, 151 insertions(+), 49 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 76820c81..e5489986 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -77,7 +77,12 @@ chrono = { version = "0.4.38", default-features = false, features = [ [features] default = ["base64", "macros", "server"] client = ["dep:tokio-stream"] -server = ["transport-async-rw", "dep:schemars", "dep:pastey"] +server = [ + "transport-async-rw", + "dep:schemars", + "dep:pastey", + "dep:tokio-stream", +] macros = ["dep:rmcp-macros", "dep:pastey"] elicitation = ["dep:url"] diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 0e8e95ea..e3fae794 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -1,4 +1,8 @@ -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::FuturesUnordered}; +use futures::{ + FutureExt, Stream, StreamExt, + future::{BoxFuture, RemoteHandle}, + stream::FuturesUnordered, +}; use thiserror::Error; use tokio_stream::wrappers::ReceiverStream; @@ -113,7 +117,9 @@ pub trait ServiceExt: Service + Sized { fn serve( self, transport: T, - ) -> impl Future, R::InitializeError>> + Send + ) -> impl Future< + Output = Result<(RunningService, impl Future), R::InitializeError>, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -125,7 +131,9 @@ pub trait ServiceExt: Service + Sized { self, transport: T, ct: CancellationToken, - ) -> impl Future, R::InitializeError>> + Send + ) -> impl Future< + Output = Result<(RunningService, impl Future), R::InitializeError>, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -434,7 +442,7 @@ impl Peer { pub struct RunningService> { service: Arc, peer: Peer, - handle: Option>, + handle: Option>, cancellation_token: CancellationToken, dg: DropGuard, } @@ -564,6 +572,9 @@ impl> RunningService { impl> Drop for RunningService { fn drop(&mut self) { if self.handle.is_some() && !self.cancellation_token.is_cancelled() { + // Make sure we don't stop the work itself, the work future should + // handle that via cancellation token or drop guard + self.handle.take().unwrap().forget(); tracing::debug!( "RunningService dropped without explicit close(). \ The connection will be closed asynchronously. \ @@ -703,16 +714,11 @@ where .instrument(current_span); let (work, work_handle) = work.remote_handle(); - // If the handle is dropped, don't stop the work. - // We don't want to force the user to keep the `RunningService` - // struct alive just to keep the work running (since the work - // future will be explicitly managed by the caller) - work_handle.forget(); let running_service = RunningService { service, peer: peer_return, - handle: Some(work_handle.boxed()), + handle: Some(work_handle), cancellation_token: ct.clone(), dg: ct.drop_guard(), }; diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 837fafef..5c971791 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -161,7 +161,12 @@ impl> ServiceExt for S { self, transport: T, ct: CancellationToken, - ) -> impl Future, ClientInitializeError>> + Send + ) -> impl Future< + Output = Result< + (RunningService, impl Future), + ClientInitializeError, + >, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -174,7 +179,7 @@ impl> ServiceExt for S { pub async fn serve_client( service: S, transport: T, -) -> Result, ClientInitializeError> +) -> Result<(RunningService, impl Future), ClientInitializeError> where S: Service, T: IntoTransport, @@ -187,7 +192,7 @@ pub async fn serve_client_with_ct( service: S, transport: T, ct: CancellationToken, -) -> Result, ClientInitializeError> +) -> Result<(RunningService, impl Future), ClientInitializeError> where S: Service, T: IntoTransport, @@ -205,7 +210,7 @@ async fn serve_client_with_ct_inner( service: S, transport: T, ct: CancellationToken, -) -> Result, ClientInitializeError> +) -> Result<(RunningService, impl Future), ClientInitializeError> where S: Service, T: Transport + 'static, @@ -263,6 +268,7 @@ where transport.send(notification).await.map_err(|error| { ClientInitializeError::transport::(error, "send initialized notification") })?; + let peer_rx = ReceiverStream::new(peer_rx); Ok(serve_inner(service, transport, peer, peer_rx, ct)) } diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 5f54f3dc..b19a1602 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -94,7 +94,12 @@ impl> ServiceExt for S { self, transport: T, ct: CancellationToken, - ) -> impl Future, ServerInitializeError>> + Send + ) -> impl Future< + Output = Result< + (RunningService, impl Future), + ServerInitializeError, + >, + > + Send where T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -107,7 +112,7 @@ impl> ServiceExt for S { pub async fn serve_server( service: S, transport: T, -) -> Result, ServerInitializeError> +) -> Result<(RunningService, impl Future), ServerInitializeError> where S: Service, T: IntoTransport, @@ -166,7 +171,7 @@ pub async fn serve_server_with_ct( service: S, transport: T, ct: CancellationToken, -) -> Result, ServerInitializeError> +) -> Result<(RunningService, impl Future), ServerInitializeError> where S: Service, T: IntoTransport, @@ -180,11 +185,14 @@ where } } +/// Performs handshake and initial protocol setup through the transport, +/// and returns a [RunningService] with a separate work future that will +/// need polled to run the service. async fn serve_server_with_ct_inner( service: S, transport: T, ct: CancellationToken, -) -> Result, ServerInitializeError> +) -> Result<(RunningService, impl Future), ServerInitializeError> where S: Service, T: Transport + 'static, @@ -258,6 +266,7 @@ where peer: peer.clone(), }; let _ = service.handle_notification(notification, context).await; + let peer_rx = ReceiverStream::new(peer_rx); // Continue processing service Ok(serve_inner(service, transport, peer, peer_rx, ct)) } diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs index 774c542f..c1b6edc3 100644 --- a/crates/rmcp/src/task_manager.rs +++ b/crates/rmcp/src/task_manager.rs @@ -1,6 +1,10 @@ -use std::{any::Any, collections::HashMap, pin::Pin}; +use std::{any::Any, collections::HashMap}; -use futures::Future; +use futures::{ + Future, FutureExt, StreamExt, + future::abortable, + stream::{AbortHandle, FuturesUnordered}, +}; use tokio::{ sync::mpsc, time::{Duration, timeout}, @@ -11,11 +15,14 @@ use crate::{ error::{ErrorData as McpError, RmcpError as Error}, model::{CallToolResult, ClientRequest}, service::RequestContext, + util::PinnedFuture, }; +/// Result of running an operation +pub type OperationResult = Result, Error>; + /// Boxed future that represents an asynchronous operation managed by the processor. -pub type OperationFuture = - Pin, Error>> + Send>>; +pub type OperationFuture<'a> = PinnedFuture<'a, OperationResult>; /// Describes metadata associated with an enqueued task. #[derive(Debug, Clone)] @@ -57,11 +64,11 @@ impl OperationDescriptor { /// Operation message describing a unit of asynchronous work. pub struct OperationMessage { pub descriptor: OperationDescriptor, - pub future: OperationFuture, + pub future: OperationFuture<'static>, } impl OperationMessage { - pub fn new(descriptor: OperationDescriptor, future: OperationFuture) -> Self { + pub fn new(descriptor: OperationDescriptor, future: OperationFuture<'static>) -> Self { Self { descriptor, future } } } @@ -80,17 +87,23 @@ pub struct OperationProcessor { running_tasks: HashMap, /// Completed results waiting to be collected completed_results: Vec, + /// Receiver for asynchronously completed task results. Used + /// to collect back into `completed_results` task_result_receiver: mpsc::UnboundedReceiver, - task_result_sender: mpsc::UnboundedSender, + /// Sender to spawn futures on the worker task associated with this + /// processor. The worker future is created as part of [OperationProcessor::new] + spawn_tx: mpsc::UnboundedSender<(OperationDescriptor, OperationFuture<'static>)>, } +/// A handle to a running operation. struct RunningTask { - task_handle: tokio::task::JoinHandle<()>, + task_handle: AbortHandle, started_at: std::time::Instant, timeout: Option, descriptor: OperationDescriptor, } +/// The result of a running operation. pub struct TaskResult { pub descriptor: OperationDescriptor, pub result: Result, Error>, @@ -126,21 +139,63 @@ impl OperationResultTransport for ToolCallTaskResult { } } -impl Default for OperationProcessor { - fn default() -> Self { - Self::new() - } -} - impl OperationProcessor { - pub fn new() -> Self { + /// Create a new operation processor. + /// + /// This function will return the new [OperationProcessor] + /// facade you can use to queue operations, and also a future + /// that must be polled to handle these operations. + /// + /// Spawn the work function on your runtime of choice, or poll it + /// manually. + pub fn new() -> (Self, impl Future) { let (task_result_sender, task_result_receiver) = mpsc::unbounded_channel(); - Self { + let (spawn_tx, mut spawn_rx) = + mpsc::unbounded_channel::<(OperationDescriptor, OperationFuture)>(); + + let work = async move { + let mut work_set = + FuturesUnordered::>::new(); + + // Loop and listen for new operations incoming that need to be added to the future pool, + // and also listen to operation completions via the future pool. + loop { + tokio::select! { + spawn_req = spawn_rx.recv(), if !spawn_rx.is_closed() => { + if let Some((descriptor, fut)) = spawn_req { + // Map the future back to a descriptor and result tuple + let operation_work = fut.map(|result| (descriptor, result)).boxed(); + // Add it to the set we are polling + work_set.push(operation_work); + } + }, + operation_result = work_set.next(), if !work_set.is_empty() => { + if let Some((descriptor, result)) = operation_result { + match task_result_sender.send(TaskResult { descriptor, result }) { + Err(e) => { + // TODO: Produce an error message here! + } + _ => {} + } + }; + }, + else => { + // Work was empty, and spawn channel was closed. Time + // to break the loop. + break; + } + } + } + }; + + let this = Self { running_tasks: HashMap::new(), completed_results: Vec::new(), task_result_receiver, - task_result_sender, - } + spawn_tx, + }; + + (this, work) } /// Submit an operation for asynchronous execution. @@ -159,12 +214,11 @@ impl OperationProcessor { Ok(()) } + /// Spawns an operation to be executed to completion. fn spawn_async_task(&mut self, message: OperationMessage) { let OperationMessage { descriptor, future } = message; let task_id = descriptor.operation_id.clone(); let timeout_secs = descriptor.ttl.or(Some(DEFAULT_TASK_TIMEOUT_SECS)); - let sender = self.task_result_sender.clone(); - let descriptor_for_result = descriptor.clone(); let timed_future = async move { if let Some(secs) = timeout_secs { @@ -177,16 +231,32 @@ impl OperationProcessor { } }; - let handle = tokio::spawn(async move { - let result = timed_future.await; - let task_result = TaskResult { - descriptor: descriptor_for_result, - result, - }; - let _ = sender.send(task_result); + // Below, we want to give the user a handle to the long-running operation, + // but we don't want to send the result to the user's handle. Rather the + // result gets consumed in the worker task created in the `Self::new` + // function. So here we will use the `Abortable` future utility. + let (work, abort_handle) = abortable(timed_future); + + // Map the error type of abortion (for now) + let work = work.map(|result| { + match result { + // Was not aborted, true operation result + Ok(inner_result) => inner_result, + // Was aborted, flatten to expected error type + Err(e) => Err(Error::TaskError(e.to_string())), + } }); + + // Then send the work to be executed + match self.spawn_tx.send((descriptor.clone(), work.boxed())) { + Ok(_) => {} + Err(e) => { + // TODO: Produce an error message! + } + } + let running_task = RunningTask { - task_handle: handle, + task_handle: abort_handle, started_at: std::time::Instant::now(), timeout: timeout_secs, descriptor, diff --git a/crates/rmcp/tests/test_task.rs b/crates/rmcp/tests/test_task.rs index 9ad0b200..c0f08de0 100644 --- a/crates/rmcp/tests/test_task.rs +++ b/crates/rmcp/tests/test_task.rs @@ -21,7 +21,10 @@ impl OperationResultTransport for DummyTransport { #[tokio::test] async fn executes_enqueued_future() { - let mut processor = OperationProcessor::new(); + let (mut processor, work) = OperationProcessor::new(); + + tokio::spawn(work); + let descriptor = OperationDescriptor::new("op1", "dummy"); let future = Box::pin(async { tokio::time::sleep(Duration::from_millis(10)).await; @@ -50,7 +53,10 @@ async fn executes_enqueued_future() { #[tokio::test] async fn rejects_duplicate_operation_ids() { - let mut processor = OperationProcessor::new(); + let (mut processor, work) = OperationProcessor::new(); + + tokio::spawn(work); + let descriptor = OperationDescriptor::new("dup", "dummy"); let future = Box::pin(async { Ok(Box::new(DummyTransport { From 8d6655b802663477f6d506ccf657d920f5fddff5 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 00:14:16 +0000 Subject: [PATCH 03/20] refactor(progress): remove need for spawning on drop larger refactor for the way progress is multiplexed this needed a redesign of the broadcast multiplex logic to a more stateless design. this design removes the need for mutating any state on drop, the stream dropping implicitly removes broadcast listeners. this design also allows for multiple subscribers of the same progress token. --- crates/rmcp/Cargo.toml | 2 +- crates/rmcp/src/handler/client/progress.rs | 151 +++++++++++++----- crates/rmcp/src/util.rs | 5 + crates/rmcp/tests/test_progress_subscriber.rs | 8 +- 4 files changed, 123 insertions(+), 43 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index e5489986..35049780 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -56,7 +56,7 @@ process-wrap = { version = "9.0", features = ["tokio1"], optional = true } # for http-server transport rand = { version = "0.10", optional = true } -tokio-stream = { version = "0.1", optional = true } +tokio-stream = { version = "0.1", optional = true, features = ["sync"] } uuid = { version = "1", features = ["v4"], optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } diff --git a/crates/rmcp/src/handler/client/progress.rs b/crates/rmcp/src/handler/client/progress.rs index 04f31610..7dd84f08 100644 --- a/crates/rmcp/src/handler/client/progress.rs +++ b/crates/rmcp/src/handler/client/progress.rs @@ -1,32 +1,55 @@ use std::{collections::HashMap, sync::Arc}; use futures::{Stream, StreamExt}; -use tokio::sync::RwLock; -use tokio_stream::wrappers::ReceiverStream; +use tokio::sync::{RwLock, broadcast}; +use tokio_stream::wrappers::BroadcastStream; -use crate::model::{ProgressNotificationParam, ProgressToken}; -type Dispatcher = - Arc>>>; +use crate::{ + model::{ProgressNotificationParam, ProgressToken}, + util::PinnedStream, +}; /// A dispatcher for progress notifications. -#[derive(Debug, Clone, Default)] +/// +/// See [ProgressNotificationParam] and [ProgressToken] for more details on +/// how progress is dispatched to a particular listener. +#[derive(Debug, Clone)] pub struct ProgressDispatcher { - pub(crate) dispatcher: Dispatcher, + /// A channel of any progress notification. Subscribers will filter + /// on this channel. + pub(crate) any_progress_notification_tx: broadcast::Sender, + pub(crate) unsubscribe_tx: broadcast::Sender, + pub(crate) unsubscribe_all_tx: broadcast::Sender<()>, } impl ProgressDispatcher { const CHANNEL_SIZE: usize = 16; pub fn new() -> Self { - Self::default() + // Note that channel size is per-receiver for broadcast channel. It is up to the receiver to + // keep up with the notifications to avoid missing any (via propper polling) + let (any_progress_notification_tx, _) = broadcast::channel(Self::CHANNEL_SIZE); + let (unsubscribe_tx, _) = broadcast::channel(Self::CHANNEL_SIZE); + let (unsubscribe_all_tx, _) = broadcast::channel(Self::CHANNEL_SIZE); + Self { + any_progress_notification_tx, + unsubscribe_tx, + unsubscribe_all_tx, + } } /// Handle a progress notification by sending it to the appropriate subscriber pub async fn handle_notification(&self, notification: ProgressNotificationParam) { - let token = ¬ification.progress_token; - if let Some(sender) = self.dispatcher.read().await.get(token).cloned() { - let send_result = sender.send(notification).await; - if let Err(e) = send_result { - tracing::warn!("Failed to send progress notification: {e}"); + // Broadcast the notification to all subscribers. Interested subscribers + // will filter on their end. + // ! Note that this implementaiton is very stateless and simple, we cannot + // ! easily inspect which subscribers are interested in which notifications. + // ! However, the stateless-ness and simplicity is also a plus! + // ! Cleanup becomes much easier. Just drop the `ProgressSubscriber`. + match self.any_progress_notification_tx.send(notification) { + Ok(_) => {} + Err(_) => { + // This error only happens if there are no active receivers of the `broadcast` channel. + // Silent error. } } } @@ -35,35 +58,97 @@ impl ProgressDispatcher { /// /// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token. pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber { - let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE); - self.dispatcher - .write() - .await - .insert(progress_token.clone(), sender); - let receiver = ReceiverStream::new(receiver); + // First, set up the unsubscribe listeners. This will fuse the notifiaction stream below. + let progress_token_clone = progress_token.clone(); + let unsub_this_token_rx = BroadcastStream::new(self.unsubscribe_tx.subscribe()).filter_map( + move |token| { + let progress_token_clone = progress_token_clone.clone(); + async move { + match token { + Ok(token) => { + if token == progress_token_clone { + Some(()) + } else { + None + } + } + Err(e) => { + // An error here means the broadcast stream did not receive values quick enough and + // and we missed some notification. This implies there are notifications + // we missed, but we cannot assume they were for us :( + tracing::warn!( + "Error receiving unsubscribe notification from broadcast channel: {e}" + ); + None + } + } + } + }, + ); + let unsub_any_token_tx = + BroadcastStream::new(self.unsubscribe_all_tx.subscribe()).map(|_| { + // Any reception of a result here indicates we should unsubscribe, + // regardless of if we received an `Ok(())` or an `Err(_)` (which + // indicates the broadcast receiver lagged behind) + () + }); + let unsub_fut = futures::stream::select(unsub_this_token_rx, unsub_any_token_tx) + .boxed() + .into_future(); // If the unsub streams end, this will cause unsubscription from the subscriber below. + + // Now setup the notification stream. We will receive all notifications and only forward progress notifications + // for the token we're interested in. + let progress_token_clone = progress_token.clone(); + let receiver = BroadcastStream::new(self.any_progress_notification_tx.subscribe()) + .filter_map(move |notification| { + let progress_token_clone = progress_token_clone.clone(); + async move { + // We need to kneed-out the broadcast receive error type here. + match notification { + Ok(notification) => { + let token = notification.progress_token.clone(); + if token == progress_token_clone { + Some(notification) + } else { + None + } + } + Err(e) => { + tracing::warn!( + "Error receiving progress notification from broadcast channel: {e}" + ); + None + } + } + } + }) + // Fuse this stream so it stops once we receive an unsubscribe notification from the stream + // created above + .take_until(unsub_fut) + .boxed(); + ProgressSubscriber { progress_token, receiver, - dispatcher: self.dispatcher.clone(), } } /// Unsubscribe from progress notifications for a specific token. - pub async fn unsubscribe(&self, token: &ProgressToken) { - self.dispatcher.write().await.remove(token); + pub fn unsubscribe(&self, token: ProgressToken) { + // The only error defined is if there are no listeners, which is fine. Ignore the result. + let _ = self.unsubscribe_tx.send(token); } /// Clear all dispatcher. - pub async fn clear(&self) { - let mut dispatcher = self.dispatcher.write().await; - dispatcher.clear(); + pub fn clear(&self) { + // The only error defined is if there are no listeners, which is fine. Ignore the result. + let _ = self.unsubscribe_all_tx.send(()); } } pub struct ProgressSubscriber { pub(crate) progress_token: ProgressToken, - pub(crate) receiver: ReceiverStream, - pub(crate) dispatcher: Dispatcher, + pub(crate) receiver: PinnedStream<'static, ProgressNotificationParam>, } impl ProgressSubscriber { @@ -86,15 +171,3 @@ impl Stream for ProgressSubscriber { self.receiver.size_hint() } } - -impl Drop for ProgressSubscriber { - fn drop(&mut self) { - let token = self.progress_token.clone(); - self.receiver.close(); - let dispatcher = self.dispatcher.clone(); - tokio::spawn(async move { - let mut dispatcher = dispatcher.write_owned().await; - dispatcher.remove(&token); - }); - } -} diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs index 06a563e3..33b273f8 100644 --- a/crates/rmcp/src/util.rs +++ b/crates/rmcp/src/util.rs @@ -1,5 +1,10 @@ +use futures::Stream; use std::pin::Pin; pub type PinnedFuture<'a, T> = Pin + Send + 'a>>; pub type PinnedLocalFuture<'a, T> = Pin + 'a>>; + +pub type PinnedStream<'a, T> = Pin + Send + 'a>>; + +pub type PinnedLocalStream<'a, T> = Pin + 'a>>; diff --git a/crates/rmcp/tests/test_progress_subscriber.rs b/crates/rmcp/tests/test_progress_subscriber.rs index 521219a3..5c5715b9 100644 --- a/crates/rmcp/tests/test_progress_subscriber.rs +++ b/crates/rmcp/tests/test_progress_subscriber.rs @@ -100,11 +100,13 @@ async fn test_progress_subscriber() -> anyhow::Result<()> { let server = MyServer::new(); let (transport_server, transport_client) = tokio::io::duplex(4096); tokio::spawn(async move { - let service = server.serve(transport_server).await?; - service.waiting().await?; + let (service, work) = server.serve(transport_server).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Ok(()) }); - let client_service = client.serve(transport_client).await?; + let (client_service, client_work) = client.serve(transport_client).await?; + tokio::spawn(client_work); let handle = client_service .send_cancellable_request( ClientRequest::CallToolRequest(Request::new(CallToolRequestParams { From dc4204ca270f4f7fd0c730b9d30e54a585c7047f Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 05:16:26 +0000 Subject: [PATCH 04/20] refactor(child-process): wip experiment with new child process transport --- crates/rmcp/Cargo.toml | 2 +- crates/rmcp/src/lib.rs | 1 + crates/rmcp/src/transport.rs | 2 + crates/rmcp/src/transport/async_rw.rs | 6 +- crates/rmcp/src/transport/child_process.rs | 4 +- crates/rmcp/src/transport/child_process2.rs | 2 + .../src/transport/child_process2/runner.rs | 314 ++++++++++++++++++ .../src/transport/child_process2/transport.rs | 71 ++++ crates/rmcp/src/util.rs | 65 +++- 9 files changed, 461 insertions(+), 6 deletions(-) create mode 100644 crates/rmcp/src/transport/child_process2.rs create mode 100644 crates/rmcp/src/transport/child_process2/runner.rs create mode 100644 crates/rmcp/src/transport/child_process2/transport.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 35049780..421bd07a 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -119,7 +119,7 @@ transport-streamable-http-client-reqwest = [ "__reqwest", ] -transport-async-rw = ["tokio/io-util", "tokio-util/codec"] +transport-async-rw = ["tokio/io-util", "tokio-util/codec", "tokio-util/compat"] transport-io = ["transport-async-rw", "tokio/io-std"] transport-child-process = [ "transport-async-rw", diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 456bc3ea..c70e61b5 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -4,6 +4,7 @@ mod error; mod util; + #[allow(deprecated)] pub use error::{Error, ErrorData, RmcpError}; diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index d7dfa979..f22b63d8 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -85,6 +85,8 @@ pub use worker::WorkerTransport; pub mod child_process; #[cfg(feature = "transport-child-process")] pub use child_process::{ConfigureCommandExt, TokioChildProcess}; +#[cfg(feature = "transport-child-process")] +pub mod child_process2; #[cfg(feature = "transport-io")] pub mod io; diff --git a/crates/rmcp/src/transport/async_rw.rs b/crates/rmcp/src/transport/async_rw.rs index ff4ecc65..cb5ea75d 100644 --- a/crates/rmcp/src/transport/async_rw.rs +++ b/crates/rmcp/src/transport/async_rw.rs @@ -43,9 +43,13 @@ where } pub type TransportWriter = FramedWrite>>; +pub type TransportReader = FramedRead>>; pub struct AsyncRwTransport { - read: FramedRead>>, + read: TransportReader, + /// This is behind a mutex so that concurrent writes can happen. + /// Naturally, the mutex will block parallel writes, but allow + /// multiple futures to be executed at once, even if some are waiting. write: Arc>>>, } diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs index e33800b1..58e8b9a8 100644 --- a/crates/rmcp/src/transport/child_process.rs +++ b/crates/rmcp/src/transport/child_process.rs @@ -23,7 +23,7 @@ type ChildProcessParts = ( /// Returns `(child, stdout, stdin, stderr)` where `stderr` is `Some` only /// if the process was spawned with `Stdio::piped()`. #[inline] -fn child_process(mut child: Box) -> std::io::Result { +fn split_child_process(mut child: Box) -> std::io::Result { let child_stdin = match child.inner_mut().stdin().take() { Some(stdin) => stdin, None => return Err(std::io::Error::other("stdin was already taken")), @@ -192,7 +192,7 @@ impl TokioChildProcessBuilder { .stdout(self.stdout) .stderr(self.stderr); - let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?; + let (child, stdout, stdin, stderr_opt) = split_child_process(self.cmd.spawn()?)?; let transport = AsyncRwTransport::new(stdout, stdin); let proc = TokioChildProcess { diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process2.rs new file mode 100644 index 00000000..27f1eb46 --- /dev/null +++ b/crates/rmcp/src/transport/child_process2.rs @@ -0,0 +1,2 @@ +pub mod runner; +pub mod transport; diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process2/runner.rs new file mode 100644 index 00000000..d169970f --- /dev/null +++ b/crates/rmcp/src/transport/child_process2/runner.rs @@ -0,0 +1,314 @@ +use futures::{ + FutureExt, + io::{AsyncRead, AsyncWrite}, +}; +use std::process::Stdio; + +use crate::util::PinnedFuture; + +/// A simple enum for describing if a stream is available, unused, or already taken. +#[derive(Debug)] +pub enum StreamSlot { + /// The stream is not used in this implementation. + Unused, + /// The stream is available for use, and can be taken. + Available(S), + /// The stream has already been taken, and is no longer available. + Taken, +} + +impl From> for Option { + fn from(slot: StreamSlot) -> Self { + match slot { + StreamSlot::Unused => None, + StreamSlot::Available(s) => Some(s), + StreamSlot::Taken => None, + } + } +} + +/// A structure that requests how the child process streams should +/// be configured when spawning. +pub struct StdioConfig { + pub stdin: Stdio, + pub stdout: Stdio, + pub stderr: Stdio, +} + +/// The contract for what an instance of a child process +/// must provide to be used with a transport. +pub trait ChildProcessInstance { + /// The input stream for the command + type Stdin: AsyncWrite + Unpin + Send; + + /// The output stream of the command + type Stdout: AsyncRead + Unpin + Send; + + /// The error stream of the command + type Stderr: AsyncRead + Unpin + Send; + + fn take_stdin(&mut self) -> StreamSlot; + fn take_stdout(&mut self) -> StreamSlot; + fn take_stderr(&mut self) -> StreamSlot; + + fn pid(&self) -> u32; + fn wait( + &mut self, + ) -> impl Future> + Send + 'static; + fn graceful_shutdown(&mut self) -> impl Future> + Send + 'static; + fn kill(&mut self) -> impl Future> + Send + 'static; +} + +/// A subset of functionality of [ChildProcessInstance] that only includes the +/// functions used to control or wait for the process. +pub trait ChildProcessControl { + fn pid(&self) -> u32; + fn wait(&mut self) -> PinnedFuture<'static, std::io::Result>; + fn graceful_shutdown(&mut self) -> PinnedFuture<'static, std::io::Result<()>>; + fn kill(&mut self) -> PinnedFuture<'static, std::io::Result<()>>; +} + +/// Auto-implement ChildProcessControl for any ChildProcessInstance, since it has all the required methods. +impl ChildProcessControl for T +where + T: ChildProcessInstance, +{ + fn pid(&self) -> u32 { + ChildProcessInstance::pid(self) + } + + fn wait(&mut self) -> PinnedFuture<'static, std::io::Result> { + ChildProcessInstance::wait(self).boxed() + } + + fn graceful_shutdown(&mut self) -> PinnedFuture<'static, std::io::Result<()>> { + ChildProcessInstance::graceful_shutdown(self).boxed() + } + + fn kill(&mut self) -> PinnedFuture<'static, std::io::Result<()>> { + ChildProcessInstance::kill(self).boxed() + } +} + +#[derive(Debug)] +pub enum RunnerSpawnError { + /// The child process instance failed to spawn. + SpawnError(std::io::Error), + Other(Box), +} + +pub trait ChildProcessRunner { + /// The implementation of the child process instance that this runner will spawn. + type Instance: ChildProcessInstance; + + fn spawn( + command: &str, + args: &[&str], + stdio_config: StdioConfig, + ) -> Result; +} + +/// A containing wrapper around a child process instance. This struct erases the type +/// by extracting some parts of the [ChildProcessInstance] trait into a common struct, +/// and then only exposes the control methods. +pub struct ChildProcess { + stdin: StreamSlot>, + stdout: StreamSlot>, + stderr: StreamSlot>, + inner: Box, +} + +impl ChildProcess { + pub fn new(mut instance: C) -> Self + where + C: ChildProcessInstance + Send + 'static, + { + Self { + stdin: match instance.take_stdin() { + StreamSlot::Available(s) => StreamSlot::Available(Box::new(s)), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => { + panic!("Stdin stream was already taken during ChildProcess construction") + } + }, + stdout: match instance.take_stdout() { + StreamSlot::Available(s) => StreamSlot::Available(Box::new(s)), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => { + panic!("Stdout stream was already taken during ChildProcess construction") + } + }, + stderr: match instance.take_stderr() { + StreamSlot::Available(s) => StreamSlot::Available(Box::new(s)), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => { + panic!("Stderr stream was already taken during ChildProcess construction") + } + }, + inner: Box::new(instance), + } + } + + pub fn split( + self, + ) -> ( + Option>, + Option>, + Option>, + Box, + ) { + ( + self.stdout.into(), + self.stdin.into(), + self.stderr.into(), + self.inner, + ) + } +} + +impl ChildProcessInstance for ChildProcess { + type Stdin = Box; + + type Stdout = Box; + + type Stderr = Box; + + fn take_stdin(&mut self) -> StreamSlot { + match self.stdin { + StreamSlot::Available(_) => std::mem::replace(&mut self.stdin, StreamSlot::Taken), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => StreamSlot::Taken, + } + } + + fn take_stdout(&mut self) -> StreamSlot { + match self.stdout { + StreamSlot::Available(_) => std::mem::replace(&mut self.stdout, StreamSlot::Taken), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => StreamSlot::Taken, + } + } + + fn take_stderr(&mut self) -> StreamSlot { + match self.stderr { + StreamSlot::Available(_) => std::mem::replace(&mut self.stderr, StreamSlot::Taken), + StreamSlot::Unused => StreamSlot::Unused, + StreamSlot::Taken => StreamSlot::Taken, + } + } + + fn pid(&self) -> u32 { + self.inner.pid() + } + + fn wait( + &mut self, + ) -> impl Future> + Send + 'static { + self.inner.wait() + } + + fn graceful_shutdown(&mut self) -> impl Future> + Send + 'static { + self.inner.graceful_shutdown() + } + + fn kill(&mut self) -> impl Future> + Send + 'static { + self.inner.kill() + } +} + +pub struct CommandBuilder { + command: String, + args: Vec, + _marker: std::marker::PhantomData, + stderr: Stdio, +} + +pub enum CommandBuilderError { + EmptyCommand, +} + +impl CommandBuilder { + /// Create a CommandBuilder from an argv-style list of strings, where the first element is the command, and the rest are the args. + pub fn from_argv(argv: I) -> Result + where + I: IntoIterator, + S: Into, + { + let mut iter = argv.into_iter(); + + // Pop the first element as the command, and use the rest as args + let command = match iter.next() { + Some(cmd) => cmd.into(), + None => return Err(CommandBuilderError::EmptyCommand), + }; + + let args = iter.map(|s| s.into()).collect(); + Ok(Self { + command, + args, + _marker: std::marker::PhantomData, + stderr: Stdio::inherit(), + }) + } + + /// Create a CommandBuilder from a command and an optional list of args. + pub fn new(command: impl Into) -> Self { + Self { + command: command.into(), + args: Vec::new(), + _marker: std::marker::PhantomData, + stderr: Stdio::inherit(), + } + } + + /// Add a single argument to the command. + pub fn arg(mut self, arg: impl Into) -> Self { + self.args.push(arg.into()); + self + } + + /// Add multiple arguments to the command. + pub fn args(mut self, args: impl IntoIterator>) -> Self { + self.args.extend(args.into_iter().map(|arg| arg.into())); + self + } + + /// Sets what happens to stderr for the command. + /// By default if not set, stderr is inherited from the parent process. + pub fn stderr(mut self, _stdio: Stdio) -> Self { + self.stderr = _stdio; + self + } +} + +impl CommandBuilder +where + R: ChildProcessRunner, +{ + /// Spawn the command into its typed child process instance type. + pub fn spawn_raw(self) -> Result { + // We should always pipe stdin and stdout. + let stdio_config = StdioConfig { + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: self.stderr, + }; + + R::spawn( + &self.command, + &self.args.iter().map(|s| s.as_str()).collect::>(), + stdio_config, + ) + } + + /// Spawn a child process struct that erases the specific child process instance type, and only exposes the control methods. + /// + /// Requires `R::Instance` to be [Send] and `'static`. + pub fn spawn_dyn(self) -> Result + where + R::Instance: Send + 'static, + { + let instance = self.spawn_raw()?; + Ok(ChildProcess::new(instance)) + } +} diff --git a/crates/rmcp/src/transport/child_process2/transport.rs b/crates/rmcp/src/transport/child_process2/transport.rs new file mode 100644 index 00000000..b378309a --- /dev/null +++ b/crates/rmcp/src/transport/child_process2/transport.rs @@ -0,0 +1,71 @@ +use tokio::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite}; +use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt}; + +use crate::{ + service::ServiceRole, + transport::{ + Transport, + async_rw::AsyncRwTransport, + child_process2::runner::{ChildProcess, ChildProcessControl}, + }, +}; + +pub struct ChildProcessTransport { + child: Box, + framed_transport: AsyncRwTransport< + R, + Box, + Box, + >, +} + +impl ChildProcessTransport +where + R: ServiceRole, +{ + pub fn new(child: ChildProcess) -> Result> { + let (stdout, stdin, stderr, control) = child.split(); + + let framed_transport: AsyncRwTransport = AsyncRwTransport::new( + Box::new( + stdout + .ok_or("Failed to capture stdout of child process")? + .compat(), + ) as Box, + Box::new( + stdin + .ok_or("Failed to capture stdin of child process")? + .compat_write(), + ) as Box, + ); + + Ok(Self { + child: control, + framed_transport, + }) + } +} + +impl Transport for ChildProcessTransport +where + R: ServiceRole, +{ + type Error = std::io::Error; + + fn send( + &mut self, + item: crate::service::TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + self.framed_transport.send(item) + } + + fn receive( + &mut self, + ) -> impl Future>> + Send { + self.framed_transport.receive() + } + + fn close(&mut self) -> impl Future> + Send { + self.framed_transport.close() + } +} diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs index 33b273f8..97121ac3 100644 --- a/crates/rmcp/src/util.rs +++ b/crates/rmcp/src/util.rs @@ -1,5 +1,5 @@ -use futures::Stream; -use std::pin::Pin; +use futures::{Sink, Stream}; +use std::{pin::Pin, task::Poll}; pub type PinnedFuture<'a, T> = Pin + Send + 'a>>; @@ -8,3 +8,64 @@ pub type PinnedLocalFuture<'a, T> = Pin + 'a>>; pub type PinnedStream<'a, T> = Pin + Send + 'a>>; pub type PinnedLocalStream<'a, T> = Pin + 'a>>; + +pub enum UnboundedSenderSinkError { + SendError(tokio::sync::mpsc::error::SendError), + Closed, +} + +/// A simple [Sink] wrapper for Tokio's [tokio::sync::mpsc::UnboundedSender] +#[derive(Debug, Clone)] +pub struct UnboundedSenderSink { + sender: tokio::sync::mpsc::UnboundedSender, +} + +impl UnboundedSenderSink { + pub fn new(sender: tokio::sync::mpsc::UnboundedSender) -> Self { + Self { sender } + } +} + +impl Sink for UnboundedSenderSink { + type Error = UnboundedSenderSinkError; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + if this.sender.is_closed() { + Poll::Ready(Err(UnboundedSenderSinkError::Closed)) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let this = self.get_mut(); + match this.sender.send(item) { + Ok(_) => Ok(()), + Err(e) => Err(UnboundedSenderSinkError::SendError(e)), + } + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // tokio's unbounded mpsc senders have no flushing required, since the + // receiver is unbounded and will get all messages we send (unless we run + // out of memory) + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // Like `poll_flush`, there is nothing to wait on here. A single + // call to `mpsc_sender.send(...)` is immediate from the perspective + // of the sender + Poll::Ready(Ok(())) + } +} From 1db16c8ddae7fdbbd94eb192d3e24123bc514238 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 06:28:21 +0000 Subject: [PATCH 05/20] refactor(child-process): implement tokio child process and use in test --- crates/rmcp/Cargo.toml | 2 + crates/rmcp/src/task_manager.rs | 4 +- crates/rmcp/src/transport/child_process2.rs | 1 + .../src/transport/child_process2/runner.rs | 41 +++++----- .../src/transport/child_process2/tokio.rs | 81 +++++++++++++++++++ .../src/transport/child_process2/transport.rs | 2 +- crates/rmcp/src/transport/common.rs | 2 +- crates/rmcp/src/transport/common/reqwest.rs | 2 +- .../transport/streamable_http_server/tower.rs | 14 ++-- crates/rmcp/src/util.rs | 6 +- crates/rmcp/tests/test_with_js.rs | 31 ++++--- 11 files changed, 145 insertions(+), 41 deletions(-) create mode 100644 crates/rmcp/src/transport/child_process2/tokio.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 421bd07a..f8e8011e 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -74,6 +74,8 @@ chrono = { version = "0.4.38", default-features = false, features = [ "oldtime", ] } +[target.'cfg(test)'] + [features] default = ["base64", "macros", "server"] client = ["dep:tokio-stream"] diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs index c1b6edc3..7cc81575 100644 --- a/crates/rmcp/src/task_manager.rs +++ b/crates/rmcp/src/task_manager.rs @@ -173,7 +173,7 @@ impl OperationProcessor { if let Some((descriptor, result)) = operation_result { match task_result_sender.send(TaskResult { descriptor, result }) { Err(e) => { - // TODO: Produce an error message here! + tracing::error!("Failed to send completed task result: {e}"); } _ => {} } @@ -251,7 +251,7 @@ impl OperationProcessor { match self.spawn_tx.send((descriptor.clone(), work.boxed())) { Ok(_) => {} Err(e) => { - // TODO: Produce an error message! + tracing::error!("Failed to spawn task on worker: {e}"); } } diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process2.rs index 27f1eb46..c82e16db 100644 --- a/crates/rmcp/src/transport/child_process2.rs +++ b/crates/rmcp/src/transport/child_process2.rs @@ -1,2 +1,3 @@ pub mod runner; +pub mod tokio; pub mod transport; diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process2/runner.rs index d169970f..6f062065 100644 --- a/crates/rmcp/src/transport/child_process2/runner.rs +++ b/crates/rmcp/src/transport/child_process2/runner.rs @@ -52,20 +52,21 @@ pub trait ChildProcessInstance { fn take_stderr(&mut self) -> StreamSlot; fn pid(&self) -> u32; - fn wait( - &mut self, - ) -> impl Future> + Send + 'static; - fn graceful_shutdown(&mut self) -> impl Future> + Send + 'static; - fn kill(&mut self) -> impl Future> + Send + 'static; + fn wait<'s>( + &'s mut self, + ) -> impl Future> + Send + 's; + fn graceful_shutdown<'s>(&'s mut self) + -> impl Future> + Send + 's; + fn kill<'s>(&'s mut self) -> impl Future> + Send + 's; } /// A subset of functionality of [ChildProcessInstance] that only includes the /// functions used to control or wait for the process. pub trait ChildProcessControl { fn pid(&self) -> u32; - fn wait(&mut self) -> PinnedFuture<'static, std::io::Result>; - fn graceful_shutdown(&mut self) -> PinnedFuture<'static, std::io::Result<()>>; - fn kill(&mut self) -> PinnedFuture<'static, std::io::Result<()>>; + fn wait<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result>; + fn graceful_shutdown<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>>; + fn kill<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>>; } /// Auto-implement ChildProcessControl for any ChildProcessInstance, since it has all the required methods. @@ -77,23 +78,25 @@ where ChildProcessInstance::pid(self) } - fn wait(&mut self) -> PinnedFuture<'static, std::io::Result> { + fn wait<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result> { ChildProcessInstance::wait(self).boxed() } - fn graceful_shutdown(&mut self) -> PinnedFuture<'static, std::io::Result<()>> { + fn graceful_shutdown<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>> { ChildProcessInstance::graceful_shutdown(self).boxed() } - fn kill(&mut self) -> PinnedFuture<'static, std::io::Result<()>> { + fn kill<'s>(&'s mut self) -> PinnedFuture<'s, std::io::Result<()>> { ChildProcessInstance::kill(self).boxed() } } -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub enum RunnerSpawnError { /// The child process instance failed to spawn. - SpawnError(std::io::Error), + #[error("Failed to spawn child process: {0}")] + SpawnError(#[from] std::io::Error), + #[error("Other error: {0}")] Other(Box), } @@ -201,17 +204,19 @@ impl ChildProcessInstance for ChildProcess { self.inner.pid() } - fn wait( - &mut self, - ) -> impl Future> + Send + 'static { + fn wait<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { self.inner.wait() } - fn graceful_shutdown(&mut self) -> impl Future> + Send + 'static { + fn graceful_shutdown<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { self.inner.graceful_shutdown() } - fn kill(&mut self) -> impl Future> + Send + 'static { + fn kill<'s>(&'s mut self) -> impl Future> + Send + 's { self.inner.kill() } } diff --git a/crates/rmcp/src/transport/child_process2/tokio.rs b/crates/rmcp/src/transport/child_process2/tokio.rs new file mode 100644 index 00000000..dd75e0e1 --- /dev/null +++ b/crates/rmcp/src/transport/child_process2/tokio.rs @@ -0,0 +1,81 @@ +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +use crate::transport::child_process2::runner::{ + ChildProcessInstance, ChildProcessRunner, RunnerSpawnError, StdioConfig, +}; + +pub struct TokioChildProcessRunner {} + +pub struct TokioChildProcess { + inner: tokio::process::Child, +} + +impl ChildProcessInstance for TokioChildProcess { + type Stdin = Compat; + + type Stdout = Compat; + + type Stderr = Compat; + + fn take_stdin(&mut self) -> super::runner::StreamSlot { + match self.inner.stdin.take() { + Some(stdin) => super::runner::StreamSlot::Available(stdin.compat_write()), + None => super::runner::StreamSlot::Unused, + } + } + + fn take_stdout(&mut self) -> super::runner::StreamSlot { + match self.inner.stdout.take() { + Some(stdout) => super::runner::StreamSlot::Available(stdout.compat()), + None => super::runner::StreamSlot::Unused, + } + } + + fn take_stderr(&mut self) -> super::runner::StreamSlot { + match self.inner.stderr.take() { + Some(stderr) => super::runner::StreamSlot::Available(stderr.compat()), + None => super::runner::StreamSlot::Unused, + } + } + + fn pid(&self) -> u32 { + // TODO: Consider refactor to return Option to avoid confusion of 0 as a valid PID. + self.inner.id().unwrap_or(0) + } + + fn wait<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { + self.inner.wait() + } + + fn graceful_shutdown<'s>( + &'s mut self, + ) -> impl Future> + Send + 's { + // TODO: Implement graceful shutdown on unix with SIGTERM. And look into graceful shutdown on windows as well. + self.inner.kill() + } + + fn kill<'s>(&'s mut self) -> impl Future> + Send + 's { + self.inner.kill() + } +} + +impl ChildProcessRunner for TokioChildProcessRunner { + type Instance = TokioChildProcess; + fn spawn( + command: &str, + args: &[&str], + stdio_configuration: StdioConfig, + ) -> Result { + tokio::process::Command::new(command) + .args(args) + .stdin(stdio_configuration.stdin) + .stdout(stdio_configuration.stdout) + .stderr(stdio_configuration.stderr) + .kill_on_drop(true) + .spawn() + .map(|child| TokioChildProcess { inner: child }) + .map_err(RunnerSpawnError::SpawnError) + } +} diff --git a/crates/rmcp/src/transport/child_process2/transport.rs b/crates/rmcp/src/transport/child_process2/transport.rs index b378309a..a6731347 100644 --- a/crates/rmcp/src/transport/child_process2/transport.rs +++ b/crates/rmcp/src/transport/child_process2/transport.rs @@ -24,7 +24,7 @@ where R: ServiceRole, { pub fn new(child: ChildProcess) -> Result> { - let (stdout, stdin, stderr, control) = child.split(); + let (stdout, stdin, _stderr, control) = child.split(); let framed_transport: AsyncRwTransport = AsyncRwTransport::new( Box::new( diff --git a/crates/rmcp/src/transport/common.rs b/crates/rmcp/src/transport/common.rs index 615b0e27..b41a8f3c 100644 --- a/crates/rmcp/src/transport/common.rs +++ b/crates/rmcp/src/transport/common.rs @@ -4,7 +4,7 @@ pub mod server_side_http; pub mod http_header; #[cfg(feature = "__reqwest")] -mod reqwest; +pub mod reqwest; // Note: This module provides SSE stream parsing and auto-reconnect utilities. // It's used by the streamable HTTP client (which receives SSE-formatted responses), diff --git a/crates/rmcp/src/transport/common/reqwest.rs b/crates/rmcp/src/transport/common/reqwest.rs index 42075921..696aa912 100644 --- a/crates/rmcp/src/transport/common/reqwest.rs +++ b/crates/rmcp/src/transport/common/reqwest.rs @@ -1,2 +1,2 @@ #[cfg(feature = "transport-streamable-http-client-reqwest")] -mod streamable_http_client; +pub mod streamable_http_client; diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index ba28a8bb..f62708e8 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -515,13 +515,15 @@ where let session_manager = self.session_manager.clone(); let session_id = session_id.clone(); async move { - let service = serve_server::( - service, transport, - ) - .await; - match service { - Ok(service) => { + let serve_result = + serve_server::( + service, transport, + ) + .await; + match serve_result { + Ok((service, work)) => { // on service created + tokio::spawn(work); let _ = service.waiting().await; } Err(e) => { diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs index 97121ac3..912e1378 100644 --- a/crates/rmcp/src/util.rs +++ b/crates/rmcp/src/util.rs @@ -31,7 +31,7 @@ impl Sink for UnboundedSenderSink { fn poll_ready( self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.get_mut(); if this.sender.is_closed() { @@ -51,7 +51,7 @@ impl Sink for UnboundedSenderSink { fn poll_flush( self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { // tokio's unbounded mpsc senders have no flushing required, since the // receiver is unbounded and will get all messages we send (unless we run @@ -61,7 +61,7 @@ impl Sink for UnboundedSenderSink { fn poll_close( self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { // Like `poll_flush`, there is nothing to wait on here. A single // call to `mpsc_sender.send(...)` is immediate from the perspective diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index c1e5d81a..31842773 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -3,7 +3,11 @@ use rmcp::{ service::QuitReason, transport::{ ConfigureCommandExt, StreamableHttpClientTransport, StreamableHttpServerConfig, - TokioChildProcess, + child_process2::{ + runner::{ChildProcessControl, CommandBuilder}, + tokio::TokioChildProcessRunner, + transport::ChildProcessTransport, + }, streamable_http_server::{ session::local::LocalSessionManager, tower::StreamableHttpService, }, @@ -32,18 +36,26 @@ async fn test_with_js_stdio_server() -> anyhow::Result<()> { .spawn()? .wait() .await?; - let transport = - TokioChildProcess::new(tokio::process::Command::new("node").configure(|cmd| { - cmd.arg("tests/test_with_js/server.js"); - }))?; - let client = ().serve(transport).await?; + let node_cmd = CommandBuilder::::new("node") + .args(["tests/test_with_js/server.js"]) + .spawn_dyn()?; + + tracing::info!("Spawned child process with PID: {}", node_cmd.pid()); + + let transport = ChildProcessTransport::new(node_cmd) + .map_err(|e| anyhow::anyhow!("Failed to spawn child process: {e}"))?; + + let (client, work) = ().serve(transport).await?; + + tokio::spawn(work); + let resources = client.list_all_resources().await?; tracing::info!("{:#?}", resources); let tools = client.list_all_tools().await?; tracing::info!("{:#?}", tools); - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -124,12 +136,13 @@ async fn test_with_js_streamable_http_server() -> anyhow::Result<()> { // waiting for server up tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; - let client = ().serve(transport).await?; + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); let resources = client.list_all_resources().await?; tracing::info!("{:#?}", resources); let tools = client.list_all_tools().await?; tracing::info!("{:#?}", tools); - let quit_reason = client.cancel().await?; + let quit_reason = client.cancel().await; server.kill().await?; assert!(matches!(quit_reason, QuitReason::Cancelled)); Ok(()) From d0bd6ca699f500e91c0d7f043323baa38ad72fee Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 07:06:25 +0000 Subject: [PATCH 06/20] refactor(child-process): continue to build command abstraction also update some unit tests, remove old child_process --- crates/rmcp/Cargo.toml | 14 +- crates/rmcp/src/transport.rs | 8 +- crates/rmcp/src/transport/child_process.rs | 309 ------------------ crates/rmcp/src/transport/child_process2.rs | 4 +- .../src/transport/child_process2/runner.rs | 71 ++-- .../src/transport/child_process2/tokio.rs | 24 +- crates/rmcp/tests/test_with_js.rs | 4 +- crates/rmcp/tests/test_with_python.rs | 48 ++- 8 files changed, 104 insertions(+), 378 deletions(-) delete mode 100644 crates/rmcp/src/transport/child_process.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index f8e8011e..d08951d0 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -123,10 +123,11 @@ transport-streamable-http-client-reqwest = [ transport-async-rw = ["tokio/io-util", "tokio-util/codec", "tokio-util/compat"] transport-io = ["transport-async-rw", "tokio/io-std"] -transport-child-process = [ +transport-child-process = ["transport-async-rw", "tokio/process"] +transport-child-process-tokio = [ "transport-async-rw", "tokio/process", - "dep:process-wrap", + "tokio/rt", ] transport-streamable-http-server = [ "transport-streamable-http-server-session", @@ -163,7 +164,13 @@ path = "tests/test_tool_macros.rs" [[test]] name = "test_with_python" -required-features = ["reqwest", "server", "client", "transport-child-process"] +required-features = [ + "reqwest", + "server", + "client", + "transport-child-process", + "transport-child-process-tokio", +] path = "tests/test_with_python.rs" [[test]] @@ -172,6 +179,7 @@ required-features = [ "server", "client", "transport-child-process", + "transport-child-process-tokio", "transport-streamable-http-server", "transport-streamable-http-client", "__reqwest", diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index f22b63d8..99f84c05 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -81,12 +81,12 @@ pub mod worker; #[cfg(feature = "transport-worker")] pub use worker::WorkerTransport; -#[cfg(feature = "transport-child-process")] -pub mod child_process; -#[cfg(feature = "transport-child-process")] -pub use child_process::{ConfigureCommandExt, TokioChildProcess}; #[cfg(feature = "transport-child-process")] pub mod child_process2; +#[cfg(feature = "transport-child-process")] +pub use child_process2::runner::{ + ChildProcess, ChildProcessInstance, ChildProcessRunner, CommandBuilder, +}; #[cfg(feature = "transport-io")] pub mod io; diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs deleted file mode 100644 index 58e8b9a8..00000000 --- a/crates/rmcp/src/transport/child_process.rs +++ /dev/null @@ -1,309 +0,0 @@ -use std::process::Stdio; - -use futures::future::Future; -use process_wrap::tokio::{ChildWrapper, CommandWrap}; -use tokio::{ - io::AsyncRead, - process::{ChildStderr, ChildStdin, ChildStdout}, -}; - -use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage, async_rw::AsyncRwTransport}; -use crate::RoleClient; - -const MAX_WAIT_ON_DROP_SECS: u64 = 3; -/// The parts of a child process. -type ChildProcessParts = ( - Box, - ChildStdout, - ChildStdin, - Option, -); - -/// Extract the stdio handles from a spawned child. -/// Returns `(child, stdout, stdin, stderr)` where `stderr` is `Some` only -/// if the process was spawned with `Stdio::piped()`. -#[inline] -fn split_child_process(mut child: Box) -> std::io::Result { - let child_stdin = match child.inner_mut().stdin().take() { - Some(stdin) => stdin, - None => return Err(std::io::Error::other("stdin was already taken")), - }; - let child_stdout = match child.inner_mut().stdout().take() { - Some(stdout) => stdout, - None => return Err(std::io::Error::other("stdout was already taken")), - }; - let child_stderr = child.inner_mut().stderr().take(); - Ok((child, child_stdout, child_stdin, child_stderr)) -} - -pub struct TokioChildProcess { - child: ChildWithCleanup, - transport: AsyncRwTransport, -} - -pub struct ChildWithCleanup { - inner: Option>, -} - -impl Drop for ChildWithCleanup { - fn drop(&mut self) { - // We should not use start_kill(), instead we should use kill() to avoid zombies - if let Some(mut inner) = self.inner.take() { - // We don't care about the result, just try to kill it - tokio::spawn(async move { - if let Err(e) = Box::into_pin(inner.kill()).await { - tracing::warn!("Error killing child process: {}", e); - } - }); - } - } -} - -// we hold the child process with stdout, for it's easier to implement AsyncRead -pin_project_lite::pin_project! { - pub struct TokioChildProcessOut { - child: ChildWithCleanup, - #[pin] - child_stdout: ChildStdout, - } -} - -impl TokioChildProcessOut { - /// Get the process ID of the child process. - pub fn id(&self) -> Option { - self.child.inner.as_ref()?.id() - } -} - -impl AsyncRead for TokioChildProcessOut { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - self.project().child_stdout.poll_read(cx, buf) - } -} - -impl TokioChildProcess { - /// Convenience: spawn with default `piped` stdio - pub fn new(command: impl Into) -> std::io::Result { - let (proc, _ignored) = TokioChildProcessBuilder::new(command).spawn()?; - Ok(proc) - } - - /// Builder entry-point allowing fine-grained stdio control. - pub fn builder(command: impl Into) -> TokioChildProcessBuilder { - TokioChildProcessBuilder::new(command) - } - - /// Get the process ID of the child process. - pub fn id(&self) -> Option { - self.child.inner.as_ref()?.id() - } - - /// Gracefully shutdown the child process - /// - /// This will first close the transport to the child process (the server), - /// and wait for the child process to exit normally with a timeout. - /// If the child process doesn't exit within the timeout, it will be killed. - pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> { - if let Some(mut child) = self.child.inner.take() { - self.transport.close().await?; - - let wait_fut = child.wait(); - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => { - if let Err(e) = Box::into_pin(child.kill()).await { - tracing::warn!("Error killing child: {e}"); - return Err(e); - } - }, - res = wait_fut => { - match res { - Ok(status) => { - tracing::info!("Child exited gracefully {}", status); - } - Err(e) => { - tracing::warn!("Error waiting for child: {e}"); - return Err(e); - } - } - } - } - } - Ok(()) - } - - /// Take ownership of the inner child process - pub fn into_inner(mut self) -> Option> { - self.child.inner.take() - } - - /// Split this helper into a reader (stdout) and writer (stdin). - #[deprecated( - since = "0.5.0", - note = "use the Transport trait implementation instead" - )] - pub fn split(self) -> (TokioChildProcessOut, ChildStdin) { - unimplemented!("This method is deprecated, use the Transport trait implementation instead"); - } -} - -/// Builder for `TokioChildProcess` allowing custom `Stdio` configuration. -pub struct TokioChildProcessBuilder { - cmd: CommandWrap, - stdin: Stdio, - stdout: Stdio, - stderr: Stdio, -} - -impl TokioChildProcessBuilder { - fn new(cmd: impl Into) -> Self { - Self { - cmd: cmd.into(), - stdin: Stdio::piped(), - stdout: Stdio::piped(), - stderr: Stdio::inherit(), - } - } - - /// Override the child stdin configuration. - pub fn stdin(mut self, io: impl Into) -> Self { - self.stdin = io.into(); - self - } - /// Override the child stdout configuration. - pub fn stdout(mut self, io: impl Into) -> Self { - self.stdout = io.into(); - self - } - /// Override the child stderr configuration. - pub fn stderr(mut self, io: impl Into) -> Self { - self.stderr = io.into(); - self - } - - /// Spawn the child process. Returns the transport plus an optional captured stderr handle. - pub fn spawn(mut self) -> std::io::Result<(TokioChildProcess, Option)> { - self.cmd - .command_mut() - .stdin(self.stdin) - .stdout(self.stdout) - .stderr(self.stderr); - - let (child, stdout, stdin, stderr_opt) = split_child_process(self.cmd.spawn()?)?; - - let transport = AsyncRwTransport::new(stdout, stdin); - let proc = TokioChildProcess { - child: ChildWithCleanup { inner: Some(child) }, - transport, - }; - Ok((proc, stderr_opt)) - } -} - -impl Transport for TokioChildProcess { - type Error = std::io::Error; - - fn send( - &mut self, - item: TxJsonRpcMessage, - ) -> impl Future> + Send + 'static { - self.transport.send(item) - } - - fn receive(&mut self) -> impl Future>> + Send { - self.transport.receive() - } - - fn close(&mut self) -> impl Future> + Send { - self.graceful_shutdown() - } -} - -pub trait ConfigureCommandExt { - fn configure(self, f: impl FnOnce(&mut Self)) -> Self; -} - -impl ConfigureCommandExt for tokio::process::Command { - fn configure(mut self, f: impl FnOnce(&mut Self)) -> Self { - f(&mut self); - self - } -} - -#[cfg(unix)] -#[cfg(test)] -mod tests { - use tokio::process::Command; - - use super::*; - - #[tokio::test] - async fn test_tokio_child_process_drop() { - let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { - cmd.arg("30"); - })); - assert!(r.is_ok()); - let child_process = r.unwrap(); - let id = child_process.id(); - assert!(id.is_some()); - let id = id.unwrap(); - // Drop the child process - drop(child_process); - // Wait a moment to allow the cleanup task to run - tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; - // Check if the process is still running - let status = Command::new("ps") - .arg("-p") - .arg(id.to_string()) - .status() - .await; - match status { - Ok(status) => { - assert!( - !status.success(), - "Process with PID {} is still running", - id - ); - } - Err(e) => { - panic!("Failed to check process status: {}", e); - } - } - } - - #[tokio::test] - async fn test_tokio_child_process_graceful_shutdown() { - let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { - cmd.arg("30"); - })); - assert!(r.is_ok()); - let mut child_process = r.unwrap(); - let id = child_process.id(); - assert!(id.is_some()); - let id = id.unwrap(); - child_process.graceful_shutdown().await.unwrap(); - // Wait a moment to allow the cleanup task to run - tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; - // Check if the process is still running - let status = Command::new("ps") - .arg("-p") - .arg(id.to_string()) - .status() - .await; - match status { - Ok(status) => { - assert!( - !status.success(), - "Process with PID {} is still running", - id - ); - } - Err(e) => { - panic!("Failed to check process status: {}", e); - } - } - } -} diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process2.rs index c82e16db..7e74551d 100644 --- a/crates/rmcp/src/transport/child_process2.rs +++ b/crates/rmcp/src/transport/child_process2.rs @@ -1,3 +1,5 @@ pub mod runner; -pub mod tokio; pub mod transport; + +#[cfg(feature = "transport-child-process-tokio")] +pub mod tokio; diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process2/runner.rs index 6f062065..784f1697 100644 --- a/crates/rmcp/src/transport/child_process2/runner.rs +++ b/crates/rmcp/src/transport/child_process2/runner.rs @@ -2,7 +2,7 @@ use futures::{ FutureExt, io::{AsyncRead, AsyncWrite}, }; -use std::process::Stdio; +use std::{path::PathBuf, process::Stdio}; use crate::util::PinnedFuture; @@ -104,11 +104,7 @@ pub trait ChildProcessRunner { /// The implementation of the child process instance that this runner will spawn. type Instance: ChildProcessInstance; - fn spawn( - command: &str, - args: &[&str], - stdio_config: StdioConfig, - ) -> Result; + fn spawn(command_config: CommandConfig) -> Result; } /// A containing wrapper around a child process instance. This struct erases the type @@ -222,10 +218,8 @@ impl ChildProcessInstance for ChildProcess { } pub struct CommandBuilder { - command: String, - args: Vec, + config: CommandConfig, _marker: std::marker::PhantomData, - stderr: Stdio, } pub enum CommandBuilderError { @@ -249,61 +243,78 @@ impl CommandBuilder { let args = iter.map(|s| s.into()).collect(); Ok(Self { - command, - args, + config: CommandConfig { + command, + args, + cwd: None, + stdio_config: StdioConfig { + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: Stdio::inherit(), + }, + }, _marker: std::marker::PhantomData, - stderr: Stdio::inherit(), }) } /// Create a CommandBuilder from a command and an optional list of args. pub fn new(command: impl Into) -> Self { Self { - command: command.into(), - args: Vec::new(), + config: CommandConfig { + command: command.into(), + args: Vec::new(), + cwd: None, + stdio_config: StdioConfig { + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: Stdio::inherit(), + }, + }, _marker: std::marker::PhantomData, - stderr: Stdio::inherit(), } } /// Add a single argument to the command. pub fn arg(mut self, arg: impl Into) -> Self { - self.args.push(arg.into()); + self.config.args.push(arg.into()); self } /// Add multiple arguments to the command. pub fn args(mut self, args: impl IntoIterator>) -> Self { - self.args.extend(args.into_iter().map(|arg| arg.into())); + self.config + .args + .extend(args.into_iter().map(|arg| arg.into())); self } /// Sets what happens to stderr for the command. /// By default if not set, stderr is inherited from the parent process. pub fn stderr(mut self, _stdio: Stdio) -> Self { - self.stderr = _stdio; + self.config.stdio_config.stderr = _stdio; + self + } + + pub fn current_dir(mut self, cwd: impl Into) -> Self { + self.config.cwd = Some(cwd.into()); self } } +pub struct CommandConfig { + pub command: String, + pub args: Vec, + pub cwd: Option, + pub stdio_config: StdioConfig, +} + impl CommandBuilder where R: ChildProcessRunner, { /// Spawn the command into its typed child process instance type. pub fn spawn_raw(self) -> Result { - // We should always pipe stdin and stdout. - let stdio_config = StdioConfig { - stdin: Stdio::piped(), - stdout: Stdio::piped(), - stderr: self.stderr, - }; - - R::spawn( - &self.command, - &self.args.iter().map(|s| s.as_str()).collect::>(), - stdio_config, - ) + R::spawn(self.config) } /// Spawn a child process struct that erases the specific child process instance type, and only exposes the control methods. diff --git a/crates/rmcp/src/transport/child_process2/tokio.rs b/crates/rmcp/src/transport/child_process2/tokio.rs index dd75e0e1..0beed76c 100644 --- a/crates/rmcp/src/transport/child_process2/tokio.rs +++ b/crates/rmcp/src/transport/child_process2/tokio.rs @@ -1,11 +1,12 @@ use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use crate::transport::child_process2::runner::{ - ChildProcessInstance, ChildProcessRunner, RunnerSpawnError, StdioConfig, + ChildProcessInstance, ChildProcessRunner, CommandConfig, RunnerSpawnError, }; pub struct TokioChildProcessRunner {} +/// An implementation for the tokio Child Process pub struct TokioChildProcess { inner: tokio::process::Child, } @@ -63,16 +64,17 @@ impl ChildProcessInstance for TokioChildProcess { impl ChildProcessRunner for TokioChildProcessRunner { type Instance = TokioChildProcess; - fn spawn( - command: &str, - args: &[&str], - stdio_configuration: StdioConfig, - ) -> Result { - tokio::process::Command::new(command) - .args(args) - .stdin(stdio_configuration.stdin) - .stdout(stdio_configuration.stdout) - .stderr(stdio_configuration.stderr) + fn spawn(command_config: CommandConfig) -> Result { + tokio::process::Command::new(command_config.command) + .args(command_config.args) + .stdin(command_config.stdio_config.stdin) + .stdout(command_config.stdio_config.stdout) + .stderr(command_config.stdio_config.stderr) + .current_dir( + command_config + .cwd + .unwrap_or_else(|| std::env::current_dir().unwrap()), + ) .kill_on_drop(true) .spawn() .map(|child| TokioChildProcess { inner: child }) diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 31842773..6e39b5da 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -2,7 +2,7 @@ use rmcp::{ ServiceExt, service::QuitReason, transport::{ - ConfigureCommandExt, StreamableHttpClientTransport, StreamableHttpServerConfig, + StreamableHttpClientTransport, StreamableHttpServerConfig, child_process2::{ runner::{ChildProcessControl, CommandBuilder}, tokio::TokioChildProcessRunner, @@ -44,7 +44,7 @@ async fn test_with_js_stdio_server() -> anyhow::Result<()> { tracing::info!("Spawned child process with PID: {}", node_cmd.pid()); let transport = ChildProcessTransport::new(node_cmd) - .map_err(|e| anyhow::anyhow!("Failed to spawn child process: {e}"))?; + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; let (client, work) = ().serve(transport).await?; diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index 3f883c96..014b959d 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -1,10 +1,16 @@ use std::process::Stdio; +use futures::AsyncReadExt; use rmcp::{ ServiceExt, - transport::{ConfigureCommandExt, TokioChildProcess}, + transport::{ + ChildProcess, ChildProcessInstance, + child_process2::{ + runner::CommandBuilder, tokio::TokioChildProcessRunner, + transport::ChildProcessTransport, + }, + }, }; -use tokio::io::AsyncReadExt; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod common; @@ -29,18 +35,21 @@ async fn init() -> anyhow::Result<()> { async fn test_with_python_server() -> anyhow::Result<()> { init().await?; - let transport = TokioChildProcess::new(tokio::process::Command::new("uv").configure(|cmd| { - cmd.arg("run") - .arg("server.py") - .current_dir("tests/test_with_python"); - }))?; + let server_command = CommandBuilder::::new("uv") + .args(["run", "server.py"]) + .current_dir("tests/test_with_python") + .spawn_dyn()?; + + let transport = ChildProcessTransport::new(server_command) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; - let client = ().serve(transport).await?; + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); let resources = client.list_all_resources().await?; tracing::info!("{:#?}", resources); let tools = client.list_all_tools().await?; tracing::info!("{:#?}", tools); - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -48,15 +57,14 @@ async fn test_with_python_server() -> anyhow::Result<()> { async fn test_with_python_server_stderr() -> anyhow::Result<()> { init().await?; - let (transport, stderr) = - TokioChildProcess::builder(tokio::process::Command::new("uv").configure(|cmd| { - cmd.arg("run") - .arg("server.py") - .current_dir("tests/test_with_python"); - })) + let mut server_command = CommandBuilder::::new("uv") + .args(["run", "server.py"]) + .current_dir("tests/test_with_python") .stderr(Stdio::piped()) - .spawn()?; + .spawn_dyn()?; + let stderr: Option<::Stderr> = + server_command.take_stderr().into(); let mut stderr = stderr.expect("stderr must be piped"); let stderr_task = tokio::spawn(async move { @@ -65,10 +73,14 @@ async fn test_with_python_server_stderr() -> anyhow::Result<()> { Ok::<_, std::io::Error>(buffer) }); - let client = ().serve(transport).await?; + let transport = ChildProcessTransport::new(server_command) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; + + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); let _ = client.list_all_resources().await?; let _ = client.list_all_tools().await?; - client.cancel().await?; + client.cancel().await; let stderr_output = stderr_task.await??; assert!(stderr_output.contains("server starting up...")); From 02b53cd500705864fc61e42dd40776901b4f389c Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 07:18:22 +0000 Subject: [PATCH 07/20] refactor(child-process): add env to command, move builder to separate file --- crates/rmcp/src/transport.rs | 6 +- crates/rmcp/src/transport/child_process2.rs | 1 + .../src/transport/child_process2/builder.rs | 149 ++++++++++++++++++ .../src/transport/child_process2/runner.rs | 123 +-------------- .../src/transport/child_process2/tokio.rs | 6 +- crates/rmcp/tests/test_with_python.rs | 2 +- 6 files changed, 159 insertions(+), 128 deletions(-) create mode 100644 crates/rmcp/src/transport/child_process2/builder.rs diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 99f84c05..c23f3a07 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -84,9 +84,9 @@ pub use worker::WorkerTransport; #[cfg(feature = "transport-child-process")] pub mod child_process2; #[cfg(feature = "transport-child-process")] -pub use child_process2::runner::{ - ChildProcess, ChildProcessInstance, ChildProcessRunner, CommandBuilder, -}; +pub use child_process2::builder::CommandBuilder; +#[cfg(feature = "transport-child-process")] +pub use child_process2::runner::{ChildProcess, ChildProcessInstance, ChildProcessRunner}; #[cfg(feature = "transport-io")] pub mod io; diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process2.rs index 7e74551d..44db4a45 100644 --- a/crates/rmcp/src/transport/child_process2.rs +++ b/crates/rmcp/src/transport/child_process2.rs @@ -1,3 +1,4 @@ +pub mod builder; pub mod runner; pub mod transport; diff --git a/crates/rmcp/src/transport/child_process2/builder.rs b/crates/rmcp/src/transport/child_process2/builder.rs new file mode 100644 index 00000000..f1216dac --- /dev/null +++ b/crates/rmcp/src/transport/child_process2/builder.rs @@ -0,0 +1,149 @@ +use std::{collections::HashMap, hash::Hash, path::PathBuf, process::Stdio}; + +use crate::transport::{ + ChildProcess, ChildProcessRunner, child_process2::runner::RunnerSpawnError, +}; + +/// A builder for constructing a command to spawn a child process, with typical command +/// configuration options like `args` and `current_dir`. +pub struct CommandBuilder { + config: CommandConfig, + _marker: std::marker::PhantomData, +} + +#[derive(Debug, thiserror::Error)] +pub enum CommandBuilderError { + #[error("Command cannot be empty")] + EmptyCommand, +} + +impl CommandBuilder { + /// Create a CommandBuilder from an argv-style list of strings, where the first element is the command, and the rest are the args. + pub fn from_argv(argv: I) -> Result + where + I: IntoIterator, + S: Into, + { + let mut iter = argv.into_iter(); + + // Pop the first element as the command, and use the rest as args + let command = match iter.next() { + Some(cmd) => cmd.into(), + None => return Err(CommandBuilderError::EmptyCommand), + }; + + let args = iter.map(|s| s.into()).collect(); + Ok(Self { + config: CommandConfig { + command, + args, + ..Default::default() + }, + _marker: std::marker::PhantomData, + }) + } + + /// Create a CommandBuilder from a command and an optional list of args. + pub fn new(command: impl Into) -> Self { + Self { + config: CommandConfig { + command: command.into(), + ..Default::default() + }, + _marker: std::marker::PhantomData, + } + } + + /// Add a single argument to the command. + pub fn arg(mut self, arg: impl Into) -> Self { + self.config.args.push(arg.into()); + self + } + + /// Add multiple arguments to the command. + pub fn args(mut self, args: impl IntoIterator>) -> Self { + self.config + .args + .extend(args.into_iter().map(|arg| arg.into())); + self + } + + /// Set an environment variable for the command. + pub fn env(mut self, key: impl Into, value: impl Into) -> Self { + self.config.env.insert(key.into(), value.into()); + self + } + + /// Set multiple environment variables for the command. + pub fn envs( + mut self, + envs: impl IntoIterator, impl Into)>, + ) -> Self { + self.config + .env + .extend(envs.into_iter().map(|(k, v)| (k.into(), v.into()))); + self + } + + /// Sets what happens to stderr for the command. + /// By default if not set, stderr is inherited from the parent process. + pub fn stderr(mut self, _stdio: Stdio) -> Self { + self.config.stdio_config.stderr = _stdio; + self + } + + pub fn current_dir(mut self, cwd: impl Into) -> Self { + self.config.cwd = Some(cwd.into()); + self + } +} + +/// A structure that requests how the child process streams should +/// be configured when spawning. +#[derive(Debug)] +pub struct StdioConfig { + pub stdin: Stdio, + pub stdout: Stdio, + pub stderr: Stdio, +} + +impl Default for StdioConfig { + fn default() -> Self { + Self { + stdin: Stdio::piped(), + stdout: Stdio::piped(), + stderr: Stdio::inherit(), + } + } +} + +/// A structure that requests how the command should be executed +#[derive(Debug, Default)] +pub struct CommandConfig { + pub command: String, + pub args: Vec, + pub cwd: Option, + pub stdio_config: StdioConfig, + pub env: HashMap, +} + +impl CommandBuilder +where + R: ChildProcessRunner, +{ + /// Spawn the command into its typed child process instance type. + pub fn spawn_raw(self) -> Result { + R::spawn(self.config) + } + + /// Spawn a child process struct that erases the specific child process instance type, and only exposes the control methods. + /// + /// Requires `R::Instance` to be [Send] and `'static`. + pub fn spawn_dyn(self) -> Result + where + R::Instance: Send + 'static, + { + let instance = self.spawn_raw()?; + Ok(ChildProcess::new(instance)) + } +} diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process2/runner.rs index 784f1697..f447d9aa 100644 --- a/crates/rmcp/src/transport/child_process2/runner.rs +++ b/crates/rmcp/src/transport/child_process2/runner.rs @@ -2,9 +2,8 @@ use futures::{ FutureExt, io::{AsyncRead, AsyncWrite}, }; -use std::{path::PathBuf, process::Stdio}; -use crate::util::PinnedFuture; +use crate::{transport::child_process2::builder::CommandConfig, util::PinnedFuture}; /// A simple enum for describing if a stream is available, unused, or already taken. #[derive(Debug)] @@ -27,14 +26,6 @@ impl From> for Option { } } -/// A structure that requests how the child process streams should -/// be configured when spawning. -pub struct StdioConfig { - pub stdin: Stdio, - pub stdout: Stdio, - pub stderr: Stdio, -} - /// The contract for what an instance of a child process /// must provide to be used with a transport. pub trait ChildProcessInstance { @@ -216,115 +207,3 @@ impl ChildProcessInstance for ChildProcess { self.inner.kill() } } - -pub struct CommandBuilder { - config: CommandConfig, - _marker: std::marker::PhantomData, -} - -pub enum CommandBuilderError { - EmptyCommand, -} - -impl CommandBuilder { - /// Create a CommandBuilder from an argv-style list of strings, where the first element is the command, and the rest are the args. - pub fn from_argv(argv: I) -> Result - where - I: IntoIterator, - S: Into, - { - let mut iter = argv.into_iter(); - - // Pop the first element as the command, and use the rest as args - let command = match iter.next() { - Some(cmd) => cmd.into(), - None => return Err(CommandBuilderError::EmptyCommand), - }; - - let args = iter.map(|s| s.into()).collect(); - Ok(Self { - config: CommandConfig { - command, - args, - cwd: None, - stdio_config: StdioConfig { - stdin: Stdio::piped(), - stdout: Stdio::piped(), - stderr: Stdio::inherit(), - }, - }, - _marker: std::marker::PhantomData, - }) - } - - /// Create a CommandBuilder from a command and an optional list of args. - pub fn new(command: impl Into) -> Self { - Self { - config: CommandConfig { - command: command.into(), - args: Vec::new(), - cwd: None, - stdio_config: StdioConfig { - stdin: Stdio::piped(), - stdout: Stdio::piped(), - stderr: Stdio::inherit(), - }, - }, - _marker: std::marker::PhantomData, - } - } - - /// Add a single argument to the command. - pub fn arg(mut self, arg: impl Into) -> Self { - self.config.args.push(arg.into()); - self - } - - /// Add multiple arguments to the command. - pub fn args(mut self, args: impl IntoIterator>) -> Self { - self.config - .args - .extend(args.into_iter().map(|arg| arg.into())); - self - } - - /// Sets what happens to stderr for the command. - /// By default if not set, stderr is inherited from the parent process. - pub fn stderr(mut self, _stdio: Stdio) -> Self { - self.config.stdio_config.stderr = _stdio; - self - } - - pub fn current_dir(mut self, cwd: impl Into) -> Self { - self.config.cwd = Some(cwd.into()); - self - } -} - -pub struct CommandConfig { - pub command: String, - pub args: Vec, - pub cwd: Option, - pub stdio_config: StdioConfig, -} - -impl CommandBuilder -where - R: ChildProcessRunner, -{ - /// Spawn the command into its typed child process instance type. - pub fn spawn_raw(self) -> Result { - R::spawn(self.config) - } - - /// Spawn a child process struct that erases the specific child process instance type, and only exposes the control methods. - /// - /// Requires `R::Instance` to be [Send] and `'static`. - pub fn spawn_dyn(self) -> Result - where - R::Instance: Send + 'static, - { - let instance = self.spawn_raw()?; - Ok(ChildProcess::new(instance)) - } -} diff --git a/crates/rmcp/src/transport/child_process2/tokio.rs b/crates/rmcp/src/transport/child_process2/tokio.rs index 0beed76c..4733d457 100644 --- a/crates/rmcp/src/transport/child_process2/tokio.rs +++ b/crates/rmcp/src/transport/child_process2/tokio.rs @@ -1,7 +1,8 @@ use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; -use crate::transport::child_process2::runner::{ - ChildProcessInstance, ChildProcessRunner, CommandConfig, RunnerSpawnError, +use crate::transport::child_process2::{ + builder::CommandConfig, + runner::{ChildProcessInstance, ChildProcessRunner, RunnerSpawnError}, }; pub struct TokioChildProcessRunner {} @@ -67,6 +68,7 @@ impl ChildProcessRunner for TokioChildProcessRunner { fn spawn(command_config: CommandConfig) -> Result { tokio::process::Command::new(command_config.command) .args(command_config.args) + .envs(command_config.env) .stdin(command_config.stdio_config.stdin) .stdout(command_config.stdio_config.stdout) .stderr(command_config.stdio_config.stderr) diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index 014b959d..d880eb03 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -6,7 +6,7 @@ use rmcp::{ transport::{ ChildProcess, ChildProcessInstance, child_process2::{ - runner::CommandBuilder, tokio::TokioChildProcessRunner, + builder::CommandBuilder, tokio::TokioChildProcessRunner, transport::ChildProcessTransport, }, }, From 86977b678b0d0118fec8be6b1ed25f0917604c4d Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 07:32:01 +0000 Subject: [PATCH 08/20] refactor(example): fix example compilation --- crates/rmcp/src/error.rs | 2 -- crates/rmcp/src/handler/client/progress.rs | 4 +-- crates/rmcp/src/service.rs | 1 - .../src/transport/child_process2/builder.rs | 2 +- .../src/transport/child_process2/transport.rs | 4 +-- examples/simple-chat-client/Cargo.toml | 3 +- examples/simple-chat-client/src/config.rs | 34 ++++++++++++------- 7 files changed, 28 insertions(+), 22 deletions(-) diff --git a/crates/rmcp/src/error.rs b/crates/rmcp/src/error.rs index c7901f4b..3bd528f4 100644 --- a/crates/rmcp/src/error.rs +++ b/crates/rmcp/src/error.rs @@ -30,8 +30,6 @@ pub enum RmcpError { #[cfg(feature = "server")] #[error("Server initialization error: {0}")] ServerInitialize(#[from] crate::service::ServerInitializeError), - #[error("Runtime error: {0}")] - Runtime(#[from] tokio::task::JoinError), #[error("Transport creation error: {error}")] // TODO: Maybe we can introduce something like `TryIntoTransport` to auto wrap transport type, // but it could be an breaking change, so we could do it in the future. diff --git a/crates/rmcp/src/handler/client/progress.rs b/crates/rmcp/src/handler/client/progress.rs index 7dd84f08..892de197 100644 --- a/crates/rmcp/src/handler/client/progress.rs +++ b/crates/rmcp/src/handler/client/progress.rs @@ -1,7 +1,5 @@ -use std::{collections::HashMap, sync::Arc}; - use futures::{Stream, StreamExt}; -use tokio::sync::{RwLock, broadcast}; +use tokio::sync::broadcast; use tokio_stream::wrappers::BroadcastStream; use crate::{ diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index e3fae794..f4374dea 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -598,7 +598,6 @@ impl RunningServiceCancellationToken { pub enum QuitReason { Cancelled, Closed, - JoinError(tokio::task::JoinError), } /// Request execution context diff --git a/crates/rmcp/src/transport/child_process2/builder.rs b/crates/rmcp/src/transport/child_process2/builder.rs index f1216dac..cf1a2e11 100644 --- a/crates/rmcp/src/transport/child_process2/builder.rs +++ b/crates/rmcp/src/transport/child_process2/builder.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, hash::Hash, path::PathBuf, process::Stdio}; +use std::{collections::HashMap, path::PathBuf, process::Stdio}; use crate::transport::{ ChildProcess, ChildProcessRunner, child_process2::runner::RunnerSpawnError, diff --git a/crates/rmcp/src/transport/child_process2/transport.rs b/crates/rmcp/src/transport/child_process2/transport.rs index a6731347..30f15a2d 100644 --- a/crates/rmcp/src/transport/child_process2/transport.rs +++ b/crates/rmcp/src/transport/child_process2/transport.rs @@ -11,7 +11,7 @@ use crate::{ }; pub struct ChildProcessTransport { - child: Box, + _child: Box, framed_transport: AsyncRwTransport< R, Box, @@ -40,7 +40,7 @@ where ); Ok(Self { - child: control, + _child: control, framed_transport, }) } diff --git a/examples/simple-chat-client/Cargo.toml b/examples/simple-chat-client/Cargo.toml index e382e63c..ab24878c 100644 --- a/examples/simple-chat-client/Cargo.toml +++ b/examples/simple-chat-client/Cargo.toml @@ -17,6 +17,7 @@ toml = "1.0" rmcp = { workspace = true, features = [ "client", "transport-child-process", - "transport-streamable-http-client-reqwest" + "transport-child-process-tokio", + "transport-streamable-http-client-reqwest", ] } clap = { version = "4.0", features = ["derive"] } diff --git a/examples/simple-chat-client/src/config.rs b/examples/simple-chat-client/src/config.rs index e469c280..6d0292b3 100644 --- a/examples/simple-chat-client/src/config.rs +++ b/examples/simple-chat-client/src/config.rs @@ -1,7 +1,14 @@ -use std::{collections::HashMap, path::Path, process::Stdio}; +use std::{collections::HashMap, path::Path}; use anyhow::Result; -use rmcp::{RoleClient, ServiceExt, service::RunningService, transport::ConfigureCommandExt}; +use rmcp::{ + RoleClient, ServiceExt, + service::RunningService, + transport::{ + CommandBuilder, + child_process2::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + }, +}; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] @@ -47,22 +54,25 @@ impl McpServerTransportConfig { McpServerTransportConfig::Streamable { url } => { let transport = rmcp::transport::StreamableHttpClientTransport::from_uri(url.to_string()); - ().serve(transport).await? + let (service, work) = ().serve(transport).await?; + tokio::spawn(work); + service } McpServerTransportConfig::Stdio { command, args, envs, } => { - let transport = rmcp::transport::child_process::TokioChildProcess::new( - tokio::process::Command::new(command).configure(|cmd| { - cmd.args(args) - .envs(envs) - .stderr(Stdio::inherit()) - .stdout(Stdio::inherit()); - }), - )?; - ().serve(transport).await? + let cmd = CommandBuilder::::new(command) + .args(args) + .envs(envs) + .spawn_dyn()?; + + let transport = ChildProcessTransport::new(cmd) + .map_err(|e| anyhow::anyhow!("Failed to wrap child process: {e}"))?; + let (service, work) = ().serve(transport).await?; + tokio::spawn(work); + service } }; Ok(client) From b746384d19e4cbde6d0c9d748165c3f571aa3317 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 07:34:11 +0000 Subject: [PATCH 09/20] refactor(child-process): rename module back to "child-process" --- crates/rmcp/src/transport.rs | 6 +++--- .../src/transport/{child_process2.rs => child_process.rs} | 0 .../transport/{child_process2 => child_process}/builder.rs | 4 +--- .../transport/{child_process2 => child_process}/runner.rs | 2 +- .../transport/{child_process2 => child_process}/tokio.rs | 2 +- .../{child_process2 => child_process}/transport.rs | 2 +- crates/rmcp/tests/test_with_js.rs | 2 +- crates/rmcp/tests/test_with_python.rs | 2 +- examples/simple-chat-client/src/config.rs | 2 +- 9 files changed, 10 insertions(+), 12 deletions(-) rename crates/rmcp/src/transport/{child_process2.rs => child_process.rs} (100%) rename crates/rmcp/src/transport/{child_process2 => child_process}/builder.rs (97%) rename crates/rmcp/src/transport/{child_process2 => child_process}/runner.rs (98%) rename crates/rmcp/src/transport/{child_process2 => child_process}/tokio.rs (98%) rename crates/rmcp/src/transport/{child_process2 => child_process}/transport.rs (96%) diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index c23f3a07..de12321c 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -82,11 +82,11 @@ pub mod worker; pub use worker::WorkerTransport; #[cfg(feature = "transport-child-process")] -pub mod child_process2; +pub mod child_process; #[cfg(feature = "transport-child-process")] -pub use child_process2::builder::CommandBuilder; +pub use child_process::builder::CommandBuilder; #[cfg(feature = "transport-child-process")] -pub use child_process2::runner::{ChildProcess, ChildProcessInstance, ChildProcessRunner}; +pub use child_process::runner::{ChildProcess, ChildProcessInstance, ChildProcessRunner}; #[cfg(feature = "transport-io")] pub mod io; diff --git a/crates/rmcp/src/transport/child_process2.rs b/crates/rmcp/src/transport/child_process.rs similarity index 100% rename from crates/rmcp/src/transport/child_process2.rs rename to crates/rmcp/src/transport/child_process.rs diff --git a/crates/rmcp/src/transport/child_process2/builder.rs b/crates/rmcp/src/transport/child_process/builder.rs similarity index 97% rename from crates/rmcp/src/transport/child_process2/builder.rs rename to crates/rmcp/src/transport/child_process/builder.rs index cf1a2e11..e23303ae 100644 --- a/crates/rmcp/src/transport/child_process2/builder.rs +++ b/crates/rmcp/src/transport/child_process/builder.rs @@ -1,8 +1,6 @@ use std::{collections::HashMap, path::PathBuf, process::Stdio}; -use crate::transport::{ - ChildProcess, ChildProcessRunner, child_process2::runner::RunnerSpawnError, -}; +use crate::transport::{ChildProcess, ChildProcessRunner, child_process::runner::RunnerSpawnError}; /// A builder for constructing a command to spawn a child process, with typical command /// configuration options like `args` and `current_dir`. diff --git a/crates/rmcp/src/transport/child_process2/runner.rs b/crates/rmcp/src/transport/child_process/runner.rs similarity index 98% rename from crates/rmcp/src/transport/child_process2/runner.rs rename to crates/rmcp/src/transport/child_process/runner.rs index f447d9aa..7c2bb3f4 100644 --- a/crates/rmcp/src/transport/child_process2/runner.rs +++ b/crates/rmcp/src/transport/child_process/runner.rs @@ -3,7 +3,7 @@ use futures::{ io::{AsyncRead, AsyncWrite}, }; -use crate::{transport::child_process2::builder::CommandConfig, util::PinnedFuture}; +use crate::{transport::child_process::builder::CommandConfig, util::PinnedFuture}; /// A simple enum for describing if a stream is available, unused, or already taken. #[derive(Debug)] diff --git a/crates/rmcp/src/transport/child_process2/tokio.rs b/crates/rmcp/src/transport/child_process/tokio.rs similarity index 98% rename from crates/rmcp/src/transport/child_process2/tokio.rs rename to crates/rmcp/src/transport/child_process/tokio.rs index 4733d457..edb3c061 100644 --- a/crates/rmcp/src/transport/child_process2/tokio.rs +++ b/crates/rmcp/src/transport/child_process/tokio.rs @@ -1,6 +1,6 @@ use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; -use crate::transport::child_process2::{ +use crate::transport::child_process::{ builder::CommandConfig, runner::{ChildProcessInstance, ChildProcessRunner, RunnerSpawnError}, }; diff --git a/crates/rmcp/src/transport/child_process2/transport.rs b/crates/rmcp/src/transport/child_process/transport.rs similarity index 96% rename from crates/rmcp/src/transport/child_process2/transport.rs rename to crates/rmcp/src/transport/child_process/transport.rs index 30f15a2d..9f46d9d9 100644 --- a/crates/rmcp/src/transport/child_process2/transport.rs +++ b/crates/rmcp/src/transport/child_process/transport.rs @@ -6,7 +6,7 @@ use crate::{ transport::{ Transport, async_rw::AsyncRwTransport, - child_process2::runner::{ChildProcess, ChildProcessControl}, + child_process::runner::{ChildProcess, ChildProcessControl}, }, }; diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 6e39b5da..ed1c0b7d 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -3,7 +3,7 @@ use rmcp::{ service::QuitReason, transport::{ StreamableHttpClientTransport, StreamableHttpServerConfig, - child_process2::{ + child_process::{ runner::{ChildProcessControl, CommandBuilder}, tokio::TokioChildProcessRunner, transport::ChildProcessTransport, diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs index d880eb03..074b8447 100644 --- a/crates/rmcp/tests/test_with_python.rs +++ b/crates/rmcp/tests/test_with_python.rs @@ -5,7 +5,7 @@ use rmcp::{ ServiceExt, transport::{ ChildProcess, ChildProcessInstance, - child_process2::{ + child_process::{ builder::CommandBuilder, tokio::TokioChildProcessRunner, transport::ChildProcessTransport, }, diff --git a/examples/simple-chat-client/src/config.rs b/examples/simple-chat-client/src/config.rs index 6d0292b3..946436fd 100644 --- a/examples/simple-chat-client/src/config.rs +++ b/examples/simple-chat-client/src/config.rs @@ -6,7 +6,7 @@ use rmcp::{ service::RunningService, transport::{ CommandBuilder, - child_process2::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, }, }; use serde::{Deserialize, Serialize}; From 0adab2536b347810afd743c849ecb1fd74b14aa1 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 19:32:28 +0000 Subject: [PATCH 10/20] refactor(test): re-introduce tests for child process dropping --- crates/rmcp/Cargo.toml | 2 - .../src/transport/child_process/runner.rs | 3 + .../rmcp/src/transport/child_process/tokio.rs | 79 ++++++++++++++++++- 3 files changed, 79 insertions(+), 5 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index d08951d0..ae5262c8 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -74,8 +74,6 @@ chrono = { version = "0.4.38", default-features = false, features = [ "oldtime", ] } -[target.'cfg(test)'] - [features] default = ["base64", "macros", "server"] client = ["dep:tokio-stream"] diff --git a/crates/rmcp/src/transport/child_process/runner.rs b/crates/rmcp/src/transport/child_process/runner.rs index 7c2bb3f4..bc457bb2 100644 --- a/crates/rmcp/src/transport/child_process/runner.rs +++ b/crates/rmcp/src/transport/child_process/runner.rs @@ -87,6 +87,9 @@ pub enum RunnerSpawnError { /// The child process instance failed to spawn. #[error("Failed to spawn child process: {0}")] SpawnError(#[from] std::io::Error), + /// The child process instance did not have a PID assigned (this is unexpected for a spawned process). + #[error("Child process did not have a PID assigned after spawning")] + NoPidAssigned, #[error("Other error: {0}")] Other(Box), } diff --git a/crates/rmcp/src/transport/child_process/tokio.rs b/crates/rmcp/src/transport/child_process/tokio.rs index edb3c061..db835afb 100644 --- a/crates/rmcp/src/transport/child_process/tokio.rs +++ b/crates/rmcp/src/transport/child_process/tokio.rs @@ -10,6 +10,8 @@ pub struct TokioChildProcessRunner {} /// An implementation for the tokio Child Process pub struct TokioChildProcess { inner: tokio::process::Child, + /// The PID at the time of spawning. + pid: u32, } impl ChildProcessInstance for TokioChildProcess { @@ -41,8 +43,7 @@ impl ChildProcessInstance for TokioChildProcess { } fn pid(&self) -> u32 { - // TODO: Consider refactor to return Option to avoid confusion of 0 as a valid PID. - self.inner.id().unwrap_or(0) + self.pid } fn wait<'s>( @@ -79,7 +80,79 @@ impl ChildProcessRunner for TokioChildProcessRunner { ) .kill_on_drop(true) .spawn() - .map(|child| TokioChildProcess { inner: child }) .map_err(RunnerSpawnError::SpawnError) + .and_then(|child| { + let pid = child.id().ok_or_else(|| RunnerSpawnError::NoPidAssigned)?; + Ok(TokioChildProcess { inner: child, pid }) + }) + } +} + +#[cfg(test)] +mod test { + + use crate::transport::CommandBuilder; + use tokio::process::Command; + + use super::*; + + async fn check_pid(pid: u32) -> std::io::Result { + // This command will output only process numbers on each line. + let output = Command::new("ps") + .arg("-o") + .arg("pid=") + .arg("-p") + .arg(pid.to_string()) + .output() + .await?; + + let output_str = String::from_utf8_lossy(&output.stdout); + Ok(output_str + .lines() + .any(|line| line.trim() == pid.to_string())) + } + + #[cfg(unix)] + #[tokio::test] + async fn test_kill_on_drop() { + let child = CommandBuilder::::new("sleep") + .args(["10"]) + .spawn_raw() + .expect("Failed to spawn child process"); + + let pid = child.pid(); + + // Drop the child process without waiting for it to exit, which should kill it due to `kill_on_drop(true)`. + drop(child); + + // Wait a moment to ensure the process has been killed. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let pid_found = check_pid(pid).await.expect("Failed to check if PID exists"); + + assert!(!pid_found, "Child process was not killed on drop"); + } + + #[tokio::test] + async fn test_graceful_shutdown() { + let mut child = CommandBuilder::::new("sleep") + .args(["10"]) + .spawn_raw() + .expect("Failed to spawn child process"); + + let pid = child.pid(); + + // Sleep a moment to ensure the process is running before we attempt to shut it down. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + child + .graceful_shutdown() + .await + .expect("Failed to gracefully shutdown child process"); + + // We should not need to wait here since we await the graceful shutdown above. + // Graceful shutdown *should* cover waiting for the process to exit. + let pid_found = check_pid(pid).await.expect("Failed to check if PID exists"); + assert!(!pid_found, "Child process was not shutdown"); } } From 007dd925c5ceb6f47edb4ef2db770821a9160fad Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Sun, 1 Mar 2026 19:58:50 +0000 Subject: [PATCH 11/20] refactor: revert some unnecessary module visibility changes fix unit test --- crates/rmcp/Cargo.toml | 1 + crates/rmcp/src/transport/common.rs | 2 +- crates/rmcp/src/transport/common/reqwest.rs | 2 +- crates/rmcp/tests/test_with_js.rs | 5 ++--- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index ae5262c8..9117db70 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -180,6 +180,7 @@ required-features = [ "transport-child-process-tokio", "transport-streamable-http-server", "transport-streamable-http-client", + "transport-streamable-http-client-reqwest", "__reqwest", ] path = "tests/test_with_js.rs" diff --git a/crates/rmcp/src/transport/common.rs b/crates/rmcp/src/transport/common.rs index b41a8f3c..615b0e27 100644 --- a/crates/rmcp/src/transport/common.rs +++ b/crates/rmcp/src/transport/common.rs @@ -4,7 +4,7 @@ pub mod server_side_http; pub mod http_header; #[cfg(feature = "__reqwest")] -pub mod reqwest; +mod reqwest; // Note: This module provides SSE stream parsing and auto-reconnect utilities. // It's used by the streamable HTTP client (which receives SSE-formatted responses), diff --git a/crates/rmcp/src/transport/common/reqwest.rs b/crates/rmcp/src/transport/common/reqwest.rs index 696aa912..42075921 100644 --- a/crates/rmcp/src/transport/common/reqwest.rs +++ b/crates/rmcp/src/transport/common/reqwest.rs @@ -1,2 +1,2 @@ #[cfg(feature = "transport-streamable-http-client-reqwest")] -pub mod streamable_http_client; +mod streamable_http_client; diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index ed1c0b7d..40228b58 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -2,10 +2,9 @@ use rmcp::{ ServiceExt, service::QuitReason, transport::{ - StreamableHttpClientTransport, StreamableHttpServerConfig, + CommandBuilder, StreamableHttpClientTransport, StreamableHttpServerConfig, child_process::{ - runner::{ChildProcessControl, CommandBuilder}, - tokio::TokioChildProcessRunner, + runner::ChildProcessControl, tokio::TokioChildProcessRunner, transport::ChildProcessTransport, }, streamable_http_server::{ From 319a77a04a8f4766a03699ec76b9570a5c761760 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Tue, 3 Mar 2026 04:26:54 +0000 Subject: [PATCH 12/20] refactor(tests): update calls to serve in all unit tests --- conformance/src/bin/client.rs | 28 ++++++--- crates/rmcp/src/service.rs | 21 +++++++ crates/rmcp/tests/test_close_connection.rs | 46 ++++++++------ crates/rmcp/tests/test_custom_request.rs | 20 ++++--- crates/rmcp/tests/test_logging.rs | 40 ++++++++----- crates/rmcp/tests/test_message_protocol.rs | 60 +++++++++++-------- crates/rmcp/tests/test_notification.rs | 30 ++++++---- crates/rmcp/tests/test_prompt_macros.rs | 9 ++- crates/rmcp/tests/test_sampling.rs | 30 ++++++---- .../tests/test_task_support_validation.rs | 36 +++++++---- crates/rmcp/tests/test_tool_macros.rs | 9 ++- 11 files changed, 214 insertions(+), 115 deletions(-) diff --git a/conformance/src/bin/client.rs b/conformance/src/bin/client.rs index 53a44d9e..061c2e2b 100644 --- a/conformance/src/bin/client.rs +++ b/conformance/src/bin/client.rs @@ -325,7 +325,10 @@ async fn run_auth_client(server_url: &str, ctx: &ConformanceContext) -> anyhow:: StreamableHttpClientTransportConfig::with_uri(server_url), ); - let client = BasicClientHandler.serve(transport).await?; + let (client, work) = BasicClientHandler.serve(transport).await?; + // Run the client work loop in the background while we interact with it + tokio::spawn(work); + tracing::debug!("Connected (authenticated)"); let tools = client.list_tools(Default::default()).await?; @@ -379,7 +382,9 @@ async fn run_auth_scope_step_up_client( StreamableHttpClientTransportConfig::with_uri(server_url), ); - let client = BasicClientHandler.serve(transport).await?; + let (client, work) = BasicClientHandler.serve(transport).await?; + // Run the client work loop in the background while we interact with it + tokio::spawn(work); let tools = client.list_tools(Default::default()).await?; tracing::debug!("Listed {} tools", tools.tools.len()); @@ -426,7 +431,8 @@ async fn run_auth_scope_step_up_client( auth_client2, StreamableHttpClientTransportConfig::with_uri(server_url), ); - let client2 = BasicClientHandler.serve(transport2).await?; + let (client2, work2) = BasicClientHandler.serve(transport2).await?; + tokio::spawn(work2); let _ = client2 .call_tool(CallToolRequestParams { meta: None, @@ -474,7 +480,9 @@ async fn run_auth_scope_retry_limit_client( StreamableHttpClientTransportConfig::with_uri(server_url), ); - let client = BasicClientHandler.serve(transport).await?; + let (client, work) = BasicClientHandler.serve(transport).await?; + tokio::spawn(work); + let tools = client.list_tools(Default::default()).await?; let mut got_403 = false; @@ -532,7 +540,9 @@ async fn run_auth_preregistered_client( StreamableHttpClientTransportConfig::with_uri(server_url), ); - let client = BasicClientHandler.serve(transport).await?; + let (client, work) = BasicClientHandler.serve(transport).await?; + tokio::spawn(work); + let tools = client.list_tools(Default::default()).await?; tracing::debug!("Listed {} tools", tools.tools.len()); @@ -591,7 +601,9 @@ async fn run_client_credentials_basic( .auth_header(access_token.to_string()), ); - let client = BasicClientHandler.serve(transport).await?; + let (client, work) = BasicClientHandler.serve(transport).await?; + tokio::spawn(work); + let tools = client.list_tools(Default::default()).await?; tracing::debug!("Listed {} tools", tools.tools.len()); for tool in &tools.tools { @@ -661,7 +673,9 @@ async fn run_client_credentials_jwt( .auth_header(access_token.to_string()), ); - let client = BasicClientHandler.serve(transport).await?; + let (client, work) = BasicClientHandler.serve(transport).await?; + tokio::spawn(work); + let tools = client.list_tools(Default::default()).await?; tracing::debug!("Listed {} tools", tools.tools.len()); for tool in &tools.tools { diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index f4374dea..0320f371 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -114,6 +114,27 @@ pub trait ServiceExt: Service + Sized { fn into_dyn(self) -> Box> { Box::new(self) } + + /// Serve this service with the provided transport + /// + /// This function returns a facade to the running service, and a future that runs the + /// service business logic. The caller is responsible for running the business logic + /// future, either by spawning it on a runtime or awaiting it directly. + /// + /// Ex: + /// ```rust,ignore + /// // Try to initialize a service with the given transport + /// let (client, work) = MyServiceImpl.serve(transport).await?; + /// + /// // Spawn the service business logic on a runtime (e.g. tokio) + /// tokio::spawn(work); + /// + /// // Now we can interact with the service + /// let response = client.send_request(...).await?; + /// ``` + /// + /// The returned [RunningService] provides methods to interact with the running service, such + /// as sending requests or notifications to the peer, and cancelling the service. fn serve( self, transport: T, diff --git a/crates/rmcp/tests/test_close_connection.rs b/crates/rmcp/tests/test_close_connection.rs index b3bb5b63..d79fac85 100644 --- a/crates/rmcp/tests/test_close_connection.rs +++ b/crates/rmcp/tests/test_close_connection.rs @@ -3,8 +3,9 @@ mod common; use std::time::Duration; +use anyhow::anyhow; use common::handlers::{TestClientHandler, TestServer}; -use rmcp::{ServiceExt, service::QuitReason}; +use rmcp::{ServiceExt, handler::client, service::QuitReason}; /// Test that close() properly shuts down the connection #[tokio::test] @@ -13,27 +14,29 @@ async fn test_close_method() -> anyhow::Result<()> { // Start server let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); // Start client let handler = TestClientHandler::new(true, true); - let mut client = handler.serve(client_transport).await?; + let (mut client, work) = handler.serve(client_transport).await?; + tokio::spawn(work); // Verify client is not closed assert!(!client.is_closed()); // Call close() and verify it returns - let result = client.close().await?; + let result = client.close().await; assert!(matches!(result, QuitReason::Cancelled)); // Verify client is now closed assert!(client.is_closed()); // Calling close() again should return Closed immediately - let result = client.close().await?; + let result = client.close().await; assert!(matches!(result, QuitReason::Closed)); // Wait for server to finish @@ -48,19 +51,23 @@ async fn test_close_with_timeout() -> anyhow::Result<()> { // Start server let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); // Start client let handler = TestClientHandler::new(true, true); - let mut client = handler.serve(client_transport).await?; + let (mut client, work) = handler.serve(client_transport).await?; + tokio::spawn(work); // Close with a reasonable timeout - let result = client.close_with_timeout(Duration::from_secs(5)).await?; - assert!(result.is_some()); - assert!(matches!(result.unwrap(), QuitReason::Cancelled)); + let result = client + .close_with_timeout(Duration::from_secs(5)) + .await + .ok_or(anyhow!("close_with_timeout should return Ok on timeout"))?; + assert!(matches!(result, QuitReason::Cancelled)); // Verify client is now closed assert!(client.is_closed()); @@ -77,17 +84,19 @@ async fn test_cancel_method() -> anyhow::Result<()> { // Start server let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); // Start client let handler = TestClientHandler::new(true, true); - let client = handler.serve(client_transport).await?; + let (client, work) = handler.serve(client_transport).await?; + tokio::spawn(work); // Cancel should work as before - let result = client.cancel().await?; + let result = client.cancel().await; assert!(matches!(result, QuitReason::Cancelled)); // Wait for server to finish @@ -103,9 +112,10 @@ async fn test_drop_without_close() -> anyhow::Result<()> { // Start server that will handle the drop let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); // The server should close when the client drops - let result = server.waiting().await?; + let result = server.waiting().await; // Server should detect closure assert!(matches!(result, QuitReason::Closed | QuitReason::Cancelled)); anyhow::Ok(()) diff --git a/crates/rmcp/tests/test_custom_request.rs b/crates/rmcp/tests/test_custom_request.rs index 83a8d347..58533bf8 100644 --- a/crates/rmcp/tests/test_custom_request.rs +++ b/crates/rmcp/tests/test_custom_request.rs @@ -48,18 +48,20 @@ async fn test_custom_client_request_reaches_server() -> anyhow::Result<()> { let receive_signal = receive_signal.clone(); let payload = payload.clone(); tokio::spawn(async move { - let server = CustomRequestServer { + let (server, work) = CustomRequestServer { receive_signal, payload, } .serve(server_transport) .await?; - server.waiting().await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); } - let client = ().serve(client_transport).await?; + let (client, work) = ().serve(client_transport).await?; + tokio::spawn(work); let response = client .send_request(ClientRequest::CustomRequest(CustomRequest::new( @@ -81,7 +83,7 @@ async fn test_custom_client_request_reaches_server() -> anyhow::Result<()> { other => panic!("Expected custom result, got: {other:?}"), } - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -148,13 +150,14 @@ async fn test_custom_server_request_reaches_client() -> anyhow::Result<()> { let response_signal = response_signal.clone(); let response = response.clone(); async move { - let server = CustomRequestServerNotifier { + let (server, work) = CustomRequestServerNotifier { receive_signal: response_signal, response, } .serve(server_transport) .await?; - server.waiting().await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) } }); @@ -162,12 +165,13 @@ async fn test_custom_server_request_reaches_client() -> anyhow::Result<()> { let receive_signal = Arc::new(Notify::new()); let payload = Arc::new(Mutex::new(None)); - let client = CustomRequestClient { + let (client, work) = CustomRequestClient { receive_signal: receive_signal.clone(), payload: payload.clone(), } .serve(client_transport) .await?; + tokio::spawn(work); tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?; tokio::time::timeout( @@ -184,6 +188,6 @@ async fn test_custom_server_request_reaches_client() -> anyhow::Result<()> { let response = response.expect("custom request response ok"); assert_eq!(response, json!({ "status": "ok" })); - client.cancel().await?; + client.cancel().await; Ok(()) } diff --git a/crates/rmcp/tests/test_logging.rs b/crates/rmcp/tests/test_logging.rs index be63b24f..3134f8db 100644 --- a/crates/rmcp/tests/test_logging.rs +++ b/crates/rmcp/tests/test_logging.rs @@ -19,7 +19,8 @@ async fn test_logging_spec_compliance() -> anyhow::Result<()> { // Start server in a separate task let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); // Test server can send messages before level is set server @@ -34,11 +35,11 @@ async fn test_logging_spec_compliance() -> anyhow::Result<()> { }) .await?; - server.waiting().await?; + server.waiting().await; anyhow::Ok(()) }); - let client = TestClientHandler::with_notification( + let (client, work) = TestClientHandler::with_notification( true, true, receive_signal.clone(), @@ -46,6 +47,7 @@ async fn test_logging_spec_compliance() -> anyhow::Result<()> { ) .serve(client_transport) .await?; + tokio::spawn(work); // Wait for the initial server message receive_signal.notified().await; @@ -89,7 +91,7 @@ async fn test_logging_spec_compliance() -> anyhow::Result<()> { } // Important: Cancel the client before ending the test - client.cancel().await?; + client.cancel().await; // Wait for server to complete server_handle.await??; @@ -104,12 +106,13 @@ async fn test_logging_user_scenarios() -> anyhow::Result<()> { let received_messages = Arc::new(Mutex::new(Vec::::new())); let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); - let client = TestClientHandler::with_notification( + let (client, work) = TestClientHandler::with_notification( true, true, receive_signal.clone(), @@ -117,6 +120,7 @@ async fn test_logging_user_scenarios() -> anyhow::Result<()> { ) .serve(client_transport) .await?; + tokio::spawn(work); // Test 1: Error reporting scenario client @@ -187,7 +191,7 @@ async fn test_logging_user_scenarios() -> anyhow::Result<()> { } // Important: Cancel client and wait for server before ending - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) @@ -237,12 +241,13 @@ async fn test_logging_edge_cases() -> anyhow::Result<()> { let received_messages = Arc::new(Mutex::new(Vec::::new())); let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); - let client = TestClientHandler::with_notification( + let (client, work) = TestClientHandler::with_notification( true, true, receive_signal.clone(), @@ -250,6 +255,7 @@ async fn test_logging_edge_cases() -> anyhow::Result<()> { ) .serve(client_transport) .await?; + tokio::spawn(work); // Test all logging levels from spec for level in [ @@ -268,7 +274,7 @@ async fn test_logging_edge_cases() -> anyhow::Result<()> { assert_eq!(msg.level, level); } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -280,7 +286,8 @@ async fn test_logging_optional_fields() -> anyhow::Result<()> { let received_messages = Arc::new(Mutex::new(Vec::::new())); let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); // Test message with and without optional logger field for (level, has_logger) in [(LoggingLevel::Info, true), (LoggingLevel::Debug, false)] { @@ -294,11 +301,11 @@ async fn test_logging_optional_fields() -> anyhow::Result<()> { .await?; } - server.waiting().await?; + server.waiting().await; anyhow::Ok(()) }); - let client = TestClientHandler::with_notification( + let (client, work) = TestClientHandler::with_notification( true, true, receive_signal.clone(), @@ -306,6 +313,7 @@ async fn test_logging_optional_fields() -> anyhow::Result<()> { ) .serve(client_transport) .await?; + tokio::spawn(work); // Wait for the initial server message receive_signal.notified().await; @@ -345,7 +353,7 @@ async fn test_logging_optional_fields() -> anyhow::Result<()> { } // Important: Cancel the client before ending the test - client.cancel().await?; + client.cancel().await; // Wait for server to complete server_handle.await??; diff --git a/crates/rmcp/tests/test_message_protocol.rs b/crates/rmcp/tests/test_message_protocol.rs index 7ec3258c..1d636a55 100644 --- a/crates/rmcp/tests/test_message_protocol.rs +++ b/crates/rmcp/tests/test_message_protocol.rs @@ -29,14 +29,16 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { // Start server let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); // Start client that honors context requests let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; + let (client, work) = handler.clone().serve(client_transport).await?; + tokio::spawn(work); // Test ThisServer context inclusion let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { @@ -191,7 +193,7 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { panic!("Expected CreateMessageResult"); } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -202,14 +204,16 @@ async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { // Start server let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); // Start client that ignores context requests let handler = TestClientHandler::new(false, false); - let client = handler.clone().serve(client_transport).await?; + let (client, work) = handler.clone().serve(client_transport).await?; + tokio::spawn(work); // Test that context requests are ignored let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { @@ -262,7 +266,7 @@ async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { panic!("Expected CreateMessageResult"); } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -273,14 +277,16 @@ async fn test_message_sequence_integration() -> anyhow::Result<()> { // Start server let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); // Start client let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; + let (client, work) = handler.clone().serve(client_transport).await?; + tokio::spawn(work); let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { method: Default::default(), @@ -340,7 +346,7 @@ async fn test_message_sequence_integration() -> anyhow::Result<()> { panic!("Expected CreateMessageResult"); } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -350,13 +356,15 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; + let (client, work) = handler.clone().serve(client_transport).await?; + tokio::spawn(work); // Test valid sequence: User -> Assistant -> User let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { @@ -432,7 +440,7 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { assert!(result.is_err()); - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -442,14 +450,16 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); // Client that only honors ThisServer but ignores AllServers let handler = TestClientHandler::new(true, false); - let client = handler.clone().serve(client_transport).await?; + let (client, work) = handler.clone().serve(client_transport).await?; + tokio::spawn(work); // Test ThisServer is honored let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { @@ -549,7 +559,7 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { ); } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -558,13 +568,15 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { async fn test_context_inclusion() -> anyhow::Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; + let (client, work) = handler.clone().serve(client_transport).await?; + tokio::spawn(work); // Test context handling let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { @@ -612,7 +624,7 @@ async fn test_context_inclusion() -> anyhow::Result<()> { assert!(text.contains("test context")); } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } diff --git a/crates/rmcp/tests/test_notification.rs b/crates/rmcp/tests/test_notification.rs index 01837421..ca2e92b2 100644 --- a/crates/rmcp/tests/test_notification.rs +++ b/crates/rmcp/tests/test_notification.rs @@ -76,16 +76,18 @@ async fn test_server_notification() -> anyhow::Result<()> { .try_init(); let (server_transport, client_transport) = tokio::io::duplex(4096); tokio::spawn(async move { - let server = Server {}.serve(server_transport).await?; - server.waiting().await?; + let (server, work) = Server {}.serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); let receive_signal = Arc::new(Notify::new()); - let client = Client { + let (client, work) = Client { receive_signal: receive_signal.clone(), } .serve(client_transport) .await?; + tokio::spawn(work); client .subscribe(SubscribeRequestParams { meta: None, @@ -93,7 +95,7 @@ async fn test_server_notification() -> anyhow::Result<()> { }) .await?; receive_signal.notified().await; - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -134,18 +136,20 @@ async fn test_custom_client_notification_reaches_server() -> anyhow::Result<()> let receive_signal = receive_signal.clone(); let payload = payload.clone(); tokio::spawn(async move { - let server = CustomServer { + let (server, work) = CustomServer { receive_signal, payload, } .serve(server_transport) .await?; - server.waiting().await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); } - let client = ().serve(client_transport).await?; + let (client, work) = ().serve(client_transport).await?; + tokio::spawn(work); client .send_notification(ClientNotification::CustomNotification( @@ -159,7 +163,7 @@ async fn test_custom_client_notification_reaches_server() -> anyhow::Result<()> assert_eq!("notifications/custom-test", method); assert_eq!(Some(json!({ "foo": "bar" })), params); - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -210,20 +214,22 @@ async fn test_custom_server_notification_reaches_client() -> anyhow::Result<()> let (server_transport, client_transport) = tokio::io::duplex(4096); tokio::spawn(async move { - let server = CustomServerNotifier {}.serve(server_transport).await?; - server.waiting().await?; + let (server, work) = CustomServerNotifier {}.serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); let receive_signal = Arc::new(Notify::new()); let payload = Arc::new(Mutex::new(None)); - let client = CustomClient { + let (client, work) = CustomClient { receive_signal: receive_signal.clone(), payload: payload.clone(), } .serve(client_transport) .await?; + tokio::spawn(work); tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?; @@ -231,6 +237,6 @@ async fn test_custom_server_notification_reaches_client() -> anyhow::Result<()> assert_eq!("notifications/custom-test", method); assert_eq!(Some(json!({ "hello": "world" })), params); - client.cancel().await?; + client.cancel().await; Ok(()) } diff --git a/crates/rmcp/tests/test_prompt_macros.rs b/crates/rmcp/tests/test_prompt_macros.rs index 2407571d..ab3488c4 100644 --- a/crates/rmcp/tests/test_prompt_macros.rs +++ b/crates/rmcp/tests/test_prompt_macros.rs @@ -317,13 +317,16 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Server setup let server = OptionalSchemaTester::new(); let server_handle = tokio::spawn(async move { - server.serve(server_transport).await?.waiting().await?; + let (server, work) = server.serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); // Create a simple client handler that just forwards prompt calls let client_handler = DummyClientHandler::default(); - let client = client_handler.serve(client_transport).await?; + let (client, work) = client_handler.serve(client_transport).await?; + tokio::spawn(work); // Test null case let result = client @@ -379,7 +382,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { "Some case should return expected message" ); - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } diff --git a/crates/rmcp/tests/test_sampling.rs b/crates/rmcp/tests/test_sampling.rs index e5191d3c..f8c8a194 100644 --- a/crates/rmcp/tests/test_sampling.rs +++ b/crates/rmcp/tests/test_sampling.rs @@ -102,13 +102,15 @@ async fn test_sampling_integration_with_test_handlers() -> Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; + let (client, work) = handler.clone().serve(client_transport).await?; + tokio::spawn(work); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -176,7 +178,7 @@ async fn test_sampling_integration_with_test_handlers() -> Result<()> { panic!("Expected CreateMessageResult"); } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -186,13 +188,15 @@ async fn test_sampling_no_context_inclusion() -> Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; + let (client, work) = handler.clone().serve(client_transport).await?; + tokio::spawn(work); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -249,7 +253,7 @@ async fn test_sampling_no_context_inclusion() -> Result<()> { panic!("Expected CreateMessageResult"); } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -259,13 +263,15 @@ async fn test_sampling_error_invalid_message_sequence() -> Result<()> { let (server_transport, client_transport) = tokio::io::duplex(4096); let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; + let (server, work) = TestServer::new().serve(server_transport).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; + let (client, work) = handler.clone().serve(client_transport).await?; + tokio::spawn(work); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -305,7 +311,7 @@ async fn test_sampling_error_invalid_message_sequence() -> Result<()> { assert!(result.is_err()); - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } diff --git a/crates/rmcp/tests/test_task_support_validation.rs b/crates/rmcp/tests/test_task_support_validation.rs index 016ed240..14415697 100644 --- a/crates/rmcp/tests/test_task_support_validation.rs +++ b/crates/rmcp/tests/test_task_support_validation.rs @@ -84,12 +84,15 @@ async fn test_required_task_tool_without_task_returns_method_not_found() -> anyh let server = TaskSupportTestServer::new(); let server_handle = tokio::spawn(async move { - server.serve(server_transport).await?.waiting().await?; + let (service, work) = server.serve(server_transport).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Ok(()) }); let client_handler = DummyClientHandler::default(); - let client = client_handler.serve(client_transport).await?; + let (client, work) = client_handler.serve(client_transport).await?; + tokio::spawn(work); // Call the task-required tool without a task - should fail with -32601 let result = client @@ -127,7 +130,7 @@ async fn test_required_task_tool_without_task_returns_method_not_found() -> anyh _ => panic!("Expected McpError variant, got: {:?}", error), } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -138,12 +141,15 @@ async fn test_forbidden_task_tool_with_task_returns_error() -> anyhow::Result<() let server = TaskSupportTestServer::new(); let server_handle = tokio::spawn(async move { - server.serve(server_transport).await?.waiting().await?; + let (service, work) = server.serve(server_transport).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Ok(()) }); let client_handler = DummyClientHandler::default(); - let client = client_handler.serve(client_transport).await?; + let (client, work) = client_handler.serve(client_transport).await?; + tokio::spawn(work); // Call the forbidden task tool WITH a task - should fail let result = client @@ -181,7 +187,7 @@ async fn test_forbidden_task_tool_with_task_returns_error() -> anyhow::Result<() _ => panic!("Expected McpError variant, got: {:?}", error), } - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -192,12 +198,15 @@ async fn test_forbidden_task_tool_without_task_succeeds() -> anyhow::Result<()> let server = TaskSupportTestServer::new(); let server_handle = tokio::spawn(async move { - server.serve(server_transport).await?.waiting().await?; + let (service, work) = server.serve(server_transport).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Ok(()) }); let client_handler = DummyClientHandler::default(); - let client = client_handler.serve(client_transport).await?; + let (client, work) = client_handler.serve(client_transport).await?; + tokio::spawn(work); // Call the forbidden task tool WITHOUT a task - should succeed let result = client @@ -222,7 +231,7 @@ async fn test_forbidden_task_tool_without_task_succeeds() -> anyhow::Result<()> .unwrap_or(""); assert_eq!(text, "forbidden task executed"); - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } @@ -233,12 +242,15 @@ async fn test_optional_task_tool_without_task_succeeds() -> anyhow::Result<()> { let server = TaskSupportTestServer::new(); let server_handle = tokio::spawn(async move { - server.serve(server_transport).await?.waiting().await?; + let (service, work) = server.serve(server_transport).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Ok(()) }); let client_handler = DummyClientHandler::default(); - let client = client_handler.serve(client_transport).await?; + let (client, work) = client_handler.serve(client_transport).await?; + tokio::spawn(work); // Call the optional task tool WITHOUT a task - should succeed let result = client @@ -263,7 +275,7 @@ async fn test_optional_task_tool_without_task_succeeds() -> anyhow::Result<()> { .unwrap_or(""); assert_eq!(text, "optional task executed"); - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 837198cb..9bac64a5 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -298,13 +298,16 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { // Server setup let server = OptionalSchemaTester::new(); let server_handle = tokio::spawn(async move { - server.serve(server_transport).await?.waiting().await?; + let (service, work) = server.serve(server_transport).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Ok(()) }); // Create a simple client handler that just forwards tool calls let client_handler = DummyClientHandler::default(); - let client = client_handler.serve(client_transport).await?; + let (client, work) = client_handler.serve(client_transport).await?; + tokio::spawn(work); // Test null case let result = client @@ -366,7 +369,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { "Some case should return expected message" ); - client.cancel().await?; + client.cancel().await; server_handle.await??; Ok(()) } From aee0d4f2b7d330b2920302e2a77575ac04abe289 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Tue, 3 Mar 2026 23:32:03 +0000 Subject: [PATCH 13/20] refactor(test,examples): update remaining calls to `serve(...)` --- crates/rmcp/src/service.rs | 5 +- crates/rmcp/src/transport.rs | 8 ++- .../src/transport/child_process/transport.rs | 14 +++- crates/rmcp/src/util.rs | 65 +------------------ crates/rmcp/tests/test_custom_headers.rs | 10 +-- examples/clients/Cargo.toml | 3 +- examples/clients/src/auth/oauth_client.rs | 3 +- examples/clients/src/collection.rs | 23 +++---- examples/clients/src/everything_stdio.rs | 23 ++++--- examples/clients/src/git_stdio.rs | 28 ++++---- examples/clients/src/progress_client.rs | 28 ++++---- examples/clients/src/sampling_stdio.rs | 35 +++++----- examples/clients/src/streamable_http.rs | 6 +- examples/rig-integration/Cargo.toml | 3 +- examples/rig-integration/src/config/mcp.rs | 28 +++++--- examples/servers/src/calculator_stdio.rs | 6 +- examples/servers/src/common/counter.rs | 21 +++--- examples/servers/src/completion_stdio.rs | 6 +- examples/servers/src/counter_stdio.rs | 6 +- examples/servers/src/elicitation_stdio.rs | 6 +- examples/servers/src/progress_demo.rs | 5 +- examples/servers/src/prompt_stdio.rs | 6 +- examples/servers/src/sampling_stdio.rs | 6 +- examples/servers/src/structured_output.rs | 5 +- examples/transport/src/http_upgrade.rs | 10 +-- examples/transport/src/websocket.rs | 10 +-- examples/wasi/src/lib.rs | 5 +- 27 files changed, 193 insertions(+), 181 deletions(-) diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 0320f371..5d6e93a0 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -536,7 +536,10 @@ impl> RunningService { /// # Example /// /// ```rust,ignore - /// let mut client = ().serve(transport).await?; + /// let mut (client, work) = ().serve(transport).await?; + /// // spawn the work (e.g. on tokio) + /// tokio::spawn(work); + /// /// // ... use the client ... /// client.close().await?; /// ``` diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index de12321c..83017a33 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -51,7 +51,9 @@ //! let stream = tokio::net::TcpSocket::new_v4()? //! .connect("127.0.0.1:8001".parse()?) //! .await?; -//! let client = ().serve(stream).await?; +//! let (client, work) = ().serve(stream).await?; +//! // spawn the work on a runtime (or poll it somehow) +//! tokio::spawn(work); //! let tools = client.peer().list_tools(Default::default()).await?; //! println!("{:?}", tools); //! Ok(()) @@ -60,7 +62,9 @@ //! // create transport from std io //! #[cfg(feature = "client")] //! async fn io() -> Result<(), Box> { -//! let client = ().serve((tokio::io::stdin(), tokio::io::stdout())).await?; +//! let (client, work) = ().serve((tokio::io::stdin(), tokio::io::stdout())).await?; +//! // spawn the work on a runtime (or poll it somehow) +//! tokio::spawn(work); //! let tools = client.peer().list_tools(Default::default()).await?; //! println!("{:?}", tools); //! Ok(()) diff --git a/crates/rmcp/src/transport/child_process/transport.rs b/crates/rmcp/src/transport/child_process/transport.rs index 9f46d9d9..d3c5b2b7 100644 --- a/crates/rmcp/src/transport/child_process/transport.rs +++ b/crates/rmcp/src/transport/child_process/transport.rs @@ -10,6 +10,14 @@ use crate::{ }, }; +#[derive(thiserror::Error, Debug)] +pub enum ChildProcessTransportError { + #[error("Missing stdout")] + MissingStdout, + #[error("Missing stdin")] + MissingStdin, +} + pub struct ChildProcessTransport { _child: Box, framed_transport: AsyncRwTransport< @@ -23,18 +31,18 @@ impl ChildProcessTransport where R: ServiceRole, { - pub fn new(child: ChildProcess) -> Result> { + pub fn new(child: ChildProcess) -> Result { let (stdout, stdin, _stderr, control) = child.split(); let framed_transport: AsyncRwTransport = AsyncRwTransport::new( Box::new( stdout - .ok_or("Failed to capture stdout of child process")? + .ok_or(ChildProcessTransportError::MissingStdout)? .compat(), ) as Box, Box::new( stdin - .ok_or("Failed to capture stdin of child process")? + .ok_or(ChildProcessTransportError::MissingStdin)? .compat_write(), ) as Box, ); diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs index 912e1378..33b273f8 100644 --- a/crates/rmcp/src/util.rs +++ b/crates/rmcp/src/util.rs @@ -1,5 +1,5 @@ -use futures::{Sink, Stream}; -use std::{pin::Pin, task::Poll}; +use futures::Stream; +use std::pin::Pin; pub type PinnedFuture<'a, T> = Pin + Send + 'a>>; @@ -8,64 +8,3 @@ pub type PinnedLocalFuture<'a, T> = Pin + 'a>>; pub type PinnedStream<'a, T> = Pin + Send + 'a>>; pub type PinnedLocalStream<'a, T> = Pin + 'a>>; - -pub enum UnboundedSenderSinkError { - SendError(tokio::sync::mpsc::error::SendError), - Closed, -} - -/// A simple [Sink] wrapper for Tokio's [tokio::sync::mpsc::UnboundedSender] -#[derive(Debug, Clone)] -pub struct UnboundedSenderSink { - sender: tokio::sync::mpsc::UnboundedSender, -} - -impl UnboundedSenderSink { - pub fn new(sender: tokio::sync::mpsc::UnboundedSender) -> Self { - Self { sender } - } -} - -impl Sink for UnboundedSenderSink { - type Error = UnboundedSenderSinkError; - - fn poll_ready( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - if this.sender.is_closed() { - Poll::Ready(Err(UnboundedSenderSinkError::Closed)) - } else { - Poll::Ready(Ok(())) - } - } - - fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { - let this = self.get_mut(); - match this.sender.send(item) { - Ok(_) => Ok(()), - Err(e) => Err(UnboundedSenderSinkError::SendError(e)), - } - } - - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - // tokio's unbounded mpsc senders have no flushing required, since the - // receiver is unbounded and will get all messages we send (unless we run - // out of memory) - Poll::Ready(Ok(())) - } - - fn poll_close( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - // Like `poll_flush`, there is nothing to wait on here. A single - // call to `mpsc_sender.send(...)` is immediate from the perspective - // of the sender - Poll::Ready(Ok(())) - } -} diff --git a/crates/rmcp/tests/test_custom_headers.rs b/crates/rmcp/tests/test_custom_headers.rs index 82537a80..b88d0cd6 100644 --- a/crates/rmcp/tests/test_custom_headers.rs +++ b/crates/rmcp/tests/test_custom_headers.rs @@ -496,7 +496,8 @@ async fn test_mcp_custom_headers_sent_to_server() -> anyhow::Result<()> { let transport = StreamableHttpClientTransport::from_config(config); // Start MCP client with empty handler (this will trigger initialize request) - let client = ().serve(transport).await.expect("Failed to start client"); + let (client, work) = ().serve(transport).await.expect("Failed to start client"); + tokio::spawn(work); // Wait for initialize to be called tokio::time::timeout( @@ -526,7 +527,7 @@ async fn test_mcp_custom_headers_sent_to_server() -> anyhow::Result<()> { ); // Cleanup - drop(client); + client.cancel().await; server_handle.abort(); Ok(()) @@ -665,7 +666,8 @@ async fn test_mcp_protocol_version_header_sent_after_init() -> anyhow::Result<() StreamableHttpClientTransportConfig::with_uri(format!("http://127.0.0.1:{}/mcp", port)); let transport = StreamableHttpClientTransport::from_config(config); - let client = ().serve(transport).await.expect("Failed to start client"); + let (client, work) = ().serve(transport).await.expect("Failed to start client"); + tokio::spawn(work); tokio::time::timeout( std::time::Duration::from_secs(5), @@ -701,7 +703,7 @@ async fn test_mcp_protocol_version_header_sent_after_init() -> anyhow::Result<() "Initialized notification should include MCP-Protocol-Version: 2025-03-26" ); - drop(client); + client.cancel().await; server_handle.abort(); Ok(()) diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml index ea35b021..05238fc8 100644 --- a/examples/clients/Cargo.toml +++ b/examples/clients/Cargo.toml @@ -12,8 +12,9 @@ rmcp = { workspace = true, features = [ "reqwest", "transport-streamable-http-client-reqwest", "transport-child-process", + "transport-child-process-tokio", "tower", - "auth" + "auth", ] } tokio = { version = "1", features = ["full"] } serde = { version = "1.0", features = ["derive"] } diff --git a/examples/clients/src/auth/oauth_client.rs b/examples/clients/src/auth/oauth_client.rs index 456f3269..97f0e497 100644 --- a/examples/clients/src/auth/oauth_client.rs +++ b/examples/clients/src/auth/oauth_client.rs @@ -180,7 +180,8 @@ async fn main() -> Result<()> { // Create client and connect to MCP server let client_service = ClientInfo::default(); - let client = client_service.serve(transport).await?; + let (client, work) = client_service.serve(transport).await?; + tokio::spawn(work); tracing::info!("Successfully connected to MCP server"); // Test API requests diff --git a/examples/clients/src/collection.rs b/examples/clients/src/collection.rs index a4c73482..a181a423 100644 --- a/examples/clients/src/collection.rs +++ b/examples/clients/src/collection.rs @@ -8,9 +8,11 @@ use anyhow::Result; use rmcp::{ model::CallToolRequestParams, service::ServiceExt, - transport::{ConfigureCommandExt, TokioChildProcess}, + transport::{ + CommandBuilder, + child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + }, }; -use tokio::process::Command; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] @@ -26,14 +28,13 @@ async fn main() -> Result<()> { let mut clients_map = HashMap::new(); for idx in 0..10 { - let client = () - .into_dyn() - .serve(TokioChildProcess::new(Command::new("uvx").configure( - |cmd| { - cmd.arg("mcp-client-git"); - }, - ))?) - .await?; + let child_process = CommandBuilder::::new("uvx") + .arg("mcp-client-git") + .spawn_dyn()?; + let transport = ChildProcessTransport::new(child_process)?; + + let (client, work) = ().into_dyn().serve(transport).await?; + tokio::spawn(work); clients_map.insert(idx, client); } @@ -55,7 +56,7 @@ async fn main() -> Result<()> { .await?; } for (_, service) in clients_map { - service.cancel().await?; + service.cancel().await; } Ok(()) } diff --git a/examples/clients/src/everything_stdio.rs b/examples/clients/src/everything_stdio.rs index 763a880a..75731204 100644 --- a/examples/clients/src/everything_stdio.rs +++ b/examples/clients/src/everything_stdio.rs @@ -3,7 +3,10 @@ use rmcp::{ ServiceExt, model::{CallToolRequestParams, GetPromptRequestParams, ReadResourceRequestParams}, object, - transport::{ConfigureCommandExt, TokioChildProcess}, + transport::{ + self, CommandBuilder, + child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + }, }; use tokio::process::Command; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -19,13 +22,15 @@ async fn main() -> Result<()> { .with(tracing_subscriber::fmt::layer()) .init(); - let client = () - .serve(TokioChildProcess::new(Command::new("npx").configure( - |cmd| { - cmd.arg("-y").arg("@modelcontextprotocol/server-everything"); - }, - ))?) - .await?; + let command = CommandBuilder::::new("npx") + .arg("-y") + .arg("@modelcontextprotocol/server-everything") + .spawn_dyn()?; + + let transport = ChildProcessTransport::new(command)?; + + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); // Initialize let server_info = client.peer_info(); @@ -98,7 +103,7 @@ async fn main() -> Result<()> { let resource_templates = client.list_all_resource_templates().await?; tracing::info!("Available resource templates: {resource_templates:#?}"); - client.cancel().await?; + client.cancel().await; Ok(()) } diff --git a/examples/clients/src/git_stdio.rs b/examples/clients/src/git_stdio.rs index 9960c16b..ba78f213 100644 --- a/examples/clients/src/git_stdio.rs +++ b/examples/clients/src/git_stdio.rs @@ -2,9 +2,11 @@ use rmcp::{ RmcpError, model::CallToolRequestParams, service::ServiceExt, - transport::{ConfigureCommandExt, TokioChildProcess}, + transport::{ + CommandBuilder, + child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + }, }; -use tokio::process::Command; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[allow(clippy::result_large_err)] @@ -18,16 +20,18 @@ async fn main() -> Result<(), RmcpError> { ) .with(tracing_subscriber::fmt::layer()) .init(); - let client = () - .serve( - TokioChildProcess::new(Command::new("uvx").configure(|cmd| { - cmd.arg("mcp-server-git"); - })) - .map_err(RmcpError::transport_creation::)?, - ) - .await?; - // or serve_client((), TokioChildProcess::new(cmd)?).await?; + let command = CommandBuilder::::new("npx") + .arg("-y") + .arg("@modelcontextprotocol/server-everything") + .spawn_dyn() + .map_err(RmcpError::transport_creation::)?; + + let transport = ChildProcessTransport::new(command) + .map_err(RmcpError::transport_creation::)?; + + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); // Initialize let server_info = client.peer_info(); @@ -47,6 +51,6 @@ async fn main() -> Result<(), RmcpError> { }) .await?; tracing::info!("Tool result: {tool_result:#?}"); - client.cancel().await?; + client.cancel().await; Ok(()) } diff --git a/examples/clients/src/progress_client.rs b/examples/clients/src/progress_client.rs index db66a8ed..33e1351e 100644 --- a/examples/clients/src/progress_client.rs +++ b/examples/clients/src/progress_client.rs @@ -12,9 +12,12 @@ use rmcp::{ ProgressNotificationParam, }, service::{NotificationContext, RoleClient}, - transport::{StreamableHttpClientTransport, TokioChildProcess}, + transport::{ + CommandBuilder, StreamableHttpClientTransport, + child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + }, }; -use tokio::{process::Command, time::sleep}; +use tokio::time::sleep; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[derive(Debug, Clone, ValueEnum)] @@ -148,23 +151,25 @@ async fn test_stdio_transport(records: u32) -> Result<()> { let servers_dir = workspace_root.join("examples").join("servers"); // Start server process - let mut server_cmd = Command::new("cargo"); - server_cmd + let server_cmd = CommandBuilder::::new("cargo") .current_dir(servers_dir) .arg("run") .arg("--example") .arg("servers_progress_demo") .arg("--") - .arg("stdio"); + .arg("stdio") + .spawn_dyn()?; // Create progress-aware client handler let client_handler = ProgressAwareClient::new(); client_handler.start_tracking(); let client_handler_clone = client_handler.clone(); - let service = client_handler - .serve(TokioChildProcess::new(server_cmd)?) - .await?; + let server_transport = ChildProcessTransport::new(server_cmd)?; + + let (service, work) = client_handler.serve(server_transport).await?; + + tokio::spawn(work); // Initialize let server_info = service.peer_info(); @@ -196,7 +201,7 @@ async fn test_stdio_transport(records: u32) -> Result<()> { } } - service.cancel().await?; + service.cancel().await; client_handler_clone.stop_tracking(); tracing::info!("STDIO transport test completed successfully!"); Ok(()) @@ -218,9 +223,10 @@ async fn test_http_transport(http_url: &str, records: u32) -> Result<()> { client_handler.start_tracking(); let client_handler_clone = client_handler.clone(); - let client = client_handler.serve(transport).await.inspect_err(|e| { + let (client, work) = client_handler.serve(transport).await.inspect_err(|e| { tracing::error!("HTTP client error: {:?}", e); })?; + tokio::spawn(work); // Initialize let server_info = client.peer_info(); @@ -252,7 +258,7 @@ async fn test_http_transport(http_url: &str, records: u32) -> Result<()> { } } - client.cancel().await?; + client.cancel().await; client_handler_clone.stop_tracking(); tracing::info!("HTTP transport test completed successfully!"); Ok(()) diff --git a/examples/clients/src/sampling_stdio.rs b/examples/clients/src/sampling_stdio.rs index e2a7a6d5..7d8e5e6f 100644 --- a/examples/clients/src/sampling_stdio.rs +++ b/examples/clients/src/sampling_stdio.rs @@ -4,9 +4,11 @@ use rmcp::{ model::*, object, service::{RequestContext, RoleClient}, - transport::{ConfigureCommandExt, TokioChildProcess}, + transport::{ + CommandBuilder, + child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + }, }; -use tokio::process::Command; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; /// Simple Sampling Demo Client /// @@ -69,19 +71,20 @@ async fn main() -> Result<()> { .expect("CARGO_MANIFEST_DIR is not set") .join("servers"); - let client = client - .serve(TokioChildProcess::new(Command::new("cargo").configure( - |cmd| { - cmd.arg("run") - .arg("--example") - .arg("servers_sampling_stdio") - .current_dir(servers_dir); - }, - ))?) - .await - .inspect_err(|e| { - tracing::error!("client error: {:?}", e); - })?; + let command = CommandBuilder::::new("cargo") + .arg("run") + .arg("--example") + .arg("servers_sampling_stdio") + .current_dir(servers_dir) + .spawn_dyn()?; + + let transport = ChildProcessTransport::new(command)?; + + let (client, work) = client.serve(transport).await.inspect_err(|e| { + tracing::error!("client error: {:?}", e); + })?; + + tokio::spawn(work); // Wait for initialization tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; @@ -118,6 +121,6 @@ async fn main() -> Result<()> { tracing::info!("Sampling demo completed successfully!"); tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; - client.cancel().await?; + client.cancel().await; Ok(()) } diff --git a/examples/clients/src/streamable_http.rs b/examples/clients/src/streamable_http.rs index baf1838a..50e76541 100644 --- a/examples/clients/src/streamable_http.rs +++ b/examples/clients/src/streamable_http.rs @@ -30,10 +30,12 @@ async fn main() -> Result<()> { icons: None, }, }; - let client = client_info.serve(transport).await.inspect_err(|e| { + let (client, work) = client_info.serve(transport).await.inspect_err(|e| { tracing::error!("client error: {:?}", e); })?; + tokio::spawn(work); + // Initialize let server_info = client.peer_info(); tracing::info!("Connected to server: {server_info:#?}"); @@ -51,6 +53,6 @@ async fn main() -> Result<()> { }) .await?; tracing::info!("Tool result: {tool_result:#?}"); - client.cancel().await?; + client.cancel().await; Ok(()) } diff --git a/examples/rig-integration/Cargo.toml b/examples/rig-integration/Cargo.toml index 9429975e..1c98aefd 100644 --- a/examples/rig-integration/Cargo.toml +++ b/examples/rig-integration/Cargo.toml @@ -18,7 +18,8 @@ tokio = { version = "1", features = ["full"] } rmcp = { workspace = true, features = [ "client", "transport-child-process", - "transport-streamable-http-client-reqwest" + "transport-child-process-tokio", + "transport-streamable-http-client-reqwest", ] } anyhow = "1.0" serde_json = "1" diff --git a/examples/rig-integration/src/config/mcp.rs b/examples/rig-integration/src/config/mcp.rs index 45e4c23c..d5b398d2 100644 --- a/examples/rig-integration/src/config/mcp.rs +++ b/examples/rig-integration/src/config/mcp.rs @@ -1,6 +1,13 @@ use std::{collections::HashMap, process::Stdio}; -use rmcp::{RoleClient, ServiceExt, service::RunningService, transport::ConfigureCommandExt}; +use rmcp::{ + RoleClient, ServiceExt, + service::RunningService, + transport::{ + CommandBuilder, + child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + }, +}; use serde::{Deserialize, Serialize}; use crate::mcp_adaptor::McpManager; @@ -63,19 +70,24 @@ impl McpServerTransportConfig { McpServerTransportConfig::Streamable { url } => { let transport = rmcp::transport::StreamableHttpClientTransport::from_uri(url.to_string()); - ().serve(transport).await? + let (service, work) = ().serve(transport).await?; + tokio::spawn(work); + service } McpServerTransportConfig::Stdio { command, args, envs, } => { - let transport = rmcp::transport::TokioChildProcess::new( - tokio::process::Command::new(command).configure(|cmd| { - cmd.args(args).envs(envs).stderr(Stdio::null()); - }), - )?; - ().serve(transport).await? + let command = CommandBuilder::::new(command) + .args(args) + .envs(envs) + .stderr(Stdio::null()) + .spawn_dyn()?; + let transport = ChildProcessTransport::new(command)?; + let (service, work) = ().serve(transport).await?; + tokio::spawn(work); + service } }; Ok(client) diff --git a/examples/servers/src/calculator_stdio.rs b/examples/servers/src/calculator_stdio.rs index 6af82042..c27c69ad 100644 --- a/examples/servers/src/calculator_stdio.rs +++ b/examples/servers/src/calculator_stdio.rs @@ -17,10 +17,12 @@ async fn main() -> Result<()> { tracing::info!("Starting Calculator MCP server"); // Create an instance of our calculator router - let service = Calculator::new().serve(stdio()).await.inspect_err(|e| { + let (service, work) = Calculator::new().serve(stdio()).await.inspect_err(|e| { tracing::error!("serving error: {:?}", e); })?; - service.waiting().await?; + tokio::spawn(work); + + service.waiting().await; Ok(()) } diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index e92b142a..207ec37e 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -1,7 +1,6 @@ #![allow(dead_code)] use std::{any::Any, sync::Arc}; -use chrono::Utc; use rmcp::{ ErrorData as McpError, RoleServer, ServerHandler, handler::server::{ @@ -12,14 +11,11 @@ use rmcp::{ prompt, prompt_handler, prompt_router, schemars, service::RequestContext, task_handler, - task_manager::{ - OperationDescriptor, OperationMessage, OperationProcessor, OperationResultTransport, - }, + task_manager::{OperationProcessor, OperationResultTransport}, tool, tool_handler, tool_router, }; use serde_json::json; use tokio::sync::Mutex; -use tracing::info; struct ToolCallOperationResult { id: String, @@ -69,11 +65,14 @@ pub struct Counter { impl Counter { #[allow(dead_code)] pub fn new() -> Self { + let (processor, processor_work) = OperationProcessor::new(); + tokio::spawn(processor_work); + Self { counter: Arc::new(Mutex::new(0)), tool_router: Self::tool_router(), prompt_router: Self::prompt_router(), - processor: Arc::new(Mutex::new(OperationProcessor::new())), + processor: Arc::new(Mutex::new(processor)), } } @@ -353,12 +352,14 @@ mod tests { let (server_transport, client_transport) = tokio::io::duplex(4096); let server_handle = tokio::spawn(async move { - let service = counter.serve(server_transport).await?; - service.waiting().await?; + let (service, work) = counter.serve(server_transport).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Ok(()) }); - let client_service = client.serve(client_transport).await?; + let (client_service, work) = client.serve(client_transport).await?; + tokio::spawn(work); let mut task_meta = serde_json::Map::new(); task_meta.insert( "source".into(), @@ -395,7 +396,7 @@ mod tests { let running = processor.lock().await.running_task_count(); assert_eq!(running, 1); - client_service.cancel().await?; + client_service.cancel().await; let _ = server_handle.await; Ok(()) } diff --git a/examples/servers/src/completion_stdio.rs b/examples/servers/src/completion_stdio.rs index e4365cad..f4371315 100644 --- a/examples/servers/src/completion_stdio.rs +++ b/examples/servers/src/completion_stdio.rs @@ -446,10 +446,12 @@ async fn main() -> Result<()> { println!(); let server = SqlQueryServer::new(); - let service = server.serve(stdio()).await.inspect_err(|e| { + let (service, work) = server.serve(stdio()).await.inspect_err(|e| { tracing::error!("Server error: {:?}", e); })?; - service.waiting().await?; + tokio::spawn(work); + + service.waiting().await; Ok(()) } diff --git a/examples/servers/src/counter_stdio.rs b/examples/servers/src/counter_stdio.rs index 9339ab86..18a21592 100644 --- a/examples/servers/src/counter_stdio.rs +++ b/examples/servers/src/counter_stdio.rs @@ -16,10 +16,12 @@ async fn main() -> Result<()> { tracing::info!("Starting MCP server"); // Create an instance of our counter router - let service = Counter::new().serve(stdio()).await.inspect_err(|e| { + let (service, work) = Counter::new().serve(stdio()).await.inspect_err(|e| { tracing::error!("serving error: {:?}", e); })?; - service.waiting().await?; + tokio::spawn(work); + + service.waiting().await; Ok(()) } diff --git a/examples/servers/src/elicitation_stdio.rs b/examples/servers/src/elicitation_stdio.rs index 82f8d696..1fec1dff 100644 --- a/examples/servers/src/elicitation_stdio.rs +++ b/examples/servers/src/elicitation_stdio.rs @@ -182,13 +182,15 @@ async fn main() -> Result<()> { println!("1. Run: npx @modelcontextprotocol/inspector"); println!("2. Enter server command: {}", current_exe); - let service = ElicitationServer::new() + let (service, work) = ElicitationServer::new() .serve(stdio()) .await .inspect_err(|e| { tracing::error!("serving error: {:?}", e); })?; - service.waiting().await?; + tokio::spawn(work); + + service.waiting().await; Ok(()) } diff --git a/examples/servers/src/progress_demo.rs b/examples/servers/src/progress_demo.rs index e9e147ea..c12572c8 100644 --- a/examples/servers/src/progress_demo.rs +++ b/examples/servers/src/progress_demo.rs @@ -33,11 +33,12 @@ async fn main() -> anyhow::Result<()> { async fn run_stdio() -> anyhow::Result<()> { let server = ProgressDemo::new(); - let service = server.serve(stdio()).await.inspect_err(|e| { + let (service, work) = server.serve(stdio()).await.inspect_err(|e| { tracing::error!("stdio serving error: {:?}", e); })?; + tokio::spawn(work); - service.waiting().await?; + service.waiting().await; Ok(()) } diff --git a/examples/servers/src/prompt_stdio.rs b/examples/servers/src/prompt_stdio.rs index 0937c3e2..ceac5d1d 100644 --- a/examples/servers/src/prompt_stdio.rs +++ b/examples/servers/src/prompt_stdio.rs @@ -414,10 +414,12 @@ async fn main() -> Result<()> { println!(); let server = PromptServer::new(); - let service = server.serve(stdio()).await.inspect_err(|e| { + let (service, work) = server.serve(stdio()).await.inspect_err(|e| { tracing::error!("Server error: {:?}", e); })?; - service.waiting().await?; + tokio::spawn(work); + + service.waiting().await; Ok(()) } diff --git a/examples/servers/src/sampling_stdio.rs b/examples/servers/src/sampling_stdio.rs index 297af9d0..6e6c774d 100644 --- a/examples/servers/src/sampling_stdio.rs +++ b/examples/servers/src/sampling_stdio.rs @@ -146,10 +146,12 @@ async fn main() -> Result<()> { tracing::info!("Starting Sampling Demo Server"); // Create and serve the sampling demo server - let service = SamplingDemoServer.serve(stdio()).await.inspect_err(|e| { + let (service, work) = SamplingDemoServer.serve(stdio()).await.inspect_err(|e| { tracing::error!("Serving error: {:?}", e); })?; - service.waiting().await?; + tokio::spawn(work); + + service.waiting().await; Ok(()) } diff --git a/examples/servers/src/structured_output.rs b/examples/servers/src/structured_output.rs index c30a5142..e4c767c1 100644 --- a/examples/servers/src/structured_output.rs +++ b/examples/servers/src/structured_output.rs @@ -151,8 +151,9 @@ async fn main() -> anyhow::Result<()> { eprintln!("Starting server. Connect with an MCP client to test the tools."); eprintln!("Press Ctrl+C to stop."); - let service = server.serve(stdio()).await?; - service.waiting().await?; + let (service, work) = server.serve(stdio()).await?; + tokio::spawn(work); + service.waiting().await; Ok(()) } diff --git a/examples/transport/src/http_upgrade.rs b/examples/transport/src/http_upgrade.rs index 6a15add3..da96f82c 100644 --- a/examples/transport/src/http_upgrade.rs +++ b/examples/transport/src/http_upgrade.rs @@ -16,7 +16,7 @@ async fn main() -> anyhow::Result<()> { start_server().await?; let client = http_client("127.0.0.1:8001").await?; let tools = client.list_all_tools().await?; - client.cancel().await?; + client.cancel().await; tracing::info!("{:#?}", tools); Ok(()) } @@ -24,8 +24,9 @@ async fn main() -> anyhow::Result<()> { async fn http_server(req: Request) -> Result, hyper::Error> { tokio::spawn(async move { let upgraded = hyper::upgrade::on(req).await?; - let service = Calculator::new().serve(TokioIo::new(upgraded)).await?; - service.waiting().await?; + let (service, work) = Calculator::new().serve(TokioIo::new(upgraded)).await?; + tokio::spawn(work); + service.waiting().await; anyhow::Result::<()>::Ok(()) }); let mut response = hyper::Response::new(String::new()); @@ -46,7 +47,8 @@ async fn http_client(uri: &str) -> anyhow::Result .insert(UPGRADE, HeaderValue::from_static("mcp")); let response = s.send_request(req).await?; let upgraded = hyper::upgrade::on(response).await?; - let client = ().serve(TokioIo::new(upgraded)).await?; + let (client, work) = ().serve(TokioIo::new(upgraded)).await?; + tokio::spawn(work); Ok(client) } diff --git a/examples/transport/src/websocket.rs b/examples/transport/src/websocket.rs index 5ba23546..4a61bc62 100644 --- a/examples/transport/src/websocket.rs +++ b/examples/transport/src/websocket.rs @@ -17,7 +17,7 @@ async fn main() -> anyhow::Result<()> { start_server().await?; let client = http_client("ws://127.0.0.1:8001").await?; let tools = client.list_all_tools().await?; - client.cancel().await?; + client.cancel().await; tracing::info!("{:#?}", tools); Ok(()) } @@ -28,7 +28,8 @@ async fn http_client(uri: &str) -> anyhow::Result return Err(anyhow::anyhow!("failed to upgrade connection")); } let transport = WebsocketTransport::new_client(stream); - let client = ().serve(transport).await?; + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); Ok(client) } @@ -40,8 +41,9 @@ async fn start_server() -> anyhow::Result<()> { tokio::spawn(async move { let ws_stream = tokio_tungstenite::accept_async(stream).await?; let transport = WebsocketTransport::new_server(ws_stream); - let server = Calculator::new().serve(transport).await?; - server.waiting().await?; + let (server, work) = Calculator::new().serve(transport).await?; + tokio::spawn(work); + server.waiting().await; Ok::<(), anyhow::Error>(()) }); } diff --git a/examples/wasi/src/lib.rs b/examples/wasi/src/lib.rs index 2690cc73..dea993bb 100644 --- a/examples/wasi/src/lib.rs +++ b/examples/wasi/src/lib.rs @@ -112,11 +112,12 @@ impl wasi::exports::cli::run::Guest for TokioCliRunner { .with_writer(std::io::stderr) .with_ansi(false) .init(); - let server = calculator::Calculator::new() + let (server, work) = calculator::Calculator::new() .serve(wasi_io()) .await .unwrap(); - server.waiting().await.unwrap(); + tokio::spawn(work); + server.waiting().await; }); Ok(()) } From 54825457b9eae289392eed5a1748a10dd211a4ba Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Tue, 3 Mar 2026 23:56:27 +0000 Subject: [PATCH 14/20] refactor(docs): update docs to new call convention for serving add additional docs to the `ChildProcessRunner` trait --- README.md | 28 +++++++++--- crates/rmcp/README.md | 43 +++++++++++++------ .../src/transport/child_process/runner.rs | 29 +++++++++++++ docs/OAUTH_SUPPORT.md | 4 +- docs/readme/README.zh-cn.md | 26 ++++++++--- 5 files changed, 104 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index b2d17c08..7068e182 100644 --- a/README.md +++ b/README.md @@ -38,14 +38,32 @@ Json Schema generation (version 2020-12): Start a client ```rust, ignore -use rmcp::{ServiceExt, transport::{TokioChildProcess, ConfigureCommandExt}}; +use rmcp::{ + ServiceExt, + transport::{CommandBuilder, ChildProcessTransport, tokio::TokioChildProcessRunner} +}; use tokio::process::Command; #[tokio::main] async fn main() -> Result<(), Box> { - let client = ().serve(TokioChildProcess::new(Command::new("npx").configure(|cmd| { - cmd.arg("-y").arg("@modelcontextprotocol/server-everything"); - }))?).await?; + + // Build and spawn a child process + let command = CommandBuilder::::new("npx") + .arg("-y") + .arg("@modelcontextprotocol/server-everything") + .spawn_dyn()? + + // Create a transport via the child process's STDIN and STDOUT streams + let transport = ChildProcessTransport::new(command)? + + let (client, work) = ().serve(transport).await?; + // Spawn the async work loop on the background + tokio::spawn(work); + + // Use the client ... + + // Finish using the client + client.cancel().await; Ok(()) } ``` @@ -78,7 +96,7 @@ let service = common::counter::Counter::new(); ```rust, ignore // this call will finish the initialization process -let server = service.serve(transport).await?; +let (server, work) = service.serve(transport).await?; ``` diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index 217b22cd..1444deee 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -80,9 +80,12 @@ impl ServerHandler for Counter { #[tokio::main] async fn main() -> Result<(), Box> { // Create and run the server with STDIO transport - let service = Counter::new().serve(stdio()).await.inspect_err(|e| { + let (service, work) = Counter::new().serve(stdio()).await.inspect_err(|e| { println!("Error starting server: {}", e); })?; + // Spawn the async work loop on the background + tokio::spawn(work); + // Wait for the service to conclude service.waiting().await?; Ok(()) } @@ -151,20 +154,25 @@ Creating a client to interact with a server: use rmcp::{ ServiceExt, model::CallToolRequestParams, - transport::{ConfigureCommandExt, TokioChildProcess}, + transport::{CommandBuilder, ChildProcessTransport, tokio::TokioChildProcessRunner} }; use tokio::process::Command; #[tokio::main] async fn main() -> Result<(), Box> { + // Connect to a server running as a child process - let service = () - .serve(TokioChildProcess::new(Command::new("uvx").configure( - |cmd| { - cmd.arg("mcp-server-git"); - }, - ))?) - .await?; + let command = CommandBuilder::::new("uvx") + .arg("mcp-server-git") + .spawn_dyn()? + + // Create a transport via the child process's STDIN and STDOUT streams + let transport = ChildProcessTransport::new(command)? + + let (service, work) = ().serve(transport).await?; + + // Spawn the async work loop on the background + tokio::spawn(work); // Get server information let server_info = service.peer_info(); @@ -186,7 +194,7 @@ async fn main() -> Result<(), Box> { println!("Result: {result:#?}"); // Gracefully close the connection - service.cancel().await?; + service.cancel().await; Ok(()) } ``` @@ -210,11 +218,18 @@ Run MCP servers as child processes and communicate via standard I/O. Example: ```rust,ignore -use rmcp::transport::TokioChildProcess; -use tokio::process::Command; +use rmcp::transport::{ + CommandBuilder, + ChildProcessTransport, + + // Included tokio command runner implementation in the "transport-child-process-tokio" feature, + // or implement your own via the `ChildProcessRunner` trait + tokio::TokioChildProcessRunner +}; -let transport = TokioChildProcess::new(Command::new("mcp-server"))?; -let service = client.serve(transport).await?; +let command = CommandBuilder::::new("mcp-server").spawn_dyn()?; +let transport = ChildProcessTransport::new(command)?; +let (service, work) = client.serve(transport).await?; ``` ## Access with peer interface when handling message diff --git a/crates/rmcp/src/transport/child_process/runner.rs b/crates/rmcp/src/transport/child_process/runner.rs index bc457bb2..9b76d1e2 100644 --- a/crates/rmcp/src/transport/child_process/runner.rs +++ b/crates/rmcp/src/transport/child_process/runner.rs @@ -94,6 +94,35 @@ pub enum RunnerSpawnError { Other(Box), } +/// A trait that defines the implementation needed for spawning a child process via the [CommandBuilder] and used in the [ChildProcessTransport]. +/// You can implement this trait however you'd like, usually using a unit-struct. Your implmentation can then be used with the [CommandBuilder]. +/// +/// Here is a high-level example of how you'd implement your own command runner: +/// ```rust,ignore +/// // Define the type for a child process instance that your runner will spawn. +/// struct MyChildProcessInstance { +/// // ... +/// } +/// +/// impl ChildProcessInstance for MyChildProcessInstance { +/// type Stdin = ...; +/// type Stdout = ...; +/// type Stderr = ...; +/// // Implement the required methods for taking the streams, getting the PID, and waiting/shutting down/killing the process. +/// } +/// +/// impl ChildProcessRunner for MyRunner { +/// type Instance = MyChildProcessInstance; +/// +/// fn spawn(command_config: CommandConfig) -> Result { +/// // Use the information in command_config to spawn your child process instance, and return it +/// } +/// } +/// +/// +/// // Use your implementation with the command builder +/// let command = CommandBuilder::::new("my_command").arg("some_arg").spawn_dyn(); +/// ``` pub trait ChildProcessRunner { /// The implementation of the child process instance that this runner will spawn. type Instance: ChildProcessInstance; diff --git a/docs/OAUTH_SUPPORT.md b/docs/OAUTH_SUPPORT.md index b0b59f9f..20346d46 100644 --- a/docs/OAUTH_SUPPORT.md +++ b/docs/OAUTH_SUPPORT.md @@ -81,7 +81,9 @@ If you know the scopes you need, you can still pass them explicitly: // create client and connect to MCP server let client_service = ClientInfo::default(); - let client = client_service.serve(transport).await?; + let (client, work) = client_service.serve(transport).await?; + // poll the async work loop, ex spawn it on the background + tokio::spawn(work); ``` ### 5. Handle scope upgrades diff --git a/docs/readme/README.zh-cn.md b/docs/readme/README.zh-cn.md index f666928f..8c188119 100644 --- a/docs/readme/README.zh-cn.md +++ b/docs/readme/README.zh-cn.md @@ -38,14 +38,27 @@ JSON Schema 生成 (version 2020-12): 启动客户端 ```rust, ignore -use rmcp::{ServiceExt, transport::{TokioChildProcess, ConfigureCommandExt}}; -use tokio::process::Command; +use rmcp::transport::{ + CommandBuilder, + ChildProcessTransport, + tokio::TokioChildProcessRunner +}; #[tokio::main] async fn main() -> Result<(), Box> { - let client = ().serve(TokioChildProcess::new(Command::new("npx").configure(|cmd| { - cmd.arg("-y").arg("@modelcontextprotocol/server-everything"); - }))?).await?; + let command = CommandBuilder::::new("npx") + .arg("-y") + .arg("@modelcontextprotocol/server-everything") + .spawn_dyn()? + + let transport = ChildProcessTransport::new(command)? + + let (client, work) = ().serve(transport).await?; + tokio::spawn(work); + + // ... + + client.canel().await; Ok(()) } ``` @@ -78,7 +91,8 @@ let service = common::counter::Counter::new(); ```rust, ignore // 此调用将完成初始化过程 -let server = service.serve(transport).await?; +let (server, work) = service.serve(transport).await?; +tokio::spawn(work); ``` From a35067a9dd75c8516804feb2886ad130bf88f3ee Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Wed, 4 Mar 2026 00:31:10 +0000 Subject: [PATCH 15/20] refactor(http): change to futures unordered --- .../src/transport/streamable_http_client.rs | 49 ++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 779dfe1c..6878a158 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -1,6 +1,10 @@ use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration}; -use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream}; +use futures::{ + FutureExt, Stream, StreamExt, + future::BoxFuture, + stream::{BoxStream, FuturesUnordered}, +}; use http::{HeaderName, HeaderValue}; pub use sse_stream::Error as SseError; use sse_stream::Sse; @@ -423,7 +427,7 @@ impl Worker for StreamableHttpClientWorker { ServerMessage(ServerJsonRpcMessage), StreamResult(Result<(), StreamableHttpError>), } - let mut streams = tokio::task::JoinSet::new(); + let mut streams = FuturesUnordered::new(); if let Some(session_id) = &session_id { let client = self.client.clone(); let uri = config.uri.clone(); @@ -436,7 +440,7 @@ impl Worker for StreamableHttpClientWorker { let config_auth_header = config.auth_header.clone(); let spawn_headers = protocol_headers.clone(); - streams.spawn(async move { + let work = async move { match client .get_stream( uri.clone(), @@ -477,7 +481,10 @@ impl Worker for StreamableHttpClientWorker { Err(e) } } - }); + } + .boxed(); + + streams.push(work); } // Main event loop - capture exit reason so we can do cleanup before returning let loop_result: Result<(), WorkerQuitReason> = 'main_loop: loop { @@ -499,10 +506,10 @@ impl Worker for StreamableHttpClientWorker { }; Event::ServerMessage(message) }, - terminated_stream = streams.join_next(), if !streams.is_empty() => { + terminated_stream = streams.next(), if !streams.is_empty() => { match terminated_stream { Some(result) => { - Event::StreamResult(result.map_err(StreamableHttpError::TokioJoinError).and_then(std::convert::identity)) + Event::StreamResult(result) } None => { continue @@ -546,23 +553,29 @@ impl Worker for StreamableHttpClientWorker { }, self.config.retry_config.clone(), ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - true, - transport_task_ct.child_token(), - )); + streams.push( + Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + ) + .boxed(), + ); } else { let sse_stream = SseAutoReconnectStream::never_reconnect( stream, StreamableHttpError::::UnexpectedEndOfStream, ); - streams.spawn(Self::execute_sse_stream( - sse_stream, - sse_worker_tx.clone(), - true, - transport_task_ct.child_token(), - )); + streams.push( + Self::execute_sse_stream( + sse_stream, + sse_worker_tx.clone(), + true, + transport_task_ct.child_token(), + ) + .boxed(), + ); } tracing::trace!("got new sse stream"); Ok(()) From a0724d36b5931284147c1b57df0019a9ed4701d2 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Wed, 4 Mar 2026 00:50:16 +0000 Subject: [PATCH 16/20] refactor(worker): remove spawn from worker this will provoke some cascading changes in the streamable HTTP client too now work will need to be explicitly managed and bubled to the top --- .../src/transport/streamable_http_client.rs | 7 ++- .../streamable_http_server/session/local.rs | 3 +- crates/rmcp/src/transport/worker.rs | 43 ++++++++++++------- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 6878a158..7bdf580f 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -822,9 +822,12 @@ impl StreamableHttpClientTransport { /// StreamableHttpClientTransportConfig::with_uri("http://localhost:8000/mcp") /// ); /// ``` - pub fn with_client(client: C, config: StreamableHttpClientTransportConfig) -> Self { + pub fn with_client( + client: C, + config: StreamableHttpClientTransportConfig, + ) -> (Self, impl Future + Send + 'static) { let worker = StreamableHttpClientWorker::new(client, config); - WorkerTransport::spawn(worker) + WorkerTransport::new(worker) } } #[derive(Debug, Clone)] diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index 6e197b5b..229f5396 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -43,6 +43,7 @@ pub enum LocalSessionManagerError { #[error("Invalid event id: {0}")] InvalidEventId(#[from] EventIdParseError), } + impl SessionManager for LocalSessionManager { type Error = LocalSessionManagerError; type Transport = WorkerTransport; @@ -50,7 +51,7 @@ impl SessionManager for LocalSessionManager { let id = session_id(); let (handle, worker) = create_local_session(id.clone(), self.session_config.clone()); self.sessions.write().await.insert(id.clone(), handle); - Ok((id, WorkerTransport::spawn(worker))) + Ok((id, WorkerTransport::new(worker))) } async fn initialize_session( &self, diff --git a/crates/rmcp/src/transport/worker.rs b/crates/rmcp/src/transport/worker.rs index 769d448a..3ea5b2e7 100644 --- a/crates/rmcp/src/transport/worker.rs +++ b/crates/rmcp/src/transport/worker.rs @@ -1,9 +1,10 @@ use std::borrow::Cow; +use futures::{FutureExt, future::RemoteHandle}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, Level}; -use super::{IntoTransport, Transport}; +use super::Transport; use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; #[derive(Debug, thiserror::Error)] @@ -60,7 +61,7 @@ pub struct WorkerSendRequest { pub struct WorkerTransport { rx: tokio::sync::mpsc::Receiver>, send_service: tokio::sync::mpsc::Sender>, - join_handle: Option>>>, + join_handle: Option>>>, _drop_guard: tokio_util::sync::DropGuard, ct: CancellationToken, } @@ -80,20 +81,25 @@ impl Default for WorkerConfig { } pub enum WorkerAdapter {} -impl IntoTransport for W { - fn into_transport(self) -> impl Transport + 'static { - WorkerTransport::spawn(self) - } -} +// This can't be implemented if we are removing integrated "spawning" on +// an async runtime. +// impl IntoTransport for W { +// fn into_transport(self) -> impl Transport + 'static { +// WorkerTransport::new(self) +// } +// } impl WorkerTransport { pub fn cancel_token(&self) -> CancellationToken { self.ct.clone() } - pub fn spawn(worker: W) -> Self { - Self::spawn_with_ct(worker, CancellationToken::new()) + pub fn new(worker: W) -> (Self, impl Future + Send + 'static) { + Self::new_with_ct(worker, CancellationToken::new()) } - pub fn spawn_with_ct(worker: W, transport_task_ct: CancellationToken) -> Self { + pub fn new_with_ct( + worker: W, + transport_task_ct: CancellationToken, + ) -> (Self, impl Future + Send + 'static) { let config = worker.config(); let worker_name = config.name; let (to_transport_tx, from_handler_rx) = @@ -106,7 +112,7 @@ impl WorkerTransport { cancellation_token: transport_task_ct.clone(), }; - let join_handle = tokio::spawn(async move { + let work = async move { worker .run(context) .instrument(tracing::span!( @@ -131,14 +137,19 @@ impl WorkerTransport { .inspect(|_| { tracing::debug!("worker quit"); }) - }); - Self { + }; + + let (work, remote_handle) = work.remote_handle(); + + let this = Self { rx: from_transport_rx, send_service: to_transport_tx, - join_handle: Some(join_handle), + join_handle: Some(remote_handle), ct: transport_task_ct.clone(), _drop_guard: transport_task_ct.drop_guard(), - } + }; + + (this, work) } } @@ -199,7 +210,7 @@ impl Transport for WorkerTransport { async fn close(&mut self) -> Result<(), Self::Error> { if let Some(handle) = self.join_handle.take() { self.ct.cancel(); - let _quit_reason = handle.await.map_err(W::err_join)?; + let _quit_reason = handle.await; Ok(()) } else { Ok(()) From d85b6df7982cef10f743f546b0859107233237d2 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Thu, 5 Mar 2026 04:31:17 +0000 Subject: [PATCH 17/20] refactor(http): explicitly bubble work task up to be spawned --- conformance/src/bin/client.rs | 69 ++++++++++------- conformance/src/bin/server.rs | 3 +- .../common/reqwest/streamable_http_client.rs | 6 +- .../streamable_http_server/session.rs | 11 ++- .../streamable_http_server/session/local.rs | 14 +++- .../streamable_http_server/session/never.rs | 16 +++- .../transport/streamable_http_server/tower.rs | 76 ++++++++++++++++--- crates/rmcp/src/util.rs | 4 - crates/rmcp/tests/test_close_connection.rs | 2 +- crates/rmcp/tests/test_custom_headers.rs | 10 ++- .../rmcp/tests/test_sse_concurrent_streams.rs | 4 +- .../test_streamable_http_json_response.rs | 4 +- .../tests/test_streamable_http_priming.rs | 8 +- crates/rmcp/tests/test_with_js.rs | 6 +- examples/clients/src/auth/oauth_client.rs | 3 +- examples/clients/src/progress_client.rs | 3 +- examples/clients/src/sampling_stdio.rs | 2 +- examples/clients/src/streamable_http.rs | 5 +- examples/rig-integration/src/config/mcp.rs | 3 +- examples/servers/src/cimd_auth_streamhttp.rs | 12 +-- .../servers/src/complex_auth_streamhttp.rs | 12 +-- .../src/counter_hyper_streamable_http.rs | 6 +- examples/servers/src/counter_streamhttp.rs | 3 +- .../servers/src/elicitation_enum_inference.rs | 3 +- examples/servers/src/progress_demo.rs | 6 +- .../servers/src/simple_auth_streamhttp.rs | 12 +-- examples/simple-chat-client/src/config.rs | 3 +- 27 files changed, 216 insertions(+), 90 deletions(-) diff --git a/conformance/src/bin/client.rs b/conformance/src/bin/client.rs index 061c2e2b..e55856c7 100644 --- a/conformance/src/bin/client.rs +++ b/conformance/src/bin/client.rs @@ -320,10 +320,11 @@ async fn perform_oauth_flow_preregistered( async fn run_auth_client(server_url: &str, ctx: &ConformanceContext) -> anyhow::Result<()> { let auth_client = perform_oauth_flow(server_url, ctx).await?; - let transport = StreamableHttpClientTransport::with_client( + let (transport, http_work) = StreamableHttpClientTransport::with_client( auth_client, StreamableHttpClientTransportConfig::with_uri(server_url), ); + tokio::spawn(http_work); let (client, work) = BasicClientHandler.serve(transport).await?; // Run the client work loop in the background while we interact with it @@ -347,7 +348,7 @@ async fn run_auth_client(server_url: &str, ctx: &ConformanceContext) -> anyhow:: .await; } - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -377,10 +378,11 @@ async fn run_auth_scope_step_up_client( .ok_or_else(|| anyhow::anyhow!("No AM"))?; let auth_client = AuthClient::new(reqwest::Client::default(), am); - let transport = StreamableHttpClientTransport::with_client( + let (transport, http_work) = StreamableHttpClientTransport::with_client( auth_client.clone(), StreamableHttpClientTransportConfig::with_uri(server_url), ); + tokio::spawn(http_work); let (client, work) = BasicClientHandler.serve(transport).await?; // Run the client work loop in the background while we interact with it @@ -407,7 +409,7 @@ async fn run_auth_scope_step_up_client( Err(_) => { tracing::debug!("Tool call failed (likely 403), attempting scope upgrade..."); // Drop old client, re-auth with upgraded scopes - client.cancel().await.ok(); + client.cancel().await; // Re-do the full flow; the server will give us the right scopes // on the second authorization request. @@ -427,10 +429,11 @@ async fn run_auth_scope_step_up_client( let am2 = oauth2.into_authorization_manager().unwrap(); let auth_client2 = AuthClient::new(reqwest::Client::default(), am2); - let transport2 = StreamableHttpClientTransport::with_client( + let (transport2, http_work2) = StreamableHttpClientTransport::with_client( auth_client2, StreamableHttpClientTransportConfig::with_uri(server_url), ); + tokio::spawn(http_work2); let (client2, work2) = BasicClientHandler.serve(transport2).await?; tokio::spawn(work2); let _ = client2 @@ -441,13 +444,13 @@ async fn run_auth_scope_step_up_client( task: None, }) .await; - client2.cancel().await.ok(); + client2.cancel().await; return Ok(()); } } } - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -475,10 +478,11 @@ async fn run_auth_scope_retry_limit_client( let am = oauth.into_authorization_manager().unwrap(); let auth_client = AuthClient::new(reqwest::Client::default(), am); - let transport = StreamableHttpClientTransport::with_client( + let (transport, http_work) = StreamableHttpClientTransport::with_client( auth_client, StreamableHttpClientTransportConfig::with_uri(server_url), ); + tokio::spawn(http_work); let (client, work) = BasicClientHandler.serve(transport).await?; tokio::spawn(work); @@ -504,7 +508,7 @@ async fn run_auth_scope_retry_limit_client( } } } - client.cancel().await.ok(); + client.cancel().await; if !got_403 { break; @@ -535,10 +539,11 @@ async fn run_auth_preregistered_client( let auth_client = perform_oauth_flow_preregistered(server_url, client_id, client_secret).await?; - let transport = StreamableHttpClientTransport::with_client( + let (transport, http_work) = StreamableHttpClientTransport::with_client( auth_client, StreamableHttpClientTransportConfig::with_uri(server_url), ); + tokio::spawn(http_work); let (client, work) = BasicClientHandler.serve(transport).await?; tokio::spawn(work); @@ -557,7 +562,7 @@ async fn run_auth_preregistered_client( }) .await; } - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -595,11 +600,12 @@ async fn run_client_credentials_basic( .ok_or_else(|| anyhow::anyhow!("No access_token in response"))?; // Use static token - let transport = StreamableHttpClientTransport::with_client( + let (transport, http_work) = StreamableHttpClientTransport::with_client( reqwest::Client::default(), StreamableHttpClientTransportConfig::with_uri(server_url) .auth_header(access_token.to_string()), ); + tokio::spawn(http_work); let (client, work) = BasicClientHandler.serve(transport).await?; tokio::spawn(work); @@ -617,7 +623,7 @@ async fn run_client_credentials_basic( }) .await; } - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -667,11 +673,12 @@ async fn run_client_credentials_jwt( .as_str() .ok_or_else(|| anyhow::anyhow!("No access_token: {}", token_resp))?; - let transport = StreamableHttpClientTransport::with_client( + let (transport, http_work) = StreamableHttpClientTransport::with_client( reqwest::Client::default(), StreamableHttpClientTransportConfig::with_uri(server_url) .auth_header(access_token.to_string()), ); + tokio::spawn(http_work); let (client, work) = BasicClientHandler.serve(transport).await?; tokio::spawn(work); @@ -689,7 +696,7 @@ async fn run_client_credentials_jwt( }) .await; } - client.cancel().await?; + client.cancel().await; Ok(()) } @@ -839,17 +846,21 @@ fn build_tool_arguments(tool: &Tool) -> Option> { // ─── Non-auth scenarios ───────────────────────────────────────────────────── async fn run_basic_client(server_url: &str) -> anyhow::Result<()> { - let transport = StreamableHttpClientTransport::from_uri(server_url); - let client = BasicClientHandler.serve(transport).await?; + let (transport, http_work) = StreamableHttpClientTransport::from_uri(server_url); + tokio::spawn(http_work); + let (client, work) = BasicClientHandler.serve(transport).await?; + tokio::spawn(work); let tools = client.list_tools(Default::default()).await?; tracing::debug!("Listed {} tools", tools.tools.len()); - client.cancel().await?; + client.cancel().await; Ok(()) } async fn run_tools_call_client(server_url: &str) -> anyhow::Result<()> { - let transport = StreamableHttpClientTransport::from_uri(server_url); - let client = FullClientHandler.serve(transport).await?; + let (transport, http_work) = StreamableHttpClientTransport::from_uri(server_url); + tokio::spawn(http_work); + let (client, work) = FullClientHandler.serve(transport).await?; + tokio::spawn(work); let tools = client.list_tools(Default::default()).await?; for tool in &tools.tools { let args = build_tool_arguments(tool); @@ -862,13 +873,15 @@ async fn run_tools_call_client(server_url: &str) -> anyhow::Result<()> { }) .await?; } - client.cancel().await?; + client.cancel().await; Ok(()) } async fn run_elicitation_defaults_client(server_url: &str) -> anyhow::Result<()> { - let transport = StreamableHttpClientTransport::from_uri(server_url); - let client = ElicitationDefaultsClientHandler.serve(transport).await?; + let (transport, http_work) = StreamableHttpClientTransport::from_uri(server_url); + tokio::spawn(http_work); + let (client, work) = ElicitationDefaultsClientHandler.serve(transport).await?; + tokio::spawn(work); let tools = client.list_tools(Default::default()).await?; let test_tool = tools.tools.iter().find(|t| { let n = t.name.as_ref(); @@ -884,13 +897,15 @@ async fn run_elicitation_defaults_client(server_url: &str) -> anyhow::Result<()> }) .await?; } - client.cancel().await?; + client.cancel().await; Ok(()) } async fn run_sse_retry_client(server_url: &str) -> anyhow::Result<()> { - let transport = StreamableHttpClientTransport::from_uri(server_url); - let client = BasicClientHandler.serve(transport).await?; + let (transport, http_work) = StreamableHttpClientTransport::from_uri(server_url); + tokio::spawn(http_work); + let (client, work) = BasicClientHandler.serve(transport).await?; + tokio::spawn(work); let tools = client.list_tools(Default::default()).await?; if let Some(tool) = tools .tools @@ -906,7 +921,7 @@ async fn run_sse_retry_client(server_url: &str) -> anyhow::Result<()> { }) .await?; } - client.cancel().await?; + client.cancel().await; Ok(()) } diff --git a/conformance/src/bin/server.rs b/conformance/src/bin/server.rs index 97bfbcdc..87df48ee 100644 --- a/conformance/src/bin/server.rs +++ b/conformance/src/bin/server.rs @@ -943,11 +943,12 @@ async fn main() -> anyhow::Result<()> { stateful_mode: true, ..Default::default() }; - let service = StreamableHttpService::new( + let (service, http_work) = StreamableHttpService::new( move || Ok(server.clone()), LocalSessionManager::default().into(), config, ); + tokio::spawn(http_work); let router = axum::Router::new().nest_service("/mcp", service); diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index ae70f72f..1f48e0c2 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -246,7 +246,7 @@ impl StreamableHttpClientTransport { /// # Feature requirement /// /// This method requires the `transport-streamable-http-client-reqwest` feature. - pub fn from_uri(uri: impl Into>) -> Self { + pub fn from_uri(uri: impl Into>) -> (Self, impl Future + Send + 'static) { StreamableHttpClientTransport::with_client( reqwest::Client::default(), StreamableHttpClientTransportConfig { @@ -262,7 +262,9 @@ impl StreamableHttpClientTransport { /// # Arguments /// /// * `config` - The config to use with this transport - pub fn from_config(config: StreamableHttpClientTransportConfig) -> Self { + pub fn from_config( + config: StreamableHttpClientTransportConfig, + ) -> (Self, impl Future + Send + 'static) { StreamableHttpClientTransport::with_client(reqwest::Client::default(), config) } } diff --git a/crates/rmcp/src/transport/streamable_http_server/session.rs b/crates/rmcp/src/transport/streamable_http_server/session.rs index 9cf4d0db..5b354a90 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session.rs @@ -45,7 +45,16 @@ pub trait SessionManager: Send + Sync + 'static { /// that will be used to exchange MCP messages within this session. fn create_session( &self, - ) -> impl Future> + Send; + ) -> impl Future< + Output = Result< + ( + SessionId, + Self::Transport, + impl Future + Send + 'static, + ), + Self::Error, + >, + > + Send; /// Forward the first message (the `initialize` request) to the session. fn initialize_session( diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index 229f5396..0821b4b9 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -47,11 +47,21 @@ pub enum LocalSessionManagerError { impl SessionManager for LocalSessionManager { type Error = LocalSessionManagerError; type Transport = WorkerTransport; - async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> { + async fn create_session( + &self, + ) -> Result< + ( + SessionId, + Self::Transport, + impl Future + Send + 'static, + ), + Self::Error, + > { let id = session_id(); let (handle, worker) = create_local_session(id.clone(), self.session_config.clone()); self.sessions.write().await.insert(id.clone(), handle); - Ok((id, WorkerTransport::new(worker))) + let (transport, work) = WorkerTransport::new(worker); + Ok((id, transport, work)) } async fn initialize_session( &self, diff --git a/crates/rmcp/src/transport/streamable_http_server/session/never.rs b/crates/rmcp/src/transport/streamable_http_server/session/never.rs index 436d4cfc..22f88fec 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/never.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/never.rs @@ -39,8 +39,20 @@ impl SessionManager for NeverSessionManager { fn create_session( &self, - ) -> impl Future> + Send { - futures::future::ready(Err(ErrorSessionManagementNotSupported)) + ) -> impl Future< + Output = Result< + ( + SessionId, + Self::Transport, + impl Future + Send + 'static, + ), + Self::Error, + >, + > + Send { + futures::future::ready(Err::< + (SessionId, Self::Transport, futures::future::Ready<()>), + Self::Error, + >(ErrorSessionManagementNotSupported)) } fn initialize_session( diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index f62708e8..56ae05a0 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -1,7 +1,7 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; use bytes::Bytes; -use futures::{StreamExt, future::BoxFuture}; +use futures::{FutureExt, StreamExt, future::BoxFuture, stream::FuturesUnordered}; use http::{Method, Request, Response, header::ALLOW}; use http_body::Body; use http_body_util::{BodyExt, Full, combinators::BoxBody}; @@ -27,6 +27,7 @@ use crate::{ }, }, }, + util::PinnedFuture, }; #[derive(Debug, Clone)] @@ -189,6 +190,8 @@ pub struct StreamableHttpService, service_factory: Arc Result + Send + Sync>, + /// Used to spawn work on the session task, which drives all session and request work to completion. + work_tx: tokio::sync::mpsc::UnboundedSender>, } impl Clone for StreamableHttpService { @@ -197,6 +200,7 @@ impl Clone for StreamableHttpService { config: self.config.clone(), session_manager: self.session_manager.clone(), service_factory: self.service_factory.clone(), + work_tx: self.work_tx.clone(), } } } @@ -236,12 +240,38 @@ where service_factory: impl Fn() -> Result + Send + Sync + 'static, session_manager: Arc, config: StreamableHttpServerConfig, - ) -> Self { - Self { + ) -> (Self, impl Future + Send + 'static) { + let (work_tx, mut work_rx) = + tokio::sync::mpsc::unbounded_channel::>(); + + let session_work = async move { + let mut work_set = FuturesUnordered::new(); + + loop { + tokio::select! { + Some(work) = work_rx.recv(), if !work_rx.is_closed() => { + work_set.push(work); + } + _ = work_set.next(), if !work_set.is_empty() => { + // just drive the work futures to completion, no need to check results here + }, + else => { + // both channels closed and all work completed, we can shut down the session task + tracing::info!("Streamable HTTP server session work task is shutting down"); + break; + } + } + } + }; + + let this = Self { config, session_manager, service_factory: Arc::new(service_factory), - } + work_tx, + }; + + (this, session_work) } fn get_service(&self) -> Result { (self.service_factory)() @@ -493,11 +523,16 @@ where } } } else { - let (session_id, transport) = self + let (session_id, transport, work) = self .session_manager .create_session() .await .map_err(internal_error_response("create session"))?; + + self.work_tx + .send(work.boxed()) + .map_err(internal_error_response("spawn session task"))?; + if let ClientJsonRpcMessage::Request(req) = &mut message { if !matches!(req.request, ClientRequest::InitializeRequest(_)) { return Err(unexpected_message_response("initialize request")); @@ -511,9 +546,11 @@ where .get_service() .map_err(internal_error_response("get service"))?; // spawn a task to serve the session - tokio::spawn({ + + let work = { let session_manager = self.session_manager.clone(); let session_id = session_id.clone(); + let work_tx = self.work_tx.clone(); async move { let serve_result = serve_server::( @@ -523,8 +560,11 @@ where match serve_result { Ok((service, work)) => { // on service created - tokio::spawn(work); - let _ = service.waiting().await; + if let Err(e) = work_tx.send(work.boxed()) { + tracing::error!("Failed to spawn session work: {e}"); + } else { + let _ = service.waiting().await; + } } Err(e) => { tracing::error!("Failed to create service: {e}"); @@ -537,7 +577,12 @@ where tracing::error!("Failed to close session {session_id}: {e}"); }); } - }); + }; + + self.work_tx + .send(work.boxed()) + .map_err(internal_error_response("spawn async work"))?; + // get initialize response let response = self .session_manager @@ -595,8 +640,17 @@ where request.request.extensions_mut().insert(part); let (transport, mut receiver) = OneshotTransport::::new(ClientJsonRpcMessage::Request(request)); - let (_, work) = serve_directly(service, transport, None); - tokio::spawn(work); + let (service, work) = serve_directly(service, transport, None); + + let work = async move { + work.await; + // Need to keep the service handle alive, because if it is dropped it will cancel its work. + service.waiting().await; + }; + + self.work_tx + .send(work.boxed()) + .map_err(internal_error_response("spawn async work"))?; if self.config.json_response { // JSON-direct mode: await the single response and return as // application/json, eliminating SSE framing overhead. diff --git a/crates/rmcp/src/util.rs b/crates/rmcp/src/util.rs index 33b273f8..8aa43474 100644 --- a/crates/rmcp/src/util.rs +++ b/crates/rmcp/src/util.rs @@ -3,8 +3,4 @@ use std::pin::Pin; pub type PinnedFuture<'a, T> = Pin + Send + 'a>>; -pub type PinnedLocalFuture<'a, T> = Pin + 'a>>; - pub type PinnedStream<'a, T> = Pin + Send + 'a>>; - -pub type PinnedLocalStream<'a, T> = Pin + 'a>>; diff --git a/crates/rmcp/tests/test_close_connection.rs b/crates/rmcp/tests/test_close_connection.rs index d79fac85..ba620ee2 100644 --- a/crates/rmcp/tests/test_close_connection.rs +++ b/crates/rmcp/tests/test_close_connection.rs @@ -5,7 +5,7 @@ use std::time::Duration; use anyhow::anyhow; use common::handlers::{TestClientHandler, TestServer}; -use rmcp::{ServiceExt, handler::client, service::QuitReason}; +use rmcp::{ServiceExt, service::QuitReason}; /// Test that close() properly shuts down the connection #[tokio::test] diff --git a/crates/rmcp/tests/test_custom_headers.rs b/crates/rmcp/tests/test_custom_headers.rs index b88d0cd6..8ed91744 100644 --- a/crates/rmcp/tests/test_custom_headers.rs +++ b/crates/rmcp/tests/test_custom_headers.rs @@ -493,7 +493,8 @@ async fn test_mcp_custom_headers_sent_to_server() -> anyhow::Result<()> { StreamableHttpClientTransportConfig::with_uri(format!("http://127.0.0.1:{}/mcp", port)) .custom_headers(custom_headers); - let transport = StreamableHttpClientTransport::from_config(config); + let (transport, http_work) = StreamableHttpClientTransport::from_config(config); + tokio::spawn(http_work); // Start MCP client with empty handler (this will trigger initialize request) let (client, work) = ().serve(transport).await.expect("Failed to start client"); @@ -665,7 +666,8 @@ async fn test_mcp_protocol_version_header_sent_after_init() -> anyhow::Result<() let config = StreamableHttpClientTransportConfig::with_uri(format!("http://127.0.0.1:{}/mcp", port)); - let transport = StreamableHttpClientTransport::from_config(config); + let (transport, http_work) = StreamableHttpClientTransport::from_config(config); + tokio::spawn(http_work); let (client, work) = ().serve(transport).await.expect("Failed to start client"); tokio::spawn(work); @@ -740,12 +742,14 @@ async fn test_server_rejects_unsupported_protocol_version() { } let session_manager = Arc::new(LocalSessionManager::default()); - let service = StreamableHttpService::new( + let (service, http_work) = StreamableHttpService::new( || Ok(TestHandler), session_manager, StreamableHttpServerConfig::default(), ); + tokio::spawn(http_work); + // First, send an initialize request to create a session let init_body = json!({ "jsonrpc": "2.0", diff --git a/crates/rmcp/tests/test_sse_concurrent_streams.rs b/crates/rmcp/tests/test_sse_concurrent_streams.rs index b54ed556..92b16946 100644 --- a/crates/rmcp/tests/test_sse_concurrent_streams.rs +++ b/crates/rmcp/tests/test_sse_concurrent_streams.rs @@ -76,7 +76,7 @@ impl ServerHandler for TestServer { async fn start_test_server(ct: CancellationToken, trigger: Arc) -> String { let server = TestServer::new(trigger); - let service = StreamableHttpService::new( + let (service, http_work) = StreamableHttpService::new( move || Ok(server.clone()), Arc::new(LocalSessionManager::default()), StreamableHttpServerConfig { @@ -88,6 +88,8 @@ async fn start_test_server(ct: CancellationToken, trigger: Arc) -> Strin }, ); + tokio::spawn(http_work); + let router = axum::Router::new().nest_service("/mcp", service); let listener = tokio::net::TcpListener::bind("127.0.0.1:0") .await diff --git a/crates/rmcp/tests/test_streamable_http_json_response.rs b/crates/rmcp/tests/test_streamable_http_json_response.rs index e5b3323a..141e2333 100644 --- a/crates/rmcp/tests/test_streamable_http_json_response.rs +++ b/crates/rmcp/tests/test_streamable_http_json_response.rs @@ -12,9 +12,11 @@ async fn spawn_server( config: StreamableHttpServerConfig, ) -> (reqwest::Client, String, CancellationToken) { let ct = config.cancellation_token.clone(); - let service: StreamableHttpService = + let (service, http_work): (StreamableHttpService, _) = StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config); + tokio::spawn(http_work); + let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = tcp_listener.local_addr().unwrap(); diff --git a/crates/rmcp/tests/test_streamable_http_priming.rs b/crates/rmcp/tests/test_streamable_http_priming.rs index 778dfedf..b24f910b 100644 --- a/crates/rmcp/tests/test_streamable_http_priming.rs +++ b/crates/rmcp/tests/test_streamable_http_priming.rs @@ -13,7 +13,7 @@ async fn test_priming_on_stream_start() -> anyhow::Result<()> { let ct = CancellationToken::new(); // stateful_mode: true automatically enables priming with DEFAULT_RETRY_INTERVAL (3 seconds) - let service: StreamableHttpService = + let (service, http_work): (StreamableHttpService, _) = StreamableHttpService::new( || Ok(Calculator::new()), Default::default(), @@ -25,6 +25,8 @@ async fn test_priming_on_stream_start() -> anyhow::Result<()> { }, ); + tokio::spawn(http_work); + let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; let addr = tcp_listener.local_addr()?; @@ -83,7 +85,7 @@ async fn test_priming_on_stream_close() -> anyhow::Result<()> { let session_manager = Arc::new(LocalSessionManager::default()); // stateful_mode: true automatically enables priming with DEFAULT_RETRY_INTERVAL (3 seconds) - let service = StreamableHttpService::new( + let (service, http_work) = StreamableHttpService::new( || Ok(Calculator::new()), session_manager.clone(), StreamableHttpServerConfig { @@ -94,6 +96,8 @@ async fn test_priming_on_stream_close() -> anyhow::Result<()> { }, ); + tokio::spawn(http_work); + let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; let addr = tcp_listener.local_addr()?; diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 40228b58..12acc40a 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -75,7 +75,7 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { .await?; let ct = CancellationToken::new(); - let service: StreamableHttpService = + let (service, http_work): (StreamableHttpService, _) = StreamableHttpService::new( || Ok(Calculator::new()), Default::default(), @@ -88,6 +88,7 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { ); let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind(STREAMABLE_HTTP_BIND_ADDRESS).await?; + tokio::spawn(http_work); let handle = tokio::spawn({ let ct = ct.clone(); @@ -124,9 +125,10 @@ async fn test_with_js_streamable_http_server() -> anyhow::Result<()> { .wait() .await?; - let transport = StreamableHttpClientTransport::from_uri(format!( + let (transport, http_work) = StreamableHttpClientTransport::from_uri(format!( "http://{STREAMABLE_HTTP_JS_BIND_ADDRESS}/mcp" )); + tokio::spawn(http_work); let mut server = tokio::process::Command::new("node") .arg("tests/test_with_js/streamable_server.js") diff --git a/examples/clients/src/auth/oauth_client.rs b/examples/clients/src/auth/oauth_client.rs index 97f0e497..15bdc357 100644 --- a/examples/clients/src/auth/oauth_client.rs +++ b/examples/clients/src/auth/oauth_client.rs @@ -173,10 +173,11 @@ async fn main() -> Result<()> { .into_authorization_manager() .ok_or_else(|| anyhow::anyhow!("Failed to get authorization manager"))?; let client = AuthClient::new(reqwest::Client::default(), am); - let transport = StreamableHttpClientTransport::with_client( + let (transport, http_work) = StreamableHttpClientTransport::with_client( client, StreamableHttpClientTransportConfig::with_uri(server_url.as_str()), ); + tokio::spawn(http_work); // Create client and connect to MCP server let client_service = ClientInfo::default(); diff --git a/examples/clients/src/progress_client.rs b/examples/clients/src/progress_client.rs index 33e1351e..17be97d1 100644 --- a/examples/clients/src/progress_client.rs +++ b/examples/clients/src/progress_client.rs @@ -216,7 +216,8 @@ async fn test_http_transport(http_url: &str, records: u32) -> Result<()> { // Wait a bit for server to be ready sleep(Duration::from_secs(1)).await; - let transport = StreamableHttpClientTransport::from_uri(http_url); + let (transport, http_work) = StreamableHttpClientTransport::from_uri(http_url); + tokio::spawn(http_work); // Create progress-aware client handler let client_handler = ProgressAwareClient::new(); diff --git a/examples/clients/src/sampling_stdio.rs b/examples/clients/src/sampling_stdio.rs index 7d8e5e6f..fc9e4c5c 100644 --- a/examples/clients/src/sampling_stdio.rs +++ b/examples/clients/src/sampling_stdio.rs @@ -6,7 +6,7 @@ use rmcp::{ service::{RequestContext, RoleClient}, transport::{ CommandBuilder, - child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, + child_process::{ChildProcessTransport, tokio::TokioChildProcessRunner}, }, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/examples/clients/src/streamable_http.rs b/examples/clients/src/streamable_http.rs index 50e76541..ecdc4105 100644 --- a/examples/clients/src/streamable_http.rs +++ b/examples/clients/src/streamable_http.rs @@ -16,7 +16,10 @@ async fn main() -> Result<()> { ) .with(tracing_subscriber::fmt::layer()) .init(); - let transport = StreamableHttpClientTransport::from_uri("http://localhost:8000/mcp"); + let (transport, http_work) = + StreamableHttpClientTransport::from_uri("http://localhost:8000/mcp"); + tokio::spawn(http_work); + let client_info = ClientInfo { meta: None, protocol_version: Default::default(), diff --git a/examples/rig-integration/src/config/mcp.rs b/examples/rig-integration/src/config/mcp.rs index d5b398d2..305e2d13 100644 --- a/examples/rig-integration/src/config/mcp.rs +++ b/examples/rig-integration/src/config/mcp.rs @@ -68,8 +68,9 @@ impl McpServerTransportConfig { pub async fn start(&self) -> anyhow::Result> { let client = match self { McpServerTransportConfig::Streamable { url } => { - let transport = + let (transport, http_work) = rmcp::transport::StreamableHttpClientTransport::from_uri(url.to_string()); + tokio::spawn(http_work); let (service, work) = ().serve(transport).await?; tokio::spawn(work); service diff --git a/examples/servers/src/cimd_auth_streamhttp.rs b/examples/servers/src/cimd_auth_streamhttp.rs index 7b402c9f..e9734460 100644 --- a/examples/servers/src/cimd_auth_streamhttp.rs +++ b/examples/servers/src/cimd_auth_streamhttp.rs @@ -478,12 +478,12 @@ async fn main() -> Result<()> { let state = AppState::new(); // Create streamable HTTP service for MCP - let mcp_service: StreamableHttpService = - StreamableHttpService::new( - || Ok(Counter::new()), - LocalSessionManager::default().into(), - StreamableHttpServerConfig::default(), - ); + let (mcp_service, http_work) = StreamableHttpService::new( + || Ok(Counter::new()), + LocalSessionManager::default().into(), + StreamableHttpServerConfig::default(), + ); + tokio::spawn(http_work); let addr = BIND_ADDRESS.parse::()?; diff --git a/examples/servers/src/complex_auth_streamhttp.rs b/examples/servers/src/complex_auth_streamhttp.rs index 4afacf8d..80ba30e7 100644 --- a/examples/servers/src/complex_auth_streamhttp.rs +++ b/examples/servers/src/complex_auth_streamhttp.rs @@ -642,12 +642,12 @@ async fn main() -> Result<()> { let addr = BIND_ADDRESS.parse::()?; // Create streamable HTTP service for MCP - let mcp_service: StreamableHttpService = - StreamableHttpService::new( - || Ok(Counter::new()), - LocalSessionManager::default().into(), - StreamableHttpServerConfig::default(), - ); + let (mcp_service, http_work) = StreamableHttpService::new( + || Ok(Counter::new()), + LocalSessionManager::default().into(), + StreamableHttpServerConfig::default(), + ); + tokio::spawn(http_work); // Create protected MCP routes (require authorization) let protected_mcp_router = diff --git a/examples/servers/src/counter_hyper_streamable_http.rs b/examples/servers/src/counter_hyper_streamable_http.rs index 6312180d..7031a350 100644 --- a/examples/servers/src/counter_hyper_streamable_http.rs +++ b/examples/servers/src/counter_hyper_streamable_http.rs @@ -11,11 +11,13 @@ use rmcp::transport::streamable_http_server::{ #[tokio::main] async fn main() -> anyhow::Result<()> { - let service = TowerToHyperService::new(StreamableHttpService::new( + let (service, http_work) = StreamableHttpService::new( || Ok(Counter::new()), LocalSessionManager::default().into(), Default::default(), - )); + ); + tokio::spawn(http_work); + let service = TowerToHyperService::new(service); let listener = tokio::net::TcpListener::bind("[::1]:8080").await?; loop { let io = tokio::select! { diff --git a/examples/servers/src/counter_streamhttp.rs b/examples/servers/src/counter_streamhttp.rs index db9b9df1..da777b3c 100644 --- a/examples/servers/src/counter_streamhttp.rs +++ b/examples/servers/src/counter_streamhttp.rs @@ -22,7 +22,7 @@ async fn main() -> anyhow::Result<()> { .init(); let ct = tokio_util::sync::CancellationToken::new(); - let service = StreamableHttpService::new( + let (service, http_work) = StreamableHttpService::new( || Ok(Counter::new()), LocalSessionManager::default().into(), StreamableHttpServerConfig { @@ -30,6 +30,7 @@ async fn main() -> anyhow::Result<()> { ..Default::default() }, ); + tokio::spawn(http_work); let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; diff --git a/examples/servers/src/elicitation_enum_inference.rs b/examples/servers/src/elicitation_enum_inference.rs index 2ecec311..951ff0a4 100644 --- a/examples/servers/src/elicitation_enum_inference.rs +++ b/examples/servers/src/elicitation_enum_inference.rs @@ -174,11 +174,12 @@ async fn main() -> anyhow::Result<()> { .with(tracing_subscriber::fmt::layer()) .init(); - let service = StreamableHttpService::new( + let (service, http_work) = StreamableHttpService::new( || Ok(ElicitationEnumFormServer::new()), LocalSessionManager::default().into(), Default::default(), ); + tokio::spawn(http_work); let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; diff --git a/examples/servers/src/progress_demo.rs b/examples/servers/src/progress_demo.rs index c12572c8..8a58f264 100644 --- a/examples/servers/src/progress_demo.rs +++ b/examples/servers/src/progress_demo.rs @@ -44,11 +44,12 @@ async fn run_stdio() -> anyhow::Result<()> { async fn run_streamable_http() -> anyhow::Result<()> { println!("Running Streamable HTTP server"); - let service = StreamableHttpService::new( + let (service, http_work) = StreamableHttpService::new( || Ok(ProgressDemo::new()), LocalSessionManager::default().into(), Default::default(), ); + tokio::spawn(http_work); let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind(HTTP_BIND_ADDRESS).await?; @@ -70,11 +71,12 @@ async fn run_all_transports() -> anyhow::Result<()> { println!("Running all transports"); // Start Streamable HTTP server - let http_service = StreamableHttpService::new( + let (http_service, http_work) = StreamableHttpService::new( || Ok(ProgressDemo::new()), LocalSessionManager::default().into(), Default::default(), ); + tokio::spawn(http_work); let http_router = axum::Router::new().nest_service("/mcp", http_service); let http_listener = tokio::net::TcpListener::bind(HTTP_BIND_ADDRESS).await?; diff --git a/examples/servers/src/simple_auth_streamhttp.rs b/examples/servers/src/simple_auth_streamhttp.rs index f68ed894..deee7aca 100644 --- a/examples/servers/src/simple_auth_streamhttp.rs +++ b/examples/servers/src/simple_auth_streamhttp.rs @@ -131,12 +131,12 @@ async fn main() -> Result<()> { let addr = BIND_ADDRESS.parse::()?; // Create streamable HTTP service - let mcp_service: StreamableHttpService = - StreamableHttpService::new( - || Ok(Counter::new()), - LocalSessionManager::default().into(), - StreamableHttpServerConfig::default(), - ); + let (mcp_service, http_work) = StreamableHttpService::new( + || Ok(Counter::new()), + LocalSessionManager::default().into(), + StreamableHttpServerConfig::default(), + ); + tokio::spawn(http_work); // Create API routes let api_routes = Router::new() diff --git a/examples/simple-chat-client/src/config.rs b/examples/simple-chat-client/src/config.rs index 946436fd..2a6c5d49 100644 --- a/examples/simple-chat-client/src/config.rs +++ b/examples/simple-chat-client/src/config.rs @@ -52,8 +52,9 @@ impl McpServerTransportConfig { pub async fn start(&self) -> Result> { let client = match self { McpServerTransportConfig::Streamable { url } => { - let transport = + let (transport, http_work) = rmcp::transport::StreamableHttpClientTransport::from_uri(url.to_string()); + tokio::spawn(http_work); let (service, work) = ().serve(transport).await?; tokio::spawn(work); service From 7fa7a3f69e36cb79c787d1d028143d120fe6c878 Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Thu, 5 Mar 2026 04:31:39 +0000 Subject: [PATCH 18/20] fix(docs): fix doc examples so they compile --- crates/rmcp/README.md | 8 ++++---- crates/rmcp/src/transport/child_process.rs | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index 1444deee..2bfbfb90 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -86,7 +86,7 @@ async fn main() -> Result<(), Box> { // Spawn the async work loop on the background tokio::spawn(work); // Wait for the service to conclude - service.waiting().await?; + service.waiting().await; Ok(()) } ``` @@ -154,7 +154,7 @@ Creating a client to interact with a server: use rmcp::{ ServiceExt, model::CallToolRequestParams, - transport::{CommandBuilder, ChildProcessTransport, tokio::TokioChildProcessRunner} + transport::{CommandBuilder, child_process::{ChildProcessTransport, tokio::TokioChildProcessRunner}} }; use tokio::process::Command; @@ -164,10 +164,10 @@ async fn main() -> Result<(), Box> { // Connect to a server running as a child process let command = CommandBuilder::::new("uvx") .arg("mcp-server-git") - .spawn_dyn()? + .spawn_dyn()?; // Create a transport via the child process's STDIN and STDOUT streams - let transport = ChildProcessTransport::new(command)? + let transport = ChildProcessTransport::new(command)?; let (service, work) = ().serve(transport).await?; diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs index 44db4a45..4c665b99 100644 --- a/crates/rmcp/src/transport/child_process.rs +++ b/crates/rmcp/src/transport/child_process.rs @@ -2,5 +2,9 @@ pub mod builder; pub mod runner; pub mod transport; +pub use builder::CommandBuilder; +pub use runner::ChildProcessControl; +pub use transport::ChildProcessTransport; + #[cfg(feature = "transport-child-process-tokio")] pub mod tokio; From d526a9b3215f271e452288c0f6c9601a53e56a5b Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Thu, 5 Mar 2026 05:05:07 +0000 Subject: [PATCH 19/20] refactor(http): use different timeout API for futures --- crates/rmcp/Cargo.toml | 1 + crates/rmcp/src/service.rs | 17 +++++++---------- crates/rmcp/src/task_manager.rs | 8 +++----- .../src/transport/streamable_http_client.rs | 13 +++++++------ .../transport/streamable_http_server/session.rs | 8 ++++++++ .../transport/streamable_http_server/tower.rs | 7 +++++++ 6 files changed, 33 insertions(+), 21 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 9117db70..b1588cb8 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -63,6 +63,7 @@ http-body-util = { version = "0.1", optional = true } bytes = { version = "1", optional = true } # macro rmcp-macros = { workspace = true, optional = true } +futures-timeout = "0.1.3" [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] chrono = { version = "0.4.38", features = ["serde"] } diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 5d6e93a0..a349d517 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -3,6 +3,7 @@ use futures::{ future::{BoxFuture, RemoteHandle}, stream::FuturesUnordered, }; +use futures_timeout::TimeoutExt; use thiserror::Error; use tokio_stream::wrappers::ReceiverStream; @@ -278,14 +279,10 @@ impl RequestHandle { pub const REQUEST_TIMEOUT_REASON: &str = "request timeout"; pub async fn await_response(self) -> Result { if let Some(timeout) = self.options.timeout { - // TODO: tokio timeout won't work if not in the tokio RT - // Find an alternative - let timeout_result = tokio::time::timeout(timeout, async move { - self.rx.await.map_err(|_e| ServiceError::TransportClosed)? - }) - .await; + let timeout_result = self.rx.timeout(timeout).await; + match timeout_result { - Ok(response) => response, + Ok(response) => response.map_err(|_e| ServiceError::TransportClosed)?, Err(_) => { let error = Err(ServiceError::Timeout { timeout }); // cancel this request @@ -566,10 +563,10 @@ impl> RunningService { pub async fn close_with_timeout(&mut self, timeout: Duration) -> Option { if let Some(handle) = self.handle.take() { self.cancellation_token.cancel(); - // TODO: tokio timeout won't work if not in the tokio RT, find an alternative - match tokio::time::timeout(timeout, handle).await { + + match handle.timeout(timeout).await { Ok(reason) => Some(reason), - Err(_elapsed) => { + Err(_) => { tracing::warn!( "close_with_timeout: cleanup did not complete within {:?}", timeout diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs index 7cc81575..3393304d 100644 --- a/crates/rmcp/src/task_manager.rs +++ b/crates/rmcp/src/task_manager.rs @@ -5,10 +5,8 @@ use futures::{ future::abortable, stream::{AbortHandle, FuturesUnordered}, }; -use tokio::{ - sync::mpsc, - time::{Duration, timeout}, -}; +use futures_timeout::TimeoutExt; +use tokio::{sync::mpsc, time::Duration}; use crate::{ RoleServer, @@ -222,7 +220,7 @@ impl OperationProcessor { let timed_future = async move { if let Some(secs) = timeout_secs { - match timeout(Duration::from_secs(secs), future).await { + match future.timeout(Duration::from_secs(secs)).await { Ok(result) => result, Err(_) => Err(Error::TaskError("Operation timed out".to_string())), } diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 7bdf580f..4312b8aa 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -5,6 +5,7 @@ use futures::{ future::BoxFuture, stream::{BoxStream, FuturesUnordered}, }; +use futures_timeout::TimeoutExt; use http::{HeaderName, HeaderValue}; pub use sse_stream::Error as SseError; use sse_stream::Sse; @@ -605,16 +606,16 @@ impl Worker for StreamableHttpClientWorker { if let Some(cleanup) = session_cleanup_info { const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5); let cleanup_session_id = cleanup.session_id.clone(); - match tokio::time::timeout( - SESSION_CLEANUP_TIMEOUT, - cleanup.client.delete_session( + match cleanup + .client + .delete_session( cleanup.uri, cleanup.session_id, cleanup.auth_header, cleanup.protocol_headers, - ), - ) - .await + ) + .timeout(SESSION_CLEANUP_TIMEOUT) + .await { Ok(Ok(_)) => { tracing::info!( diff --git a/crates/rmcp/src/transport/streamable_http_server/session.rs b/crates/rmcp/src/transport/streamable_http_server/session.rs index 5b354a90..ff20ead6 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session.rs @@ -43,6 +43,14 @@ pub trait SessionManager: Send + Sync + 'static { /// Create a new session and return its ID together with the transport /// that will be used to exchange MCP messages within this session. + /// + /// The result of this async creation will be the [SessionId], the + /// [Self::Transport], and a future that drives the session's execution. + /// + /// The session will be active and able to receive messages once the future is polled or spawned. + /// The caller is responsible for polling or spawning the work future returned by this function. + /// + /// If the [Self::Transport] handle is dropped, the session's async work loop future will exit. fn create_session( &self, ) -> impl Future< diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 56ae05a0..743a4e10 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -236,6 +236,13 @@ where S: crate::Service + Send + 'static, M: SessionManager, { + /// Create a new `StreamableHttpService` using the service factory, session manager, and configuration provided. + /// + /// This function returns a handle to the service, and a future that must be polled to drive the execution + /// of the service and its sessions. The caller is responsible for polling or spawning the future returned + /// by this function. + /// + /// If you drop the [StreamableHttpService] handle, the async work loop future will exit. pub fn new( service_factory: impl Fn() -> Result + Send + Sync + 'static, session_manager: Arc, From 4732493568692e693f257ff166fe95412822b3ec Mon Sep 17 00:00:00 2001 From: Aadam Zocolo Date: Thu, 5 Mar 2026 21:54:45 +0000 Subject: [PATCH 20/20] chore(examples): cleanup examples so they compile some examples are failing on some JSON schema issue --- examples/clients/src/auth/client_credentials.rs | 6 ++++-- examples/clients/src/everything_stdio.rs | 3 +-- examples/servers/src/common/counter.rs | 2 +- examples/transport/src/tcp.rs | 8 +++++--- examples/transport/src/unix_socket.rs | 10 +++++----- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/examples/clients/src/auth/client_credentials.rs b/examples/clients/src/auth/client_credentials.rs index 55aa6153..5e36a1f0 100644 --- a/examples/clients/src/auth/client_credentials.rs +++ b/examples/clients/src/auth/client_credentials.rs @@ -67,14 +67,16 @@ async fn main() -> Result<()> { .into_authorization_manager() .context("Failed to get authorization manager")?; let client = AuthClient::new(reqwest::Client::default(), manager); - let transport = StreamableHttpClientTransport::with_client( + let (transport, http_work) = StreamableHttpClientTransport::with_client( client, StreamableHttpClientTransportConfig::with_uri(server_url.as_str()), ); + tokio::spawn(http_work); // Connect to MCP server and list tools let client_service = ClientInfo::default(); - let client = client_service.serve(transport).await?; + let (client, work) = client_service.serve(transport).await?; + tokio::spawn(work); tracing::info!("Connected to MCP server"); match client.peer().list_all_tools().await { diff --git a/examples/clients/src/everything_stdio.rs b/examples/clients/src/everything_stdio.rs index 93292c1f..bce8edb0 100644 --- a/examples/clients/src/everything_stdio.rs +++ b/examples/clients/src/everything_stdio.rs @@ -4,11 +4,10 @@ use rmcp::{ model::{CallToolRequestParams, GetPromptRequestParams, ReadResourceRequestParams}, object, transport::{ - self, CommandBuilder, + CommandBuilder, child_process::{tokio::TokioChildProcessRunner, transport::ChildProcessTransport}, }, }; -use tokio::process::Command; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index a96273af..1b0b80ad 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -365,7 +365,7 @@ mod tests { "source".into(), serde_json::Value::String("integration-test".into()), ); - let params = CallToolRequestParams::new("long_task").with_task(Some(task_meta)); + let params = CallToolRequestParams::new("long_task").with_task(task_meta); let response = client_service .send_request(ClientRequest::CallToolRequest(Request::new(params.clone()))) .await?; diff --git a/examples/transport/src/tcp.rs b/examples/transport/src/tcp.rs index 683fb6cf..d02493ce 100644 --- a/examples/transport/src/tcp.rs +++ b/examples/transport/src/tcp.rs @@ -13,8 +13,9 @@ async fn server() -> anyhow::Result<()> { let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:8001").await?; while let Ok((stream, _)) = tcp_listener.accept().await { tokio::spawn(async move { - let server = serve_server(Calculator::new(), stream).await?; - server.waiting().await?; + let (server, work) = serve_server(Calculator::new(), stream).await?; + tokio::spawn(work); + server.waiting().await; anyhow::Ok(()) }); } @@ -25,7 +26,8 @@ async fn client() -> anyhow::Result<()> { let stream = tokio::net::TcpSocket::new_v4()? .connect("127.0.0.1:8001".parse()?) .await?; - let client = serve_client((), stream).await?; + let (client, work) = serve_client((), stream).await?; + tokio::spawn(work); let tools = client.peer().list_tools(Default::default()).await?; println!("{:?}", tools); Ok(()) diff --git a/examples/transport/src/unix_socket.rs b/examples/transport/src/unix_socket.rs index a8eb6271..ca0e69e4 100644 --- a/examples/transport/src/unix_socket.rs +++ b/examples/transport/src/unix_socket.rs @@ -15,11 +15,10 @@ async fn main() -> anyhow::Result<()> { println!("Client connected: {:?}", addr); tokio::spawn(async move { match serve_server(Calculator::new(), stream).await { - Ok(server) => { + Ok((server, work)) => { + tokio::spawn(work); println!("Server initialized successfully"); - if let Err(e) = server.waiting().await { - println!("Error while server waiting: {}", e); - } + server.waiting().await; } Err(e) => println!("Server initialization failed: {}", e), } @@ -34,7 +33,8 @@ async fn main() -> anyhow::Result<()> { println!("Client connecting to {}", SOCKET_PATH); let stream = UnixStream::connect(SOCKET_PATH).await?; - let client = serve_client((), stream).await?; + let (client, work) = serve_client((), stream).await?; + tokio::spawn(work); println!("Client connected and initialized successfully"); // List available tools