diff --git a/README.md b/README.md
index 10378b607..097289959 100644
--- a/README.md
+++ b/README.md
@@ -59,14 +59,32 @@ Json Schema generation (version 2020-12):
Start a client
```rust, ignore
-use rmcp::{ServiceExt, transport::{TokioChildProcess, ConfigureCommandExt}};
+use rmcp::{
+ ServiceExt,
+ transport::{CommandBuilder, ChildProcessTransport, tokio::TokioChildProcessRunner}
+};
use tokio::process::Command;
#[tokio::main]
async fn main() -> Result<(), Box> {
- let client = ().serve(TokioChildProcess::new(Command::new("npx").configure(|cmd| {
- cmd.arg("-y").arg("@modelcontextprotocol/server-everything");
- }))?).await?;
+
+ // Build and spawn a child process
+ let command = CommandBuilder::::new("npx")
+ .arg("-y")
+ .arg("@modelcontextprotocol/server-everything")
+ .spawn_dyn()?
+
+ // Create a transport via the child process's STDIN and STDOUT streams
+ let transport = ChildProcessTransport::new(command)?
+
+ let (client, work) = ().serve(transport).await?;
+ // Spawn the async work loop on the background
+ tokio::spawn(work);
+
+ // Use the client ...
+
+ // Finish using the client
+ client.cancel().await;
Ok(())
}
```
@@ -99,7 +117,7 @@ let service = common::counter::Counter::new();
```rust, ignore
// this call will finish the initialization process
-let server = service.serve(transport).await?;
+let (server, work) = service.serve(transport).await?;
```
diff --git a/conformance/src/bin/client.rs b/conformance/src/bin/client.rs
index 53a44d9e7..e55856c78 100644
--- a/conformance/src/bin/client.rs
+++ b/conformance/src/bin/client.rs
@@ -320,12 +320,16 @@ async fn perform_oauth_flow_preregistered(
async fn run_auth_client(server_url: &str, ctx: &ConformanceContext) -> anyhow::Result<()> {
let auth_client = perform_oauth_flow(server_url, ctx).await?;
- let transport = StreamableHttpClientTransport::with_client(
+ let (transport, http_work) = StreamableHttpClientTransport::with_client(
auth_client,
StreamableHttpClientTransportConfig::with_uri(server_url),
);
+ tokio::spawn(http_work);
+
+ let (client, work) = BasicClientHandler.serve(transport).await?;
+ // Run the client work loop in the background while we interact with it
+ tokio::spawn(work);
- let client = BasicClientHandler.serve(transport).await?;
tracing::debug!("Connected (authenticated)");
let tools = client.list_tools(Default::default()).await?;
@@ -344,7 +348,7 @@ async fn run_auth_client(server_url: &str, ctx: &ConformanceContext) -> anyhow::
.await;
}
- client.cancel().await?;
+ client.cancel().await;
Ok(())
}
@@ -374,12 +378,15 @@ async fn run_auth_scope_step_up_client(
.ok_or_else(|| anyhow::anyhow!("No AM"))?;
let auth_client = AuthClient::new(reqwest::Client::default(), am);
- let transport = StreamableHttpClientTransport::with_client(
+ let (transport, http_work) = StreamableHttpClientTransport::with_client(
auth_client.clone(),
StreamableHttpClientTransportConfig::with_uri(server_url),
);
+ tokio::spawn(http_work);
- let client = BasicClientHandler.serve(transport).await?;
+ let (client, work) = BasicClientHandler.serve(transport).await?;
+ // Run the client work loop in the background while we interact with it
+ tokio::spawn(work);
let tools = client.list_tools(Default::default()).await?;
tracing::debug!("Listed {} tools", tools.tools.len());
@@ -402,7 +409,7 @@ async fn run_auth_scope_step_up_client(
Err(_) => {
tracing::debug!("Tool call failed (likely 403), attempting scope upgrade...");
// Drop old client, re-auth with upgraded scopes
- client.cancel().await.ok();
+ client.cancel().await;
// Re-do the full flow; the server will give us the right scopes
// on the second authorization request.
@@ -422,11 +429,13 @@ async fn run_auth_scope_step_up_client(
let am2 = oauth2.into_authorization_manager().unwrap();
let auth_client2 = AuthClient::new(reqwest::Client::default(), am2);
- let transport2 = StreamableHttpClientTransport::with_client(
+ let (transport2, http_work2) = StreamableHttpClientTransport::with_client(
auth_client2,
StreamableHttpClientTransportConfig::with_uri(server_url),
);
- let client2 = BasicClientHandler.serve(transport2).await?;
+ tokio::spawn(http_work2);
+ let (client2, work2) = BasicClientHandler.serve(transport2).await?;
+ tokio::spawn(work2);
let _ = client2
.call_tool(CallToolRequestParams {
meta: None,
@@ -435,13 +444,13 @@ async fn run_auth_scope_step_up_client(
task: None,
})
.await;
- client2.cancel().await.ok();
+ client2.cancel().await;
return Ok(());
}
}
}
- client.cancel().await?;
+ client.cancel().await;
Ok(())
}
@@ -469,12 +478,15 @@ async fn run_auth_scope_retry_limit_client(
let am = oauth.into_authorization_manager().unwrap();
let auth_client = AuthClient::new(reqwest::Client::default(), am);
- let transport = StreamableHttpClientTransport::with_client(
+ let (transport, http_work) = StreamableHttpClientTransport::with_client(
auth_client,
StreamableHttpClientTransportConfig::with_uri(server_url),
);
+ tokio::spawn(http_work);
+
+ let (client, work) = BasicClientHandler.serve(transport).await?;
+ tokio::spawn(work);
- let client = BasicClientHandler.serve(transport).await?;
let tools = client.list_tools(Default::default()).await?;
let mut got_403 = false;
@@ -496,7 +508,7 @@ async fn run_auth_scope_retry_limit_client(
}
}
}
- client.cancel().await.ok();
+ client.cancel().await;
if !got_403 {
break;
@@ -527,12 +539,15 @@ async fn run_auth_preregistered_client(
let auth_client =
perform_oauth_flow_preregistered(server_url, client_id, client_secret).await?;
- let transport = StreamableHttpClientTransport::with_client(
+ let (transport, http_work) = StreamableHttpClientTransport::with_client(
auth_client,
StreamableHttpClientTransportConfig::with_uri(server_url),
);
+ tokio::spawn(http_work);
+
+ let (client, work) = BasicClientHandler.serve(transport).await?;
+ tokio::spawn(work);
- let client = BasicClientHandler.serve(transport).await?;
let tools = client.list_tools(Default::default()).await?;
tracing::debug!("Listed {} tools", tools.tools.len());
@@ -547,7 +562,7 @@ async fn run_auth_preregistered_client(
})
.await;
}
- client.cancel().await?;
+ client.cancel().await;
Ok(())
}
@@ -585,13 +600,16 @@ async fn run_client_credentials_basic(
.ok_or_else(|| anyhow::anyhow!("No access_token in response"))?;
// Use static token
- let transport = StreamableHttpClientTransport::with_client(
+ let (transport, http_work) = StreamableHttpClientTransport::with_client(
reqwest::Client::default(),
StreamableHttpClientTransportConfig::with_uri(server_url)
.auth_header(access_token.to_string()),
);
+ tokio::spawn(http_work);
+
+ let (client, work) = BasicClientHandler.serve(transport).await?;
+ tokio::spawn(work);
- let client = BasicClientHandler.serve(transport).await?;
let tools = client.list_tools(Default::default()).await?;
tracing::debug!("Listed {} tools", tools.tools.len());
for tool in &tools.tools {
@@ -605,7 +623,7 @@ async fn run_client_credentials_basic(
})
.await;
}
- client.cancel().await?;
+ client.cancel().await;
Ok(())
}
@@ -655,13 +673,16 @@ async fn run_client_credentials_jwt(
.as_str()
.ok_or_else(|| anyhow::anyhow!("No access_token: {}", token_resp))?;
- let transport = StreamableHttpClientTransport::with_client(
+ let (transport, http_work) = StreamableHttpClientTransport::with_client(
reqwest::Client::default(),
StreamableHttpClientTransportConfig::with_uri(server_url)
.auth_header(access_token.to_string()),
);
+ tokio::spawn(http_work);
+
+ let (client, work) = BasicClientHandler.serve(transport).await?;
+ tokio::spawn(work);
- let client = BasicClientHandler.serve(transport).await?;
let tools = client.list_tools(Default::default()).await?;
tracing::debug!("Listed {} tools", tools.tools.len());
for tool in &tools.tools {
@@ -675,7 +696,7 @@ async fn run_client_credentials_jwt(
})
.await;
}
- client.cancel().await?;
+ client.cancel().await;
Ok(())
}
@@ -825,17 +846,21 @@ fn build_tool_arguments(tool: &Tool) -> Option> {
// ─── Non-auth scenarios ─────────────────────────────────────────────────────
async fn run_basic_client(server_url: &str) -> anyhow::Result<()> {
- let transport = StreamableHttpClientTransport::from_uri(server_url);
- let client = BasicClientHandler.serve(transport).await?;
+ let (transport, http_work) = StreamableHttpClientTransport::from_uri(server_url);
+ tokio::spawn(http_work);
+ let (client, work) = BasicClientHandler.serve(transport).await?;
+ tokio::spawn(work);
let tools = client.list_tools(Default::default()).await?;
tracing::debug!("Listed {} tools", tools.tools.len());
- client.cancel().await?;
+ client.cancel().await;
Ok(())
}
async fn run_tools_call_client(server_url: &str) -> anyhow::Result<()> {
- let transport = StreamableHttpClientTransport::from_uri(server_url);
- let client = FullClientHandler.serve(transport).await?;
+ let (transport, http_work) = StreamableHttpClientTransport::from_uri(server_url);
+ tokio::spawn(http_work);
+ let (client, work) = FullClientHandler.serve(transport).await?;
+ tokio::spawn(work);
let tools = client.list_tools(Default::default()).await?;
for tool in &tools.tools {
let args = build_tool_arguments(tool);
@@ -848,13 +873,15 @@ async fn run_tools_call_client(server_url: &str) -> anyhow::Result<()> {
})
.await?;
}
- client.cancel().await?;
+ client.cancel().await;
Ok(())
}
async fn run_elicitation_defaults_client(server_url: &str) -> anyhow::Result<()> {
- let transport = StreamableHttpClientTransport::from_uri(server_url);
- let client = ElicitationDefaultsClientHandler.serve(transport).await?;
+ let (transport, http_work) = StreamableHttpClientTransport::from_uri(server_url);
+ tokio::spawn(http_work);
+ let (client, work) = ElicitationDefaultsClientHandler.serve(transport).await?;
+ tokio::spawn(work);
let tools = client.list_tools(Default::default()).await?;
let test_tool = tools.tools.iter().find(|t| {
let n = t.name.as_ref();
@@ -870,13 +897,15 @@ async fn run_elicitation_defaults_client(server_url: &str) -> anyhow::Result<()>
})
.await?;
}
- client.cancel().await?;
+ client.cancel().await;
Ok(())
}
async fn run_sse_retry_client(server_url: &str) -> anyhow::Result<()> {
- let transport = StreamableHttpClientTransport::from_uri(server_url);
- let client = BasicClientHandler.serve(transport).await?;
+ let (transport, http_work) = StreamableHttpClientTransport::from_uri(server_url);
+ tokio::spawn(http_work);
+ let (client, work) = BasicClientHandler.serve(transport).await?;
+ tokio::spawn(work);
let tools = client.list_tools(Default::default()).await?;
if let Some(tool) = tools
.tools
@@ -892,7 +921,7 @@ async fn run_sse_retry_client(server_url: &str) -> anyhow::Result<()> {
})
.await?;
}
- client.cancel().await?;
+ client.cancel().await;
Ok(())
}
diff --git a/conformance/src/bin/server.rs b/conformance/src/bin/server.rs
index 97bfbcdcc..87df48eec 100644
--- a/conformance/src/bin/server.rs
+++ b/conformance/src/bin/server.rs
@@ -943,11 +943,12 @@ async fn main() -> anyhow::Result<()> {
stateful_mode: true,
..Default::default()
};
- let service = StreamableHttpService::new(
+ let (service, http_work) = StreamableHttpService::new(
move || Ok(server.clone()),
LocalSessionManager::default().into(),
config,
);
+ tokio::spawn(http_work);
let router = axum::Router::new().nest_service("/mcp", service);
diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml
index c5d919ae7..a2e887133 100644
--- a/crates/rmcp/Cargo.toml
+++ b/crates/rmcp/Cargo.toml
@@ -19,7 +19,7 @@ async-trait = "0.1.89"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
thiserror = "2"
-tokio = { version = "1", features = ["sync", "macros", "rt", "time"] }
+tokio = { version = "1", features = ["sync", "macros", "time"] }
futures = "0.3"
tracing = { version = "0.1" }
tokio-util = { version = "0.7" }
@@ -58,13 +58,14 @@ process-wrap = { version = "9.0", features = ["tokio1"], optional = true }
# for http-server transport
rand = { version = "0.10", optional = true }
-tokio-stream = { version = "0.1", optional = true }
+tokio-stream = { version = "0.1", optional = true, features = ["sync"] }
uuid = { version = "1", features = ["v4"], optional = true }
http-body = { version = "1", optional = true }
http-body-util = { version = "0.1", optional = true }
bytes = { version = "1", optional = true }
# macro
rmcp-macros = { workspace = true, optional = true }
+futures-timeout = "0.1.3"
[target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies]
chrono = { version = "0.4.38", features = ["serde"] }
@@ -79,7 +80,12 @@ chrono = { version = "0.4.38", default-features = false, features = [
[features]
default = ["base64", "macros", "server"]
client = ["dep:tokio-stream"]
-server = ["transport-async-rw", "dep:schemars", "dep:pastey"]
+server = [
+ "transport-async-rw",
+ "dep:schemars",
+ "dep:pastey",
+ "dep:tokio-stream",
+]
macros = ["dep:rmcp-macros", "dep:pastey"]
elicitation = ["dep:url"]
@@ -111,14 +117,18 @@ client-side-sse = ["dep:sse-stream", "dep:http"]
# Streamable HTTP client
transport-streamable-http-client = ["client-side-sse", "transport-worker"]
-transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "__reqwest"]
+transport-streamable-http-client-reqwest = [
+ "transport-streamable-http-client",
+ "__reqwest",
+]
-transport-async-rw = ["tokio/io-util", "tokio-util/codec"]
+transport-async-rw = ["tokio/io-util", "tokio-util/codec", "tokio-util/compat"]
transport-io = ["transport-async-rw", "tokio/io-std"]
-transport-child-process = [
+transport-child-process = ["transport-async-rw", "tokio/process"]
+transport-child-process-tokio = [
"transport-async-rw",
"tokio/process",
- "dep:process-wrap",
+ "tokio/rt",
]
transport-streamable-http-server = [
"transport-streamable-http-server-session",
@@ -138,7 +148,10 @@ schemars = ["dep:schemars"]
[dev-dependencies]
tokio = { version = "1", features = ["full"] }
schemars = { version = "1.1.0", features = ["chrono04"] }
-axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] }
+axum = { version = "0.8", default-features = false, features = [
+ "http1",
+ "tokio",
+] }
url = "2.4"
anyhow = "1.0"
tracing-subscriber = { version = "0.3", features = [
@@ -159,6 +172,7 @@ required-features = [
"server",
"client",
"transport-child-process",
+ "transport-child-process-tokio",
]
path = "tests/test_with_python.rs"
@@ -168,8 +182,10 @@ required-features = [
"server",
"client",
"transport-child-process",
+ "transport-child-process-tokio",
"transport-streamable-http-server",
"transport-streamable-http-client",
+ "transport-streamable-http-client-reqwest",
"__reqwest",
]
path = "tests/test_with_js.rs"
@@ -211,12 +227,22 @@ path = "tests/test_task.rs"
[[test]]
name = "test_streamable_http_priming"
-required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
+required-features = [
+ "server",
+ "client",
+ "transport-streamable-http-server",
+ "reqwest",
+]
path = "tests/test_streamable_http_priming.rs"
[[test]]
name = "test_streamable_http_json_response"
-required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
+required-features = [
+ "server",
+ "client",
+ "transport-streamable-http-server",
+ "reqwest",
+]
path = "tests/test_streamable_http_json_response.rs"
@@ -253,7 +279,13 @@ path = "tests/test_custom_headers.rs"
[[test]]
name = "test_sse_concurrent_streams"
-required-features = ["server", "client", "transport-streamable-http-server", "transport-streamable-http-client", "reqwest"]
+required-features = [
+ "server",
+ "client",
+ "transport-streamable-http-server",
+ "transport-streamable-http-client",
+ "reqwest",
+]
path = "tests/test_sse_concurrent_streams.rs"
[[test]]
diff --git a/crates/rmcp/src/error.rs b/crates/rmcp/src/error.rs
index 74f7d4383..a710b0a74 100644
--- a/crates/rmcp/src/error.rs
+++ b/crates/rmcp/src/error.rs
@@ -31,8 +31,6 @@ pub enum RmcpError {
#[cfg(feature = "server")]
#[error("Server initialization error: {0}")]
ServerInitialize(#[from] crate::service::ServerInitializeError),
- #[error("Runtime error: {0}")]
- Runtime(#[from] tokio::task::JoinError),
#[error("Transport creation error: {error}")]
// TODO: Maybe we can introduce something like `TryIntoTransport` to auto wrap transport type,
// but it could be an breaking change, so we could do it in the future.
diff --git a/crates/rmcp/src/handler/client/progress.rs b/crates/rmcp/src/handler/client/progress.rs
index 04f31610f..892de197f 100644
--- a/crates/rmcp/src/handler/client/progress.rs
+++ b/crates/rmcp/src/handler/client/progress.rs
@@ -1,32 +1,53 @@
-use std::{collections::HashMap, sync::Arc};
-
use futures::{Stream, StreamExt};
-use tokio::sync::RwLock;
-use tokio_stream::wrappers::ReceiverStream;
+use tokio::sync::broadcast;
+use tokio_stream::wrappers::BroadcastStream;
-use crate::model::{ProgressNotificationParam, ProgressToken};
-type Dispatcher =
- Arc>>>;
+use crate::{
+ model::{ProgressNotificationParam, ProgressToken},
+ util::PinnedStream,
+};
/// A dispatcher for progress notifications.
-#[derive(Debug, Clone, Default)]
+///
+/// See [ProgressNotificationParam] and [ProgressToken] for more details on
+/// how progress is dispatched to a particular listener.
+#[derive(Debug, Clone)]
pub struct ProgressDispatcher {
- pub(crate) dispatcher: Dispatcher,
+ /// A channel of any progress notification. Subscribers will filter
+ /// on this channel.
+ pub(crate) any_progress_notification_tx: broadcast::Sender,
+ pub(crate) unsubscribe_tx: broadcast::Sender,
+ pub(crate) unsubscribe_all_tx: broadcast::Sender<()>,
}
impl ProgressDispatcher {
const CHANNEL_SIZE: usize = 16;
pub fn new() -> Self {
- Self::default()
+ // Note that channel size is per-receiver for broadcast channel. It is up to the receiver to
+ // keep up with the notifications to avoid missing any (via propper polling)
+ let (any_progress_notification_tx, _) = broadcast::channel(Self::CHANNEL_SIZE);
+ let (unsubscribe_tx, _) = broadcast::channel(Self::CHANNEL_SIZE);
+ let (unsubscribe_all_tx, _) = broadcast::channel(Self::CHANNEL_SIZE);
+ Self {
+ any_progress_notification_tx,
+ unsubscribe_tx,
+ unsubscribe_all_tx,
+ }
}
/// Handle a progress notification by sending it to the appropriate subscriber
pub async fn handle_notification(&self, notification: ProgressNotificationParam) {
- let token = ¬ification.progress_token;
- if let Some(sender) = self.dispatcher.read().await.get(token).cloned() {
- let send_result = sender.send(notification).await;
- if let Err(e) = send_result {
- tracing::warn!("Failed to send progress notification: {e}");
+ // Broadcast the notification to all subscribers. Interested subscribers
+ // will filter on their end.
+ // ! Note that this implementaiton is very stateless and simple, we cannot
+ // ! easily inspect which subscribers are interested in which notifications.
+ // ! However, the stateless-ness and simplicity is also a plus!
+ // ! Cleanup becomes much easier. Just drop the `ProgressSubscriber`.
+ match self.any_progress_notification_tx.send(notification) {
+ Ok(_) => {}
+ Err(_) => {
+ // This error only happens if there are no active receivers of the `broadcast` channel.
+ // Silent error.
}
}
}
@@ -35,35 +56,97 @@ impl ProgressDispatcher {
///
/// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token.
pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber {
- let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE);
- self.dispatcher
- .write()
- .await
- .insert(progress_token.clone(), sender);
- let receiver = ReceiverStream::new(receiver);
+ // First, set up the unsubscribe listeners. This will fuse the notifiaction stream below.
+ let progress_token_clone = progress_token.clone();
+ let unsub_this_token_rx = BroadcastStream::new(self.unsubscribe_tx.subscribe()).filter_map(
+ move |token| {
+ let progress_token_clone = progress_token_clone.clone();
+ async move {
+ match token {
+ Ok(token) => {
+ if token == progress_token_clone {
+ Some(())
+ } else {
+ None
+ }
+ }
+ Err(e) => {
+ // An error here means the broadcast stream did not receive values quick enough and
+ // and we missed some notification. This implies there are notifications
+ // we missed, but we cannot assume they were for us :(
+ tracing::warn!(
+ "Error receiving unsubscribe notification from broadcast channel: {e}"
+ );
+ None
+ }
+ }
+ }
+ },
+ );
+ let unsub_any_token_tx =
+ BroadcastStream::new(self.unsubscribe_all_tx.subscribe()).map(|_| {
+ // Any reception of a result here indicates we should unsubscribe,
+ // regardless of if we received an `Ok(())` or an `Err(_)` (which
+ // indicates the broadcast receiver lagged behind)
+ ()
+ });
+ let unsub_fut = futures::stream::select(unsub_this_token_rx, unsub_any_token_tx)
+ .boxed()
+ .into_future(); // If the unsub streams end, this will cause unsubscription from the subscriber below.
+
+ // Now setup the notification stream. We will receive all notifications and only forward progress notifications
+ // for the token we're interested in.
+ let progress_token_clone = progress_token.clone();
+ let receiver = BroadcastStream::new(self.any_progress_notification_tx.subscribe())
+ .filter_map(move |notification| {
+ let progress_token_clone = progress_token_clone.clone();
+ async move {
+ // We need to kneed-out the broadcast receive error type here.
+ match notification {
+ Ok(notification) => {
+ let token = notification.progress_token.clone();
+ if token == progress_token_clone {
+ Some(notification)
+ } else {
+ None
+ }
+ }
+ Err(e) => {
+ tracing::warn!(
+ "Error receiving progress notification from broadcast channel: {e}"
+ );
+ None
+ }
+ }
+ }
+ })
+ // Fuse this stream so it stops once we receive an unsubscribe notification from the stream
+ // created above
+ .take_until(unsub_fut)
+ .boxed();
+
ProgressSubscriber {
progress_token,
receiver,
- dispatcher: self.dispatcher.clone(),
}
}
/// Unsubscribe from progress notifications for a specific token.
- pub async fn unsubscribe(&self, token: &ProgressToken) {
- self.dispatcher.write().await.remove(token);
+ pub fn unsubscribe(&self, token: ProgressToken) {
+ // The only error defined is if there are no listeners, which is fine. Ignore the result.
+ let _ = self.unsubscribe_tx.send(token);
}
/// Clear all dispatcher.
- pub async fn clear(&self) {
- let mut dispatcher = self.dispatcher.write().await;
- dispatcher.clear();
+ pub fn clear(&self) {
+ // The only error defined is if there are no listeners, which is fine. Ignore the result.
+ let _ = self.unsubscribe_all_tx.send(());
}
}
pub struct ProgressSubscriber {
pub(crate) progress_token: ProgressToken,
- pub(crate) receiver: ReceiverStream,
- pub(crate) dispatcher: Dispatcher,
+ pub(crate) receiver: PinnedStream<'static, ProgressNotificationParam>,
}
impl ProgressSubscriber {
@@ -86,15 +169,3 @@ impl Stream for ProgressSubscriber {
self.receiver.size_hint()
}
}
-
-impl Drop for ProgressSubscriber {
- fn drop(&mut self) {
- let token = self.progress_token.clone();
- self.receiver.close();
- let dispatcher = self.dispatcher.clone();
- tokio::spawn(async move {
- let mut dispatcher = dispatcher.write_owned().await;
- dispatcher.remove(&token);
- });
- }
-}
diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs
index 9ae3f9586..c70e61b5b 100644
--- a/crates/rmcp/src/lib.rs
+++ b/crates/rmcp/src/lib.rs
@@ -3,6 +3,8 @@
#![doc = include_str!("../README.md")]
mod error;
+mod util;
+
#[allow(deprecated)]
pub use error::{Error, ErrorData, RmcpError};
diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs
index d6613dd3c..cf0854191 100644
--- a/crates/rmcp/src/service.rs
+++ b/crates/rmcp/src/service.rs
@@ -1,5 +1,11 @@
-use futures::{FutureExt, future::BoxFuture};
+use futures::{
+ FutureExt, Stream, StreamExt,
+ future::{BoxFuture, RemoteHandle},
+ stream::FuturesUnordered,
+};
+use futures_timeout::TimeoutExt;
use thiserror::Error;
+use tokio_stream::wrappers::ReceiverStream;
#[cfg(feature = "server")]
use crate::model::ServerJsonRpcMessage;
@@ -11,6 +17,7 @@ use crate::{
NumberOrString, ProgressToken, RequestId,
},
transport::{DynamicTransportError, IntoTransport, Transport},
+ util::PinnedFuture,
};
#[cfg(feature = "client")]
mod client;
@@ -108,10 +115,33 @@ pub trait ServiceExt: Service + Sized {
fn into_dyn(self) -> Box> {
Box::new(self)
}
+
+ /// Serve this service with the provided transport
+ ///
+ /// This function returns a facade to the running service, and a future that runs the
+ /// service business logic. The caller is responsible for running the business logic
+ /// future, either by spawning it on a runtime or awaiting it directly.
+ ///
+ /// Ex:
+ /// ```rust,ignore
+ /// // Try to initialize a service with the given transport
+ /// let (client, work) = MyServiceImpl.serve(transport).await?;
+ ///
+ /// // Spawn the service business logic on a runtime (e.g. tokio)
+ /// tokio::spawn(work);
+ ///
+ /// // Now we can interact with the service
+ /// let response = client.send_request(...).await?;
+ /// ```
+ ///
+ /// The returned [RunningService] provides methods to interact with the running service, such
+ /// as sending requests or notifications to the peer, and cancelling the service.
fn serve(
self,
transport: T,
- ) -> impl Future