From 3b18ce9f3fb72358380c59d9f639a925933f82ec Mon Sep 17 00:00:00 2001 From: YeonGyu-Kim Date: Thu, 2 Apr 2026 18:24:30 +0900 Subject: [PATCH] feat(mcp): add toolCallTimeoutMs, timeout/reconnect/error handling - Add toolCallTimeoutMs to stdio MCP config with 60s default - tools/call runs under timeout with dedicated Timeout error - Handle malformed JSON/broken protocol as InvalidResponse - Reset/reconnect stdio state on child exit or transport drop - Add tests: slow timeout, invalid JSON response, stdio reconnect - Verified: cargo test -p runtime 113 passed, clippy clean --- rust/crates/runtime/src/config.rs | 23 + rust/crates/runtime/src/mcp.rs | 8 +- rust/crates/runtime/src/mcp_client.rs | 14 + rust/crates/runtime/src/mcp_stdio.rs | 956 ++++++++++++++++++++++---- 4 files changed, 866 insertions(+), 135 deletions(-) diff --git a/rust/crates/runtime/src/config.rs b/rust/crates/runtime/src/config.rs index ccaf09a..ed362f1 100644 --- a/rust/crates/runtime/src/config.rs +++ b/rust/crates/runtime/src/config.rs @@ -106,6 +106,7 @@ pub struct McpStdioServerConfig { pub command: String, pub args: Vec, pub env: BTreeMap, + pub tool_call_timeout_ms: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -791,6 +792,7 @@ fn parse_mcp_server_config( command: expect_string(object, "command", context)?.to_string(), args: optional_string_array(object, "args", context)?.unwrap_or_default(), env: optional_string_map(object, "env", context)?.unwrap_or_default(), + tool_call_timeout_ms: optional_u64(object, "toolCallTimeoutMs", context)?, })), "sse" => Ok(McpServerConfig::Sse(parse_mcp_remote_server_config( object, context, @@ -914,6 +916,27 @@ fn optional_u16( } } +fn optional_u64( + object: &BTreeMap, + key: &str, + context: &str, +) -> Result, ConfigError> { + match object.get(key) { + Some(value) => { + let Some(number) = value.as_i64() else { + return Err(ConfigError::Parse(format!( + "{context}: field {key} must be a non-negative integer" + ))); + }; + let number = u64::try_from(number).map_err(|_| { + ConfigError::Parse(format!("{context}: field {key} is out of range")) + })?; + Ok(Some(number)) + } + None => Ok(None), + } +} + fn parse_bool_map(value: &JsonValue, context: &str) -> Result, ConfigError> { let Some(map) = value.as_object() else { return Err(ConfigError::Parse(format!( diff --git a/rust/crates/runtime/src/mcp.rs b/rust/crates/runtime/src/mcp.rs index b37ea33..e65cd08 100644 --- a/rust/crates/runtime/src/mcp.rs +++ b/rust/crates/runtime/src/mcp.rs @@ -84,10 +84,13 @@ pub fn mcp_server_signature(config: &McpServerConfig) -> Option { pub fn scoped_mcp_config_hash(config: &ScopedMcpServerConfig) -> String { let rendered = match &config.config { McpServerConfig::Stdio(stdio) => format!( - "stdio|{}|{}|{}", + "stdio|{}|{}|{}|{}", stdio.command, render_command_signature(&stdio.args), - render_env_signature(&stdio.env) + render_env_signature(&stdio.env), + stdio + .tool_call_timeout_ms + .map_or_else(String::new, |timeout_ms| timeout_ms.to_string()) ), McpServerConfig::Sse(remote) => format!( "sse|{}|{}|{}|{}", @@ -245,6 +248,7 @@ mod tests { command: "uvx".to_string(), args: vec!["mcp-server".to_string()], env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]), + tool_call_timeout_ms: None, }); assert_eq!( mcp_server_signature(&stdio), diff --git a/rust/crates/runtime/src/mcp_client.rs b/rust/crates/runtime/src/mcp_client.rs index e0e1f2c..96a6db2 100644 --- a/rust/crates/runtime/src/mcp_client.rs +++ b/rust/crates/runtime/src/mcp_client.rs @@ -3,6 +3,8 @@ use std::collections::BTreeMap; use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig}; use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp}; +pub const DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS: u64 = 60_000; + #[derive(Debug, Clone, PartialEq, Eq)] pub enum McpClientTransport { Stdio(McpStdioTransport), @@ -18,6 +20,7 @@ pub struct McpStdioTransport { pub command: String, pub args: Vec, pub env: BTreeMap, + pub tool_call_timeout_ms: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -75,6 +78,7 @@ impl McpClientTransport { command: config.command.clone(), args: config.args.clone(), env: config.env.clone(), + tool_call_timeout_ms: config.tool_call_timeout_ms, }), McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport { url: config.url.clone(), @@ -105,6 +109,14 @@ impl McpClientTransport { } } +impl McpStdioTransport { + #[must_use] + pub fn resolved_tool_call_timeout_ms(&self) -> u64 { + self.tool_call_timeout_ms + .unwrap_or(DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS) + } +} + impl McpClientAuth { #[must_use] pub fn from_oauth(oauth: Option) -> Self { @@ -136,6 +148,7 @@ mod tests { command: "uvx".to_string(), args: vec!["mcp-server".to_string()], env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]), + tool_call_timeout_ms: Some(15_000), }), }; @@ -154,6 +167,7 @@ mod tests { transport.env.get("TOKEN").map(String::as_str), Some("secret") ); + assert_eq!(transport.tool_call_timeout_ms, Some(15_000)); } other => panic!("expected stdio transport, got {other:?}"), } diff --git a/rust/crates/runtime/src/mcp_stdio.rs b/rust/crates/runtime/src/mcp_stdio.rs index b72b9dd..7f5456a 100644 --- a/rust/crates/runtime/src/mcp_stdio.rs +++ b/rust/crates/runtime/src/mcp_stdio.rs @@ -1,17 +1,30 @@ use std::collections::BTreeMap; +use std::future::Future; use std::io; use std::process::Stdio; +use std::time::Duration; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::time::timeout; use crate::config::{McpTransport, RuntimeConfig, ScopedMcpServerConfig}; use crate::mcp::mcp_tool_name; use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport}; +#[cfg(test)] +const MCP_INITIALIZE_TIMEOUT_MS: u64 = 200; +#[cfg(not(test))] +const MCP_INITIALIZE_TIMEOUT_MS: u64 = 10_000; + +#[cfg(test)] +const MCP_LIST_TOOLS_TIMEOUT_MS: u64 = 300; +#[cfg(not(test))] +const MCP_LIST_TOOLS_TIMEOUT_MS: u64 = 30_000; + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] pub enum JsonRpcId { @@ -220,6 +233,11 @@ pub struct UnsupportedMcpServer { #[derive(Debug)] pub enum McpServerManagerError { Io(io::Error), + Transport { + server_name: String, + method: &'static str, + source: io::Error, + }, JsonRpc { server_name: String, method: &'static str, @@ -230,6 +248,11 @@ pub enum McpServerManagerError { method: &'static str, details: String, }, + Timeout { + server_name: String, + method: &'static str, + timeout_ms: u64, + }, UnknownTool { qualified_name: String, }, @@ -242,6 +265,14 @@ impl std::fmt::Display for McpServerManagerError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Io(error) => write!(f, "{error}"), + Self::Transport { + server_name, + method, + source, + } => write!( + f, + "MCP server `{server_name}` transport failed during {method}: {source}" + ), Self::JsonRpc { server_name, method, @@ -259,6 +290,14 @@ impl std::fmt::Display for McpServerManagerError { f, "MCP server `{server_name}` returned invalid response for {method}: {details}" ), + Self::Timeout { + server_name, + method, + timeout_ms, + } => write!( + f, + "MCP server `{server_name}` timed out after {timeout_ms} ms while handling {method}" + ), Self::UnknownTool { qualified_name } => { write!(f, "unknown MCP tool `{qualified_name}`") } @@ -271,8 +310,10 @@ impl std::error::Error for McpServerManagerError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Self::Io(error) => Some(error), + Self::Transport { source, .. } => Some(source), Self::JsonRpc { .. } | Self::InvalidResponse { .. } + | Self::Timeout { .. } | Self::UnknownTool { .. } | Self::UnknownServer { .. } => None, } @@ -361,69 +402,18 @@ impl McpServerManager { let mut discovered_tools = Vec::new(); for server_name in server_names { - self.ensure_server_ready(&server_name).await?; + let server_tools = self.discover_tools_for_server(&server_name).await?; self.clear_routes_for_server(&server_name); - let mut cursor = None; - loop { - let request_id = self.take_request_id(); - let response = { - let server = self.server_mut(&server_name)?; - let process = server.process.as_mut().ok_or_else(|| { - McpServerManagerError::InvalidResponse { - server_name: server_name.clone(), - method: "tools/list", - details: "server process missing after initialization".to_string(), - } - })?; - process - .list_tools( - request_id, - Some(McpListToolsParams { - cursor: cursor.clone(), - }), - ) - .await? - }; - - if let Some(error) = response.error { - return Err(McpServerManagerError::JsonRpc { - server_name: server_name.clone(), - method: "tools/list", - error, - }); - } - - let result = - response - .result - .ok_or_else(|| McpServerManagerError::InvalidResponse { - server_name: server_name.clone(), - method: "tools/list", - details: "missing result payload".to_string(), - })?; - - for tool in result.tools { - let qualified_name = mcp_tool_name(&server_name, &tool.name); - self.tool_index.insert( - qualified_name.clone(), - ToolRoute { - server_name: server_name.clone(), - raw_name: tool.name.clone(), - }, - ); - discovered_tools.push(ManagedMcpTool { - server_name: server_name.clone(), - qualified_name, - raw_name: tool.name.clone(), - tool, - }); - } - - match result.next_cursor { - Some(next_cursor) => cursor = Some(next_cursor), - None => break, - } + for tool in server_tools { + self.tool_index.insert( + tool.qualified_name.clone(), + ToolRoute { + server_name: tool.server_name.clone(), + raw_name: tool.raw_name.clone(), + }, + ); + discovered_tools.push(tool); } } @@ -443,30 +433,42 @@ impl McpServerManager { qualified_name: qualified_tool_name.to_string(), })?; + let timeout_ms = self.tool_call_timeout_ms(&route.server_name)?; + self.ensure_server_ready(&route.server_name).await?; let request_id = self.take_request_id(); - let response = - { - let server = self.server_mut(&route.server_name)?; - let process = server.process.as_mut().ok_or_else(|| { - McpServerManagerError::InvalidResponse { - server_name: route.server_name.clone(), - method: "tools/call", - details: "server process missing after initialization".to_string(), - } - })?; - process - .call_tool( - request_id, - McpToolCallParams { - name: route.raw_name, - arguments, - meta: None, - }, - ) - .await? - }; - Ok(response) + let response = { + let server = self.server_mut(&route.server_name)?; + let process = server.process.as_mut().ok_or_else(|| { + McpServerManagerError::InvalidResponse { + server_name: route.server_name.clone(), + method: "tools/call", + details: "server process missing after initialization".to_string(), + } + })?; + Self::run_process_request( + &route.server_name, + "tools/call", + timeout_ms, + process.call_tool( + request_id, + McpToolCallParams { + name: route.raw_name, + arguments, + meta: None, + }, + ), + ) + .await + }; + + if let Err(error) = &response { + if Self::should_reset_server(error) { + self.reset_server(&route.server_name).await?; + } + } + + response } pub async fn shutdown(&mut self) -> Result<(), McpServerManagerError> { @@ -504,33 +506,220 @@ impl McpServerManager { JsonRpcId::Number(id) } + fn tool_call_timeout_ms(&self, server_name: &str) -> Result { + let server = self + .servers + .get(server_name) + .ok_or_else(|| McpServerManagerError::UnknownServer { + server_name: server_name.to_string(), + })?; + match &server.bootstrap.transport { + McpClientTransport::Stdio(transport) => Ok(transport.resolved_tool_call_timeout_ms()), + other => Err(McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "tools/call", + details: format!("unsupported MCP transport for stdio manager: {other:?}"), + }), + } + } + + fn server_process_exited(&mut self, server_name: &str) -> Result { + let server = self.server_mut(server_name)?; + match server.process.as_mut() { + Some(process) => Ok(process.has_exited()?), + None => Ok(false), + } + } + + async fn discover_tools_for_server( + &mut self, + server_name: &str, + ) -> Result, McpServerManagerError> { + let mut attempts = 0; + + loop { + match self.discover_tools_for_server_once(server_name).await { + Ok(tools) => return Ok(tools), + Err(error) if attempts == 0 && Self::is_retryable_error(&error) => { + self.reset_server(server_name).await?; + attempts += 1; + } + Err(error) => { + if Self::should_reset_server(&error) { + self.reset_server(server_name).await?; + } + return Err(error); + } + } + } + } + + async fn discover_tools_for_server_once( + &mut self, + server_name: &str, + ) -> Result, McpServerManagerError> { + self.ensure_server_ready(server_name).await?; + + let mut discovered_tools = Vec::new(); + let mut cursor = None; + loop { + let request_id = self.take_request_id(); + let response = { + let server = self.server_mut(server_name)?; + let process = server.process.as_mut().ok_or_else(|| { + McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "tools/list", + details: "server process missing after initialization".to_string(), + } + })?; + Self::run_process_request( + server_name, + "tools/list", + MCP_LIST_TOOLS_TIMEOUT_MS, + process.list_tools( + request_id, + Some(McpListToolsParams { + cursor: cursor.clone(), + }), + ), + ) + .await? + }; + + if let Some(error) = response.error { + return Err(McpServerManagerError::JsonRpc { + server_name: server_name.to_string(), + method: "tools/list", + error, + }); + } + + let result = + response + .result + .ok_or_else(|| McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "tools/list", + details: "missing result payload".to_string(), + })?; + + for tool in result.tools { + let qualified_name = mcp_tool_name(server_name, &tool.name); + discovered_tools.push(ManagedMcpTool { + server_name: server_name.to_string(), + qualified_name, + raw_name: tool.name.clone(), + tool, + }); + } + + match result.next_cursor { + Some(next_cursor) => cursor = Some(next_cursor), + None => break, + } + } + + Ok(discovered_tools) + } + + async fn reset_server(&mut self, server_name: &str) -> Result<(), McpServerManagerError> { + let mut process = { + let server = self.server_mut(server_name)?; + server.initialized = false; + server.process.take() + }; + + if let Some(process) = process.as_mut() { + let _ = process.shutdown().await; + } + + Ok(()) + } + + fn is_retryable_error(error: &McpServerManagerError) -> bool { + matches!( + error, + McpServerManagerError::Transport { .. } | McpServerManagerError::Timeout { .. } + ) + } + + fn should_reset_server(error: &McpServerManagerError) -> bool { + matches!( + error, + McpServerManagerError::Transport { .. } + | McpServerManagerError::Timeout { .. } + | McpServerManagerError::InvalidResponse { .. } + ) + } + + async fn run_process_request( + server_name: &str, + method: &'static str, + timeout_ms: u64, + future: F, + ) -> Result + where + F: Future>, + { + match timeout(Duration::from_millis(timeout_ms), future).await { + Ok(Ok(value)) => Ok(value), + Ok(Err(error)) if error.kind() == io::ErrorKind::InvalidData => { + Err(McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method, + details: error.to_string(), + }) + } + Ok(Err(source)) => Err(McpServerManagerError::Transport { + server_name: server_name.to_string(), + method, + source, + }), + Err(_) => Err(McpServerManagerError::Timeout { + server_name: server_name.to_string(), + method, + timeout_ms, + }), + } + } + async fn ensure_server_ready( &mut self, server_name: &str, ) -> Result<(), McpServerManagerError> { - let needs_spawn = self - .servers - .get(server_name) - .map(|server| server.process.is_none()) - .ok_or_else(|| McpServerManagerError::UnknownServer { - server_name: server_name.to_string(), - })?; - - if needs_spawn { - let server = self.server_mut(server_name)?; - server.process = Some(spawn_mcp_stdio_process(&server.bootstrap)?); - server.initialized = false; + if self.server_process_exited(server_name)? { + self.reset_server(server_name).await?; } - let needs_initialize = self - .servers - .get(server_name) - .map(|server| !server.initialized) - .ok_or_else(|| McpServerManagerError::UnknownServer { - server_name: server_name.to_string(), - })?; + let mut attempts = 0; + loop { + let needs_spawn = self + .servers + .get(server_name) + .map(|server| server.process.is_none()) + .ok_or_else(|| McpServerManagerError::UnknownServer { + server_name: server_name.to_string(), + })?; + + if needs_spawn { + let server = self.server_mut(server_name)?; + server.process = Some(spawn_mcp_stdio_process(&server.bootstrap)?); + server.initialized = false; + } + + let needs_initialize = self + .servers + .get(server_name) + .map(|server| !server.initialized) + .ok_or_else(|| McpServerManagerError::UnknownServer { + server_name: server_name.to_string(), + })?; + + if !needs_initialize { + return Ok(()); + } - if needs_initialize { let request_id = self.take_request_id(); let response = { let server = self.server_mut(server_name)?; @@ -541,9 +730,28 @@ impl McpServerManager { details: "server process missing before initialize".to_string(), } })?; - process - .initialize(request_id, default_initialize_params()) - .await? + Self::run_process_request( + server_name, + "initialize", + MCP_INITIALIZE_TIMEOUT_MS, + process.initialize(request_id, default_initialize_params()), + ) + .await + }; + + let response = match response { + Ok(response) => response, + Err(error) if attempts == 0 && Self::is_retryable_error(&error) => { + self.reset_server(server_name).await?; + attempts += 1; + continue; + } + Err(error) => { + if Self::should_reset_server(&error) { + self.reset_server(server_name).await?; + } + return Err(error); + } }; if let Some(error) = response.error { @@ -555,18 +763,19 @@ impl McpServerManager { } if response.result.is_none() { - return Err(McpServerManagerError::InvalidResponse { + let error = McpServerManagerError::InvalidResponse { server_name: server_name.to_string(), method: "initialize", details: "missing result payload".to_string(), - }); + }; + self.reset_server(server_name).await?; + return Err(error); } let server = self.server_mut(server_name)?; server.initialized = true; + return Ok(()); } - - Ok(()) } } @@ -657,12 +866,15 @@ impl McpStdioProcess { if line == "\r\n" { break; } - if let Some(value) = line.strip_prefix("Content-Length:") { - let parsed = value - .trim() - .parse::() - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; - content_length = Some(parsed); + let header = line.trim_end_matches(['\r', '\n']); + if let Some((name, value)) = header.split_once(':') { + if name.trim().eq_ignore_ascii_case("Content-Length") { + let parsed = value + .trim() + .parse::() + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + content_length = Some(parsed); + } } } @@ -703,9 +915,32 @@ impl McpStdioProcess { method: impl Into, params: Option, ) -> io::Result> { - let request = JsonRpcRequest::new(id, method, params); + let method = method.into(); + let request = JsonRpcRequest::new(id.clone(), method.clone(), params); self.send_request(&request).await?; - self.read_response().await + let response = self.read_response().await?; + + if response.jsonrpc != "2.0" { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "MCP response for {method} used unsupported jsonrpc version `{}`", + response.jsonrpc + ), + )); + } + + if response.id != id { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "MCP response for {method} used mismatched id: expected {id:?}, got {:?}", + response.id + ), + )); + } + + Ok(response) } pub async fn initialize( @@ -756,9 +991,17 @@ impl McpStdioProcess { self.child.wait().await } + pub fn has_exited(&mut self) -> io::Result { + Ok(self.child.try_wait()?.is_some()) + } + async fn shutdown(&mut self) -> io::Result<()> { if self.child.try_wait()?.is_none() { - self.child.kill().await?; + match self.child.kill().await { + Ok(()) => {} + Err(error) if error.kind() == io::ErrorKind::InvalidInput => {} + Err(error) => return Err(error), + } } let _ = self.child.wait().await?; Ok(()) @@ -809,6 +1052,7 @@ mod tests { use std::io::ErrorKind; use std::os::unix::fs::PermissionsExt; use std::path::{Path, PathBuf}; + use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{SystemTime, UNIX_EPOCH}; use serde_json::json; @@ -829,11 +1073,13 @@ mod tests { }; fn temp_dir() -> PathBuf { + static NEXT_TEMP_DIR_ID: AtomicU64 = AtomicU64::new(0); let nanos = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time should be after epoch") .as_nanos(); - std::env::temp_dir().join(format!("runtime-mcp-stdio-{nanos}")) + let unique_id = NEXT_TEMP_DIR_ID.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!("runtime-mcp-stdio-{nanos}-{unique_id}")) } fn write_echo_script() -> PathBuf { @@ -857,7 +1103,9 @@ mod tests { let script_path = root.join("jsonrpc-mcp.py"); let script = [ "#!/usr/bin/env python3", - "import json, sys", + "import json, os, sys", + "LOWERCASE_CONTENT_LENGTH = os.environ.get('MCP_LOWERCASE_CONTENT_LENGTH') == '1'", + "MISMATCHED_RESPONSE_ID = os.environ.get('MCP_MISMATCHED_RESPONSE_ID') == '1'", "header = b''", r"while not header.endswith(b'\r\n\r\n'):", " chunk = sys.stdin.buffer.read(1)", @@ -872,16 +1120,18 @@ mod tests { "request = json.loads(payload.decode())", r"assert request['jsonrpc'] == '2.0'", r"assert request['method'] == 'initialize'", + "response_id = 'wrong-id' if MISMATCHED_RESPONSE_ID else request['id']", + "header_name = 'content-length' if LOWERCASE_CONTENT_LENGTH else 'Content-Length'", r"response = json.dumps({", r" 'jsonrpc': '2.0',", - r" 'id': request['id'],", + r" 'id': response_id,", r" 'result': {", r" 'protocolVersion': request['params']['protocolVersion'],", r" 'capabilities': {'tools': {}},", r" 'serverInfo': {'name': 'fake-mcp', 'version': '0.1.0'}", r" }", r"}).encode()", - r"sys.stdout.buffer.write(f'Content-Length: {len(response)}\r\n\r\n'.encode() + response)", + r"sys.stdout.buffer.write(f'{header_name}: {len(response)}\r\n\r\n'.encode() + response)", "sys.stdout.buffer.flush()", "", ] @@ -900,7 +1150,9 @@ mod tests { let script_path = root.join("fake-mcp-server.py"); let script = [ "#!/usr/bin/env python3", - "import json, sys", + "import json, os, sys, time", + "TOOL_CALL_DELAY_MS = int(os.environ.get('MCP_TOOL_CALL_DELAY_MS', '0'))", + "INVALID_TOOL_CALL_RESPONSE = os.environ.get('MCP_INVALID_TOOL_CALL_RESPONSE') == '1'", "", "def read_message():", " header = b''", @@ -955,6 +1207,12 @@ mod tests { " }", " })", " elif method == 'tools/call':", + " if INVALID_TOOL_CALL_RESPONSE:", + " sys.stdout.buffer.write(b'Content-Length: 5\\r\\n\\r\\nnope!')", + " sys.stdout.buffer.flush()", + " continue", + " if TOOL_CALL_DELAY_MS:", + " time.sleep(TOOL_CALL_DELAY_MS / 1000)", " args = request['params'].get('arguments') or {}", " if request['params']['name'] == 'fail':", " send_message({", @@ -1026,10 +1284,13 @@ mod tests { let script_path = root.join("manager-mcp-server.py"); let script = [ "#!/usr/bin/env python3", - "import json, os, sys", + "import json, os, sys, time", "", "LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')", "LOG_PATH = os.environ.get('MCP_LOG_PATH')", + "EXIT_AFTER_TOOLS_LIST = os.environ.get('MCP_EXIT_AFTER_TOOLS_LIST') == '1'", + "FAIL_ONCE_MODE = os.environ.get('MCP_FAIL_ONCE_MODE')", + "FAIL_ONCE_MARKER = os.environ.get('MCP_FAIL_ONCE_MARKER')", "initialize_count = 0", "", "def log(method):", @@ -1037,6 +1298,15 @@ mod tests { " with open(LOG_PATH, 'a', encoding='utf-8') as handle:", " handle.write(f'{method}\\n')", "", + "def should_fail_once():", + " if not FAIL_ONCE_MODE or not FAIL_ONCE_MARKER:", + " return False", + " if os.path.exists(FAIL_ONCE_MARKER):", + " return False", + " with open(FAIL_ONCE_MARKER, 'w', encoding='utf-8') as handle:", + " handle.write(FAIL_ONCE_MODE)", + " return True", + "", "def read_message():", " header = b''", r" while not header.endswith(b'\r\n\r\n'):", @@ -1063,6 +1333,10 @@ mod tests { " method = request['method']", " log(method)", " if method == 'initialize':", + " if FAIL_ONCE_MODE == 'initialize_hang' and should_fail_once():", + " log('initialize-hang')", + " while True:", + " time.sleep(1)", " initialize_count += 1", " send_message({", " 'jsonrpc': '2.0',", @@ -1091,7 +1365,12 @@ mod tests { " ]", " }", " })", + " if EXIT_AFTER_TOOLS_LIST:", + " raise SystemExit(0)", " elif method == 'tools/call':", + " if FAIL_ONCE_MODE == 'tool_call_disconnect' and should_fail_once():", + " log('tools/call-disconnect')", + " raise SystemExit(0)", " args = request['params'].get('arguments') or {}", " text = args.get('text', '')", " send_message({", @@ -1130,16 +1409,25 @@ mod tests { command: "/bin/sh".to_string(), args: vec![script_path.to_string_lossy().into_owned()], env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "secret-value".to_string())]), + tool_call_timeout_ms: None, }), }; McpClientBootstrap::from_scoped_config("stdio server", &config) } fn script_transport(script_path: &Path) -> crate::mcp_client::McpStdioTransport { + script_transport_with_env(script_path, BTreeMap::new()) + } + + fn script_transport_with_env( + script_path: &Path, + env: BTreeMap, + ) -> crate::mcp_client::McpStdioTransport { crate::mcp_client::McpStdioTransport { command: "python3".to_string(), args: vec![script_path.to_string_lossy().into_owned()], - env: BTreeMap::new(), + env, + tool_call_timeout_ms: None, } } @@ -1165,18 +1453,30 @@ mod tests { label: &str, log_path: &Path, ) -> ScopedMcpServerConfig { + manager_server_config_with_env(script_path, label, log_path, BTreeMap::new()) + } + + fn manager_server_config_with_env( + script_path: &Path, + label: &str, + log_path: &Path, + extra_env: BTreeMap, + ) -> ScopedMcpServerConfig { + let mut env = BTreeMap::from([ + ("MCP_SERVER_LABEL".to_string(), label.to_string()), + ( + "MCP_LOG_PATH".to_string(), + log_path.to_string_lossy().into_owned(), + ), + ]); + env.extend(extra_env); ScopedMcpServerConfig { scope: ConfigSource::Local, config: McpServerConfig::Stdio(McpStdioServerConfig { command: "python3".to_string(), args: vec![script_path.to_string_lossy().into_owned()], - env: BTreeMap::from([ - ("MCP_SERVER_LABEL".to_string(), label.to_string()), - ( - "MCP_LOG_PATH".to_string(), - log_path.to_string_lossy().into_owned(), - ), - ]), + env, + tool_call_timeout_ms: None, }), } } @@ -1304,6 +1604,91 @@ mod tests { }); } + #[test] + fn given_lowercase_content_length_when_initialize_then_response_parses() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_jsonrpc_script(); + let transport = script_transport_with_env( + &script_path, + BTreeMap::from([( + "MCP_LOWERCASE_CONTENT_LENGTH".to_string(), + "1".to_string(), + )]), + ); + let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly"); + + let response = process + .initialize( + JsonRpcId::Number(8), + McpInitializeParams { + protocol_version: "2025-03-26".to_string(), + capabilities: json!({"roots": {}}), + client_info: McpInitializeClientInfo { + name: "runtime-tests".to_string(), + version: "0.1.0".to_string(), + }, + }, + ) + .await + .expect("initialize roundtrip"); + + assert_eq!(response.id, JsonRpcId::Number(8)); + assert_eq!(response.error, None); + assert!(response.result.is_some()); + + let status = process.wait().await.expect("wait for exit"); + assert!(status.success()); + + cleanup_script(&script_path); + }); + } + + #[test] + fn given_mismatched_response_id_when_initialize_then_invalid_data_is_returned() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_jsonrpc_script(); + let transport = script_transport_with_env( + &script_path, + BTreeMap::from([( + "MCP_MISMATCHED_RESPONSE_ID".to_string(), + "1".to_string(), + )]), + ); + let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly"); + + let error = process + .initialize( + JsonRpcId::Number(9), + McpInitializeParams { + protocol_version: "2025-03-26".to_string(), + capabilities: json!({"roots": {}}), + client_info: McpInitializeClientInfo { + name: "runtime-tests".to_string(), + version: "0.1.0".to_string(), + }, + }, + ) + .await + .expect_err("mismatched response id should fail"); + + assert_eq!(error.kind(), ErrorKind::InvalidData); + assert!(error.to_string().contains("mismatched id")); + + let status = process.wait().await.expect("wait for exit"); + assert!(status.success()); + + cleanup_script(&script_path); + }); + } + #[test] fn direct_spawn_uses_transport_env() { let runtime = Builder::new_current_thread() @@ -1316,6 +1701,7 @@ mod tests { command: "/bin/sh".to_string(), args: vec![script_path.to_string_lossy().into_owned()], env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "direct-secret".to_string())]), + tool_call_timeout_ms: None, }; let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly"); let ready = process.read_available().await.expect("read ready"); @@ -1556,6 +1942,310 @@ mod tests { }); } + #[test] + fn manager_times_out_slow_tool_calls() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("timeout.log"); + let servers = BTreeMap::from([( + "slow".to_string(), + ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Stdio(McpStdioServerConfig { + command: "python3".to_string(), + args: vec![script_path.to_string_lossy().into_owned()], + env: BTreeMap::from([( + "MCP_TOOL_CALL_DELAY_MS".to_string(), + "200".to_string(), + )]), + tool_call_timeout_ms: Some(25), + }), + }, + )]); + let mut manager = McpServerManager::from_servers(&servers); + + manager.discover_tools().await.expect("discover tools"); + let error = manager + .call_tool(&mcp_tool_name("slow", "echo"), Some(json!({"text": "slow"}))) + .await + .expect_err("slow tool call should time out"); + + match error { + McpServerManagerError::Timeout { + server_name, + method, + timeout_ms, + } => { + assert_eq!(server_name, "slow"); + assert_eq!(method, "tools/call"); + assert_eq!(timeout_ms, 25); + } + other => panic!("expected timeout error, got {other:?}"), + } + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + let _ = fs::remove_file(log_path); + }); + } + + #[test] + fn manager_surfaces_parse_errors_from_tool_calls() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_mcp_server_script(); + let servers = BTreeMap::from([( + "broken".to_string(), + ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Stdio(McpStdioServerConfig { + command: "python3".to_string(), + args: vec![script_path.to_string_lossy().into_owned()], + env: BTreeMap::from([( + "MCP_INVALID_TOOL_CALL_RESPONSE".to_string(), + "1".to_string(), + )]), + tool_call_timeout_ms: Some(1_000), + }), + }, + )]); + let mut manager = McpServerManager::from_servers(&servers); + + manager.discover_tools().await.expect("discover tools"); + let error = manager + .call_tool( + &mcp_tool_name("broken", "echo"), + Some(json!({"text": "invalid-json"})), + ) + .await + .expect_err("invalid json should fail"); + + match error { + McpServerManagerError::InvalidResponse { + server_name, + method, + details, + } => { + assert_eq!(server_name, "broken"); + assert_eq!(method, "tools/call"); + assert!(details.contains("expected ident") || details.contains("expected value")); + } + other => panic!("expected invalid response error, got {other:?}"), + } + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + + #[test] + fn given_child_exits_after_discovery_when_calling_twice_then_second_call_succeeds_after_reset() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_manager_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("dropping.log"); + let servers = BTreeMap::from([( + "alpha".to_string(), + manager_server_config_with_env( + &script_path, + "alpha", + &log_path, + BTreeMap::from([( + "MCP_EXIT_AFTER_TOOLS_LIST".to_string(), + "1".to_string(), + )]), + ), + )]); + let mut manager = McpServerManager::from_servers(&servers); + + manager.discover_tools().await.expect("discover tools"); + let first_error = manager + .call_tool( + &mcp_tool_name("alpha", "echo"), + Some(json!({"text": "reconnect"})), + ) + .await + .expect_err("first call should fail after transport drops"); + + match first_error { + McpServerManagerError::Transport { + server_name, + method, + source, + } => { + assert_eq!(server_name, "alpha"); + assert_eq!(method, "tools/call"); + assert_eq!(source.kind(), ErrorKind::UnexpectedEof); + } + other => panic!("expected transport error, got {other:?}"), + } + + let response = manager + .call_tool( + &mcp_tool_name("alpha", "echo"), + Some(json!({"text": "reconnect"})), + ) + .await + .expect("second tool call should succeed after reset"); + + assert_eq!( + response + .result + .as_ref() + .and_then(|result| result.structured_content.as_ref()) + .and_then(|value| value.get("server")), + Some(&json!("alpha")) + ); + let log = fs::read_to_string(&log_path).expect("read log"); + assert_eq!( + log.lines().collect::>(), + vec!["initialize", "tools/list", "initialize", "tools/call"] + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + + #[test] + fn given_initialize_hangs_once_when_discover_tools_then_manager_retries_and_succeeds() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_manager_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("initialize-hang.log"); + let marker_path = root.join("initialize-hang.marker"); + let servers = BTreeMap::from([( + "alpha".to_string(), + manager_server_config_with_env( + &script_path, + "alpha", + &log_path, + BTreeMap::from([ + ( + "MCP_FAIL_ONCE_MODE".to_string(), + "initialize_hang".to_string(), + ), + ( + "MCP_FAIL_ONCE_MARKER".to_string(), + marker_path.to_string_lossy().into_owned(), + ), + ]), + ), + )]); + let mut manager = McpServerManager::from_servers(&servers); + + let tools = manager.discover_tools().await.expect("discover tools after retry"); + + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].qualified_name, mcp_tool_name("alpha", "echo")); + let log = fs::read_to_string(&log_path).expect("read log"); + assert_eq!( + log.lines().collect::>(), + vec!["initialize", "initialize-hang", "initialize", "tools/list"] + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + + #[test] + fn given_tool_call_disconnects_once_when_calling_twice_then_manager_resets_and_next_call_succeeds() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_manager_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("tool-call-disconnect.log"); + let marker_path = root.join("tool-call-disconnect.marker"); + let servers = BTreeMap::from([( + "alpha".to_string(), + manager_server_config_with_env( + &script_path, + "alpha", + &log_path, + BTreeMap::from([ + ( + "MCP_FAIL_ONCE_MODE".to_string(), + "tool_call_disconnect".to_string(), + ), + ( + "MCP_FAIL_ONCE_MARKER".to_string(), + marker_path.to_string_lossy().into_owned(), + ), + ]), + ), + )]); + let mut manager = McpServerManager::from_servers(&servers); + + manager.discover_tools().await.expect("discover tools"); + let first_error = manager + .call_tool(&mcp_tool_name("alpha", "echo"), Some(json!({"text": "first"}))) + .await + .expect_err("first tool call should fail when transport drops"); + + match first_error { + McpServerManagerError::Transport { + server_name, + method, + source, + } => { + assert_eq!(server_name, "alpha"); + assert_eq!(method, "tools/call"); + assert_eq!(source.kind(), ErrorKind::UnexpectedEof); + } + other => panic!("expected transport error, got {other:?}"), + } + + let response = manager + .call_tool(&mcp_tool_name("alpha", "echo"), Some(json!({"text": "second"}))) + .await + .expect("second tool call should succeed after reset"); + + assert_eq!( + response + .result + .as_ref() + .and_then(|result| result.structured_content.as_ref()) + .and_then(|value| value.get("echoed")), + Some(&json!("second")) + ); + let log = fs::read_to_string(&log_path).expect("read log"); + assert_eq!( + log.lines().collect::>(), + vec![ + "initialize", + "tools/list", + "tools/call", + "tools/call-disconnect", + "initialize", + "tools/call", + ] + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + #[test] fn manager_records_unsupported_non_stdio_servers_without_panicking() { let servers = BTreeMap::from([