diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 33855b7..fc84570 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1198,6 +1198,7 @@ dependencies = [ "pulldown-cmark", "runtime", "rustyline", + "serde", "serde_json", "syntect", "tokio", diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 62e78d8..6f10ff8 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -55,11 +55,12 @@ pub use mcp_client::{ }; pub use mcp_stdio::{ spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse, - ManagedMcpTool, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, - McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, McpListToolsParams, - McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpResource, - McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, McpTool, - McpToolCallContent, McpToolCallParams, McpToolCallResult, UnsupportedMcpServer, + ManagedMcpTool, McpDiscoveryFailure, McpInitializeClientInfo, McpInitializeParams, + McpInitializeResult, McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, + McpListToolsParams, McpListToolsResult, McpReadResourceParams, McpReadResourceResult, + McpResource, McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, + McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult, McpToolDiscoveryReport, + UnsupportedMcpServer, }; pub use oauth::{ clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, diff --git a/rust/crates/runtime/src/mcp_stdio.rs b/rust/crates/runtime/src/mcp_stdio.rs index 7f5456a..f850a87 100644 --- a/rust/crates/runtime/src/mcp_stdio.rs +++ b/rust/crates/runtime/src/mcp_stdio.rs @@ -230,6 +230,19 @@ pub struct UnsupportedMcpServer { pub reason: String, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpDiscoveryFailure { + pub server_name: String, + pub error: String, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct McpToolDiscoveryReport { + pub tools: Vec, + pub failed_servers: Vec, + pub unsupported_servers: Vec, +} + #[derive(Debug)] pub enum McpServerManagerError { Io(io::Error), @@ -397,6 +410,11 @@ impl McpServerManager { &self.unsupported_servers } + #[must_use] + pub fn server_names(&self) -> Vec { + self.servers.keys().cloned().collect() + } + pub async fn discover_tools(&mut self) -> Result, McpServerManagerError> { let server_names = self.servers.keys().cloned().collect::>(); let mut discovered_tools = Vec::new(); @@ -420,6 +438,43 @@ impl McpServerManager { Ok(discovered_tools) } + pub async fn discover_tools_best_effort(&mut self) -> McpToolDiscoveryReport { + let server_names = self.server_names(); + let mut discovered_tools = Vec::new(); + let mut failed_servers = Vec::new(); + + for server_name in server_names { + match self.discover_tools_for_server(&server_name).await { + Ok(server_tools) => { + self.clear_routes_for_server(&server_name); + for tool in server_tools { + self.tool_index.insert( + tool.qualified_name.clone(), + ToolRoute { + server_name: tool.server_name.clone(), + raw_name: tool.raw_name.clone(), + }, + ); + discovered_tools.push(tool); + } + } + Err(error) => { + self.clear_routes_for_server(&server_name); + failed_servers.push(McpDiscoveryFailure { + server_name, + error: error.to_string(), + }); + } + } + } + + McpToolDiscoveryReport { + tools: discovered_tools, + failed_servers, + unsupported_servers: self.unsupported_servers.clone(), + } + } + pub async fn call_tool( &mut self, qualified_tool_name: &str, @@ -437,30 +492,31 @@ impl McpServerManager { self.ensure_server_ready(&route.server_name).await?; let request_id = self.take_request_id(); - let response = { - let server = self.server_mut(&route.server_name)?; - let process = server.process.as_mut().ok_or_else(|| { - McpServerManagerError::InvalidResponse { - server_name: route.server_name.clone(), - method: "tools/call", - details: "server process missing after initialization".to_string(), - } - })?; - Self::run_process_request( - &route.server_name, - "tools/call", - timeout_ms, - process.call_tool( - request_id, - McpToolCallParams { - name: route.raw_name, - arguments, - meta: None, - }, - ), - ) - .await - }; + let response = + { + let server = self.server_mut(&route.server_name)?; + let process = server.process.as_mut().ok_or_else(|| { + McpServerManagerError::InvalidResponse { + server_name: route.server_name.clone(), + method: "tools/call", + details: "server process missing after initialization".to_string(), + } + })?; + Self::run_process_request( + &route.server_name, + "tools/call", + timeout_ms, + process.call_tool( + request_id, + McpToolCallParams { + name: route.raw_name, + arguments, + meta: None, + }, + ), + ) + .await + }; if let Err(error) = &response { if Self::should_reset_server(error) { @@ -471,6 +527,53 @@ impl McpServerManager { response } + pub async fn list_resources( + &mut self, + server_name: &str, + ) -> Result { + let mut attempts = 0; + + loop { + match self.list_resources_once(server_name).await { + Ok(resources) => return Ok(resources), + Err(error) if attempts == 0 && Self::is_retryable_error(&error) => { + self.reset_server(server_name).await?; + attempts += 1; + } + Err(error) => { + if Self::should_reset_server(&error) { + self.reset_server(server_name).await?; + } + return Err(error); + } + } + } + } + + pub async fn read_resource( + &mut self, + server_name: &str, + uri: &str, + ) -> Result { + let mut attempts = 0; + + loop { + match self.read_resource_once(server_name, uri).await { + Ok(resource) => return Ok(resource), + Err(error) if attempts == 0 && Self::is_retryable_error(&error) => { + self.reset_server(server_name).await?; + attempts += 1; + } + Err(error) => { + if Self::should_reset_server(&error) { + self.reset_server(server_name).await?; + } + return Err(error); + } + } + } + } + pub async fn shutdown(&mut self) -> Result<(), McpServerManagerError> { let server_names = self.servers.keys().cloned().collect::>(); for server_name in server_names { @@ -507,12 +610,12 @@ impl McpServerManager { } fn tool_call_timeout_ms(&self, server_name: &str) -> Result { - let server = self - .servers - .get(server_name) - .ok_or_else(|| McpServerManagerError::UnknownServer { - server_name: server_name.to_string(), - })?; + let server = + self.servers + .get(server_name) + .ok_or_else(|| McpServerManagerError::UnknownServer { + server_name: server_name.to_string(), + })?; match &server.bootstrap.transport { McpClientTransport::Stdio(transport) => Ok(transport.resolved_tool_call_timeout_ms()), other => Err(McpServerManagerError::InvalidResponse { @@ -595,14 +698,13 @@ impl McpServerManager { }); } - let result = - response - .result - .ok_or_else(|| McpServerManagerError::InvalidResponse { - server_name: server_name.to_string(), - method: "tools/list", - details: "missing result payload".to_string(), - })?; + let result = response + .result + .ok_or_else(|| McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "tools/list", + details: "missing result payload".to_string(), + })?; for tool in result.tools { let qualified_name = mcp_tool_name(server_name, &tool.name); @@ -623,6 +725,118 @@ impl McpServerManager { Ok(discovered_tools) } + async fn list_resources_once( + &mut self, + server_name: &str, + ) -> Result { + self.ensure_server_ready(server_name).await?; + + let mut resources = Vec::new(); + let mut cursor = None; + loop { + let request_id = self.take_request_id(); + let response = { + let server = self.server_mut(server_name)?; + let process = server.process.as_mut().ok_or_else(|| { + McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "resources/list", + details: "server process missing after initialization".to_string(), + } + })?; + Self::run_process_request( + server_name, + "resources/list", + MCP_LIST_TOOLS_TIMEOUT_MS, + process.list_resources( + request_id, + Some(McpListResourcesParams { + cursor: cursor.clone(), + }), + ), + ) + .await? + }; + + if let Some(error) = response.error { + return Err(McpServerManagerError::JsonRpc { + server_name: server_name.to_string(), + method: "resources/list", + error, + }); + } + + let result = response + .result + .ok_or_else(|| McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "resources/list", + details: "missing result payload".to_string(), + })?; + + resources.extend(result.resources); + + match result.next_cursor { + Some(next_cursor) => cursor = Some(next_cursor), + None => break, + } + } + + Ok(McpListResourcesResult { + resources, + next_cursor: None, + }) + } + + async fn read_resource_once( + &mut self, + server_name: &str, + uri: &str, + ) -> Result { + self.ensure_server_ready(server_name).await?; + + let request_id = self.take_request_id(); + let response = + { + let server = self.server_mut(server_name)?; + let process = server.process.as_mut().ok_or_else(|| { + McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "resources/read", + details: "server process missing after initialization".to_string(), + } + })?; + Self::run_process_request( + server_name, + "resources/read", + MCP_LIST_TOOLS_TIMEOUT_MS, + process.read_resource( + request_id, + McpReadResourceParams { + uri: uri.to_string(), + }, + ), + ) + .await? + }; + + if let Some(error) = response.error { + return Err(McpServerManagerError::JsonRpc { + server_name: server_name.to_string(), + method: "resources/read", + error, + }); + } + + response + .result + .ok_or_else(|| McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "resources/read", + details: "missing result payload".to_string(), + }) + } + async fn reset_server(&mut self, server_name: &str) -> Result<(), McpServerManagerError> { let mut process = { let server = self.server_mut(server_name)?; @@ -1614,10 +1828,7 @@ mod tests { let script_path = write_jsonrpc_script(); let transport = script_transport_with_env( &script_path, - BTreeMap::from([( - "MCP_LOWERCASE_CONTENT_LENGTH".to_string(), - "1".to_string(), - )]), + BTreeMap::from([("MCP_LOWERCASE_CONTENT_LENGTH".to_string(), "1".to_string())]), ); let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly"); @@ -1657,10 +1868,7 @@ mod tests { let script_path = write_jsonrpc_script(); let transport = script_transport_with_env( &script_path, - BTreeMap::from([( - "MCP_MISMATCHED_RESPONSE_ID".to_string(), - "1".to_string(), - )]), + BTreeMap::from([("MCP_MISMATCHED_RESPONSE_ID".to_string(), "1".to_string())]), ); let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly"); @@ -1971,7 +2179,10 @@ mod tests { manager.discover_tools().await.expect("discover tools"); let error = manager - .call_tool(&mcp_tool_name("slow", "echo"), Some(json!({"text": "slow"}))) + .call_tool( + &mcp_tool_name("slow", "echo"), + Some(json!({"text": "slow"})), + ) .await .expect_err("slow tool call should time out"); @@ -2036,7 +2247,9 @@ mod tests { } => { assert_eq!(server_name, "broken"); assert_eq!(method, "tools/call"); - assert!(details.contains("expected ident") || details.contains("expected value")); + assert!( + details.contains("expected ident") || details.contains("expected value") + ); } other => panic!("expected invalid response error, got {other:?}"), } @@ -2047,7 +2260,8 @@ mod tests { } #[test] - fn given_child_exits_after_discovery_when_calling_twice_then_second_call_succeeds_after_reset() { + fn given_child_exits_after_discovery_when_calling_twice_then_second_call_succeeds_after_reset() + { let runtime = Builder::new_current_thread() .enable_all() .build() @@ -2062,10 +2276,7 @@ mod tests { &script_path, "alpha", &log_path, - BTreeMap::from([( - "MCP_EXIT_AFTER_TOOLS_LIST".to_string(), - "1".to_string(), - )]), + BTreeMap::from([("MCP_EXIT_AFTER_TOOLS_LIST".to_string(), "1".to_string())]), ), )]); let mut manager = McpServerManager::from_servers(&servers); @@ -2150,7 +2361,10 @@ mod tests { )]); let mut manager = McpServerManager::from_servers(&servers); - let tools = manager.discover_tools().await.expect("discover tools after retry"); + let tools = manager + .discover_tools() + .await + .expect("discover tools after retry"); assert_eq!(tools.len(), 1); assert_eq!(tools[0].qualified_name, mcp_tool_name("alpha", "echo")); @@ -2166,7 +2380,8 @@ mod tests { } #[test] - fn given_tool_call_disconnects_once_when_calling_twice_then_manager_resets_and_next_call_succeeds() { + fn given_tool_call_disconnects_once_when_calling_twice_then_manager_resets_and_next_call_succeeds( + ) { let runtime = Builder::new_current_thread() .enable_all() .build() @@ -2198,7 +2413,10 @@ mod tests { manager.discover_tools().await.expect("discover tools"); let first_error = manager - .call_tool(&mcp_tool_name("alpha", "echo"), Some(json!({"text": "first"}))) + .call_tool( + &mcp_tool_name("alpha", "echo"), + Some(json!({"text": "first"})), + ) .await .expect_err("first tool call should fail when transport drops"); @@ -2216,7 +2434,10 @@ mod tests { } let response = manager - .call_tool(&mcp_tool_name("alpha", "echo"), Some(json!({"text": "second"}))) + .call_tool( + &mcp_tool_name("alpha", "echo"), + Some(json!({"text": "second"})), + ) .await .expect("second tool call should succeed after reset"); @@ -2246,6 +2467,103 @@ mod tests { }); } + #[test] + fn manager_lists_and_reads_resources_from_stdio_servers() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("resources.log"); + let servers = BTreeMap::from([( + "alpha".to_string(), + manager_server_config(&script_path, "alpha", &log_path), + )]); + let mut manager = McpServerManager::from_servers(&servers); + + let listed = manager + .list_resources("alpha") + .await + .expect("list resources"); + assert_eq!(listed.resources.len(), 1); + assert_eq!(listed.resources[0].uri, "file://guide.txt"); + + let read = manager + .read_resource("alpha", "file://guide.txt") + .await + .expect("read resource"); + assert_eq!(read.contents.len(), 1); + assert_eq!( + read.contents[0].text.as_deref(), + Some("contents for file://guide.txt") + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + + #[test] + fn manager_discovery_report_keeps_healthy_servers_when_one_server_fails() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_manager_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let alpha_log = root.join("alpha.log"); + let servers = BTreeMap::from([ + ( + "alpha".to_string(), + manager_server_config(&script_path, "alpha", &alpha_log), + ), + ( + "broken".to_string(), + ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Stdio(McpStdioServerConfig { + command: "python3".to_string(), + args: vec!["-c".to_string(), "import sys; sys.exit(0)".to_string()], + env: BTreeMap::new(), + tool_call_timeout_ms: None, + }), + }, + ), + ]); + let mut manager = McpServerManager::from_servers(&servers); + + let report = manager.discover_tools_best_effort().await; + + assert_eq!(report.tools.len(), 1); + assert_eq!( + report.tools[0].qualified_name, + mcp_tool_name("alpha", "echo") + ); + assert_eq!(report.failed_servers.len(), 1); + assert_eq!(report.failed_servers[0].server_name, "broken"); + assert!(report.failed_servers[0].error.contains("initialize")); + + let response = manager + .call_tool(&mcp_tool_name("alpha", "echo"), Some(json!({"text": "ok"}))) + .await + .expect("healthy server should remain callable"); + assert_eq!( + response + .result + .as_ref() + .and_then(|result| result.structured_content.as_ref()) + .and_then(|value| value.get("echoed")), + Some(&json!("ok")) + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + #[test] fn manager_records_unsupported_non_stdio_servers_without_panicking() { let servers = BTreeMap::from([ diff --git a/rust/crates/rusty-claude-cli/Cargo.toml b/rust/crates/rusty-claude-cli/Cargo.toml index f10a41e..6cb7632 100644 --- a/rust/crates/rusty-claude-cli/Cargo.toml +++ b/rust/crates/rusty-claude-cli/Cargo.toml @@ -18,6 +18,7 @@ pulldown-cmark = "0.13" rustyline = "15" runtime = { path = "../runtime" } plugins = { path = "../plugins" } +serde = { version = "1", features = ["derive"] } serde_json.workspace = true syntect = "5" tokio = { version = "1", features = ["rt-multi-thread", "signal", "time"] } diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index fdb957c..40edf0e 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -42,12 +42,14 @@ use runtime::{ clear_oauth_credentials, generate_pkce_pair, generate_state, load_system_prompt, parse_oauth_callback_request_target, resolve_sandbox_status, save_oauth_credentials, ApiClient, ApiRequest, AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, - ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig, - OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent, - RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, + ConversationMessage, ConversationRuntime, McpServerManager, McpTool, MessageRole, + OAuthAuthorizationRequest, OAuthConfig, OAuthTokenExchangeRequest, PermissionMode, + PermissionPolicy, ProjectContext, PromptCacheEvent, RuntimeError, Session, TokenUsage, + ToolError, ToolExecutor, UsageTracker, }; +use serde::Deserialize; use serde_json::json; -use tools::GlobalToolRegistry; +use tools::{GlobalToolRegistry, RuntimeToolDefinition, ToolSearchOutput}; const DEFAULT_MODEL: &str = "claude-opus-4-6"; fn max_tokens_for_model(model: &str) -> u32 { @@ -576,11 +578,17 @@ fn current_tool_registry() -> Result { let cwd = env::current_dir().map_err(|error| error.to_string())?; let loader = ConfigLoader::default_for(&cwd); let runtime_config = loader.load().map_err(|error| error.to_string())?; - let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); - let plugin_tools = plugin_manager - .aggregated_tools() + let state = build_runtime_plugin_state_with_loader(&cwd, &loader, &runtime_config) .map_err(|error| error.to_string())?; - GlobalToolRegistry::with_plugin_tools(plugin_tools) + let registry = state.tool_registry.clone(); + if let Some(mcp_state) = state.mcp_state { + mcp_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .shutdown() + .map_err(|error| error.to_string())?; + } + Ok(registry) } fn parse_permission_mode_arg(value: &str) -> Result { @@ -1491,23 +1499,35 @@ struct RuntimePluginState { feature_config: runtime::RuntimeFeatureConfig, tool_registry: GlobalToolRegistry, plugin_registry: PluginRegistry, + mcp_state: Option>>, +} + +struct RuntimeMcpState { + runtime: tokio::runtime::Runtime, + manager: McpServerManager, + pending_servers: Vec, } struct BuiltRuntime { runtime: Option>, plugin_registry: PluginRegistry, plugins_active: bool, + mcp_state: Option>>, + mcp_active: bool, } impl BuiltRuntime { fn new( runtime: ConversationRuntime, plugin_registry: PluginRegistry, + mcp_state: Option>>, ) -> Self { Self { runtime: Some(runtime), plugin_registry, plugins_active: true, + mcp_state, + mcp_active: true, } } @@ -1527,6 +1547,19 @@ impl BuiltRuntime { } Ok(()) } + + fn shutdown_mcp(&mut self) -> Result<(), Box> { + if self.mcp_active { + if let Some(mcp_state) = &self.mcp_state { + mcp_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .shutdown()?; + } + self.mcp_active = false; + } + Ok(()) + } } impl Deref for BuiltRuntime { @@ -1549,10 +1582,284 @@ impl DerefMut for BuiltRuntime { impl Drop for BuiltRuntime { fn drop(&mut self) { + let _ = self.shutdown_mcp(); let _ = self.shutdown_plugins(); } } +#[derive(Debug, Deserialize)] +struct ToolSearchRequest { + query: String, + max_results: Option, +} + +#[derive(Debug, Deserialize)] +struct McpToolRequest { + #[serde(rename = "qualifiedName")] + qualified_name: Option, + tool: Option, + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct ListMcpResourcesRequest { + server: Option, +} + +#[derive(Debug, Deserialize)] +struct ReadMcpResourceRequest { + server: String, + uri: String, +} + +impl RuntimeMcpState { + fn new( + runtime_config: &runtime::RuntimeConfig, + ) -> Result, Box> { + let mut manager = McpServerManager::from_runtime_config(runtime_config); + if manager.server_names().is_empty() && manager.unsupported_servers().is_empty() { + return Ok(None); + } + + let runtime = tokio::runtime::Runtime::new()?; + let discovery = runtime.block_on(manager.discover_tools_best_effort()); + let pending_servers = discovery + .failed_servers + .iter() + .map(|failure| failure.server_name.clone()) + .chain( + discovery + .unsupported_servers + .iter() + .map(|server| server.server_name.clone()), + ) + .collect::>() + .into_iter() + .collect::>(); + + Ok(Some(( + Self { + runtime, + manager, + pending_servers, + }, + discovery, + ))) + } + + fn shutdown(&mut self) -> Result<(), Box> { + self.runtime.block_on(self.manager.shutdown())?; + Ok(()) + } + + fn pending_servers(&self) -> Option> { + (!self.pending_servers.is_empty()).then(|| self.pending_servers.clone()) + } + + fn server_names(&self) -> Vec { + self.manager.server_names() + } + + fn call_tool( + &mut self, + qualified_tool_name: &str, + arguments: Option, + ) -> Result { + let response = self + .runtime + .block_on(self.manager.call_tool(qualified_tool_name, arguments)) + .map_err(|error| ToolError::new(error.to_string()))?; + if let Some(error) = response.error { + return Err(ToolError::new(format!( + "MCP tool `{qualified_tool_name}` returned JSON-RPC error: {} ({})", + error.message, error.code + ))); + } + + let result = response.result.ok_or_else(|| { + ToolError::new(format!( + "MCP tool `{qualified_tool_name}` returned no result payload" + )) + })?; + serde_json::to_string_pretty(&result).map_err(|error| ToolError::new(error.to_string())) + } + + fn list_resources_for_server(&mut self, server_name: &str) -> Result { + let result = self + .runtime + .block_on(self.manager.list_resources(server_name)) + .map_err(|error| ToolError::new(error.to_string()))?; + serde_json::to_string_pretty(&json!({ + "server": server_name, + "resources": result.resources, + })) + .map_err(|error| ToolError::new(error.to_string())) + } + + fn list_resources_for_all_servers(&mut self) -> Result { + let mut resources = Vec::new(); + let mut failures = Vec::new(); + + for server_name in self.server_names() { + match self + .runtime + .block_on(self.manager.list_resources(&server_name)) + { + Ok(result) => resources.push(json!({ + "server": server_name, + "resources": result.resources, + })), + Err(error) => failures.push(json!({ + "server": server_name, + "error": error.to_string(), + })), + } + } + + if resources.is_empty() && !failures.is_empty() { + let message = failures + .iter() + .filter_map(|failure| failure.get("error").and_then(serde_json::Value::as_str)) + .collect::>() + .join("; "); + return Err(ToolError::new(message)); + } + + serde_json::to_string_pretty(&json!({ + "resources": resources, + "failures": failures, + })) + .map_err(|error| ToolError::new(error.to_string())) + } + + fn read_resource(&mut self, server_name: &str, uri: &str) -> Result { + let result = self + .runtime + .block_on(self.manager.read_resource(server_name, uri)) + .map_err(|error| ToolError::new(error.to_string()))?; + serde_json::to_string_pretty(&json!({ + "server": server_name, + "contents": result.contents, + })) + .map_err(|error| ToolError::new(error.to_string())) + } +} + +fn build_runtime_mcp_state( + runtime_config: &runtime::RuntimeConfig, +) -> Result< + ( + Option>>, + Vec, + ), + Box, +> { + let Some((mcp_state, discovery)) = RuntimeMcpState::new(runtime_config)? else { + return Ok((None, Vec::new())); + }; + + let mut runtime_tools = discovery + .tools + .iter() + .map(mcp_runtime_tool_definition) + .collect::>(); + if !mcp_state.server_names().is_empty() { + runtime_tools.extend(mcp_wrapper_tool_definitions()); + } + + Ok((Some(Arc::new(Mutex::new(mcp_state))), runtime_tools)) +} + +fn mcp_runtime_tool_definition(tool: &runtime::ManagedMcpTool) -> RuntimeToolDefinition { + RuntimeToolDefinition { + name: tool.qualified_name.clone(), + description: Some( + tool.tool + .description + .clone() + .unwrap_or_else(|| format!("Invoke MCP tool `{}`.", tool.qualified_name)), + ), + input_schema: tool + .tool + .input_schema + .clone() + .unwrap_or_else(|| json!({ "type": "object", "additionalProperties": true })), + required_permission: permission_mode_for_mcp_tool(&tool.tool), + } +} + +fn mcp_wrapper_tool_definitions() -> Vec { + vec![ + RuntimeToolDefinition { + name: "MCPTool".to_string(), + description: Some( + "Call a configured MCP tool by its qualified name and JSON arguments.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "qualifiedName": { "type": "string" }, + "arguments": {} + }, + "required": ["qualifiedName"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + RuntimeToolDefinition { + name: "ListMcpResourcesTool".to_string(), + description: Some( + "List MCP resources from one configured server or from every connected server." + .to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "server": { "type": "string" } + }, + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + RuntimeToolDefinition { + name: "ReadMcpResourceTool".to_string(), + description: Some("Read a specific MCP resource from a configured server.".to_string()), + input_schema: json!({ + "type": "object", + "properties": { + "server": { "type": "string" }, + "uri": { "type": "string" } + }, + "required": ["server", "uri"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ] +} + +fn permission_mode_for_mcp_tool(tool: &McpTool) -> PermissionMode { + let read_only = mcp_annotation_flag(tool, "readOnlyHint"); + let destructive = mcp_annotation_flag(tool, "destructiveHint"); + let open_world = mcp_annotation_flag(tool, "openWorldHint"); + + if read_only && !destructive && !open_world { + PermissionMode::ReadOnly + } else if destructive || open_world { + PermissionMode::DangerFullAccess + } else { + PermissionMode::WorkspaceWrite + } +} + +fn mcp_annotation_flag(tool: &McpTool, key: &str) -> bool { + tool.annotations + .as_ref() + .and_then(|annotations| annotations.get(key)) + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) +} + struct HookAbortMonitor { stop_tx: Option>, join_handle: Option>, @@ -3375,11 +3682,14 @@ fn build_runtime_plugin_state_with_loader( .feature_config() .clone() .with_hooks(runtime_config.hooks().merged(&plugin_hook_config)); - let tool_registry = GlobalToolRegistry::with_plugin_tools(plugin_registry.aggregated_tools()?)?; + let (mcp_state, runtime_tools) = build_runtime_mcp_state(runtime_config)?; + let tool_registry = GlobalToolRegistry::with_plugin_tools(plugin_registry.aggregated_tools()?)? + .with_runtime_tools(runtime_tools)?; Ok(RuntimePluginState { feature_config, tool_registry, plugin_registry, + mcp_state, }) } @@ -3801,6 +4111,7 @@ fn build_runtime_with_plugin_state( feature_config, tool_registry, plugin_registry, + mcp_state, } = runtime_plugin_state; plugin_registry.initialize()?; let mut runtime = ConversationRuntime::new_with_features( @@ -3814,7 +4125,12 @@ fn build_runtime_with_plugin_state( tool_registry.clone(), progress_reporter, )?, - CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry.clone()), + CliToolExecutor::new( + allowed_tools.clone(), + emit_output, + tool_registry.clone(), + mcp_state.clone(), + ), permission_policy(permission_mode, &feature_config, &tool_registry) .map_err(std::io::Error::other)?, system_prompt, @@ -3823,7 +4139,7 @@ fn build_runtime_with_plugin_state( if emit_output { runtime = runtime.with_hook_progress_reporter(Box::new(CliHookProgressReporter)); } - Ok(BuiltRuntime::new(runtime, plugin_registry)) + Ok(BuiltRuntime::new(runtime, plugin_registry, mcp_state)) } struct CliHookProgressReporter; @@ -4758,6 +5074,7 @@ struct CliToolExecutor { emit_output: bool, allowed_tools: Option, tool_registry: GlobalToolRegistry, + mcp_state: Option>>, } impl CliToolExecutor { @@ -4765,12 +5082,72 @@ impl CliToolExecutor { allowed_tools: Option, emit_output: bool, tool_registry: GlobalToolRegistry, + mcp_state: Option>>, ) -> Self { Self { renderer: TerminalRenderer::new(), emit_output, allowed_tools, tool_registry, + mcp_state, + } + } + + fn execute_search_tool(&self, value: serde_json::Value) -> Result { + let input: ToolSearchRequest = serde_json::from_value(value) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + let pending_mcp_servers = self.mcp_state.as_ref().and_then(|state| { + state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .pending_servers() + }); + serde_json::to_string_pretty(&self.tool_registry.search( + &input.query, + input.max_results.unwrap_or(5), + pending_mcp_servers, + )) + .map_err(|error| ToolError::new(error.to_string())) + } + + fn execute_runtime_tool( + &self, + tool_name: &str, + value: serde_json::Value, + ) -> Result { + let Some(mcp_state) = &self.mcp_state else { + return Err(ToolError::new(format!( + "runtime tool `{tool_name}` is unavailable without configured MCP servers" + ))); + }; + let mut mcp_state = mcp_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + + match tool_name { + "MCPTool" => { + let input: McpToolRequest = serde_json::from_value(value) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + let qualified_name = input + .qualified_name + .or(input.tool) + .ok_or_else(|| ToolError::new("missing required field `qualifiedName`"))?; + mcp_state.call_tool(&qualified_name, input.arguments) + } + "ListMcpResourcesTool" => { + let input: ListMcpResourcesRequest = serde_json::from_value(value) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + match input.server { + Some(server_name) => mcp_state.list_resources_for_server(&server_name), + None => mcp_state.list_resources_for_all_servers(), + } + } + "ReadMcpResourceTool" => { + let input: ReadMcpResourceRequest = serde_json::from_value(value) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + mcp_state.read_resource(&input.server, &input.uri) + } + _ => mcp_state.call_tool(tool_name, Some(value)), } } } @@ -4788,7 +5165,16 @@ impl ToolExecutor for CliToolExecutor { } let value = serde_json::from_str(input) .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; - match self.tool_registry.execute(tool_name, &value) { + let result = if tool_name == "ToolSearch" { + self.execute_search_tool(value) + } else if self.tool_registry.has_runtime_tool(tool_name) { + self.execute_runtime_tool(tool_name, value) + } else { + self.tool_registry + .execute(tool_name, &value) + .map_err(ToolError::new) + }; + match result { Ok(output) => { if self.emit_output { let markdown = format_tool_result(tool_name, &output, false); @@ -4800,12 +5186,12 @@ impl ToolExecutor for CliToolExecutor { } Err(error) => { if self.emit_output { - let markdown = format_tool_result(tool_name, &error, true); + let markdown = format_tool_result(tool_name, &error.to_string(), true); self.renderer .stream_markdown(&markdown, &mut io::stdout()) .map_err(|stream_error| ToolError::new(stream_error.to_string()))?; } - Err(ToolError::new(error)) + Err(error) } } } @@ -5006,8 +5392,9 @@ mod tests { resolve_model_alias, resolve_session_reference, response_to_events, resume_supported_slash_commands, run_resume_command, slash_command_completion_candidates_with_sessions, status_context, validate_no_args, - CliAction, CliOutputFormat, GitWorkspaceSummary, InternalPromptProgressEvent, - InternalPromptProgressState, LiveCli, SlashCommand, StatusUsage, DEFAULT_MODEL, + write_mcp_server_fixture, CliAction, CliOutputFormat, CliToolExecutor, GitWorkspaceSummary, + InternalPromptProgressEvent, InternalPromptProgressState, LiveCli, SlashCommand, + StatusUsage, DEFAULT_MODEL, }; use api::{MessageResponse, OutputContentBlock, Usage}; use plugins::{ @@ -5015,7 +5402,7 @@ mod tests { }; use runtime::{ AssistantEvent, ConfigLoader, ContentBlock, ConversationMessage, MessageRole, - PermissionMode, Session, + PermissionMode, Session, ToolExecutor, }; use serde_json::json; use std::fs; @@ -6226,7 +6613,11 @@ UU conflicted.rs", #[test] fn init_template_mentions_detected_rust_workspace() { - let rendered = crate::init::render_init_claude_md(std::path::Path::new(".")); + let _guard = cwd_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../.."); + let rendered = crate::init::render_init_claude_md(&workspace_root); assert!(rendered.contains("# CLAUDE.md")); assert!(rendered.contains("cargo clippy --workspace --all-targets -- -D warnings")); } @@ -6617,6 +7008,111 @@ UU conflicted.rs", let _ = fs::remove_dir_all(source_root); } + #[test] + fn build_runtime_plugin_state_discovers_mcp_tools_and_surfaces_pending_servers() { + let config_home = temp_dir(); + let workspace = temp_dir(); + fs::create_dir_all(&config_home).expect("config home"); + fs::create_dir_all(&workspace).expect("workspace"); + let script_path = workspace.join("fixture-mcp.py"); + write_mcp_server_fixture(&script_path); + fs::write( + config_home.join("settings.json"), + format!( + r#"{{ + "mcpServers": {{ + "alpha": {{ + "command": "python3", + "args": ["{}"] + }}, + "broken": {{ + "command": "python3", + "args": ["-c", "import sys; sys.exit(0)"] + }} + }} + }}"#, + script_path.to_string_lossy() + ), + ) + .expect("write mcp settings"); + + let loader = ConfigLoader::new(&workspace, &config_home); + let runtime_config = loader.load().expect("runtime config should load"); + let state = build_runtime_plugin_state_with_loader(&workspace, &loader, &runtime_config) + .expect("runtime plugin state should load"); + + let allowed = state + .tool_registry + .normalize_allowed_tools(&["mcp__alpha__echo".to_string(), "MCPTool".to_string()]) + .expect("mcp tools should be allow-listable") + .expect("allow-list should exist"); + assert!(allowed.contains("mcp__alpha__echo")); + assert!(allowed.contains("MCPTool")); + + let mut executor = CliToolExecutor::new( + None, + false, + state.tool_registry.clone(), + state.mcp_state.clone(), + ); + + let tool_output = executor + .execute("mcp__alpha__echo", r#"{"text":"hello"}"#) + .expect("discovered mcp tool should execute"); + let tool_json: serde_json::Value = + serde_json::from_str(&tool_output).expect("tool output should be json"); + assert_eq!(tool_json["structuredContent"]["echoed"], "hello"); + + let wrapped_output = executor + .execute( + "MCPTool", + r#"{"qualifiedName":"mcp__alpha__echo","arguments":{"text":"wrapped"}}"#, + ) + .expect("generic mcp wrapper should execute"); + let wrapped_json: serde_json::Value = + serde_json::from_str(&wrapped_output).expect("wrapped output should be json"); + assert_eq!(wrapped_json["structuredContent"]["echoed"], "wrapped"); + + let search_output = executor + .execute("ToolSearch", r#"{"query":"alpha echo","max_results":5}"#) + .expect("tool search should execute"); + let search_json: serde_json::Value = + serde_json::from_str(&search_output).expect("search output should be json"); + assert_eq!(search_json["matches"][0], "mcp__alpha__echo"); + assert_eq!(search_json["pending_mcp_servers"][0], "broken"); + + let listed = executor + .execute("ListMcpResourcesTool", r#"{"server":"alpha"}"#) + .expect("resources should list"); + let listed_json: serde_json::Value = + serde_json::from_str(&listed).expect("resource output should be json"); + assert_eq!(listed_json["resources"][0]["uri"], "file://guide.txt"); + + let read = executor + .execute( + "ReadMcpResourceTool", + r#"{"server":"alpha","uri":"file://guide.txt"}"#, + ) + .expect("resource should read"); + let read_json: serde_json::Value = + serde_json::from_str(&read).expect("resource read output should be json"); + assert_eq!( + read_json["contents"][0]["text"], + "contents for file://guide.txt" + ); + + if let Some(mcp_state) = state.mcp_state { + mcp_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .shutdown() + .expect("mcp shutdown should succeed"); + } + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(workspace); + } + #[test] fn build_runtime_runs_plugin_lifecycle_init_and_shutdown() { let config_home = temp_dir(); @@ -6671,6 +7167,105 @@ UU conflicted.rs", } } +fn write_mcp_server_fixture(script_path: &Path) { + let script = [ + "#!/usr/bin/env python3", + "import json, sys", + "", + "def read_message():", + " header = b''", + r" while not header.endswith(b'\r\n\r\n'):", + " chunk = sys.stdin.buffer.read(1)", + " if not chunk:", + " return None", + " header += chunk", + " length = 0", + r" for line in header.decode().split('\r\n'):", + r" if line.lower().startswith('content-length:'):", + " length = int(line.split(':', 1)[1].strip())", + " payload = sys.stdin.buffer.read(length)", + " return json.loads(payload.decode())", + "", + "def send_message(message):", + " payload = json.dumps(message).encode()", + r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)", + " sys.stdout.buffer.flush()", + "", + "while True:", + " request = read_message()", + " if request is None:", + " break", + " method = request['method']", + " if method == 'initialize':", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'protocolVersion': request['params']['protocolVersion'],", + " 'capabilities': {'tools': {}, 'resources': {}},", + " 'serverInfo': {'name': 'fixture', 'version': '1.0.0'}", + " }", + " })", + " elif method == 'tools/list':", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'tools': [", + " {", + " 'name': 'echo',", + " 'description': 'Echo from MCP fixture',", + " 'inputSchema': {", + " 'type': 'object',", + " 'properties': {'text': {'type': 'string'}},", + " 'required': ['text'],", + " 'additionalProperties': False", + " },", + " 'annotations': {'readOnlyHint': True}", + " }", + " ]", + " }", + " })", + " elif method == 'tools/call':", + " args = request['params'].get('arguments') or {}", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'content': [{'type': 'text', 'text': f\"echo:{args.get('text', '')}\"}],", + " 'structuredContent': {'echoed': args.get('text', '')},", + " 'isError': False", + " }", + " })", + " elif method == 'resources/list':", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'resources': [{'uri': 'file://guide.txt', 'name': 'guide', 'mimeType': 'text/plain'}]", + " }", + " })", + " elif method == 'resources/read':", + " uri = request['params']['uri']", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'contents': [{'uri': uri, 'mimeType': 'text/plain', 'text': f'contents for {uri}'}]", + " }", + " })", + " else:", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'error': {'code': -32601, 'message': method}", + " })", + "", + ] + .join("\n"); + fs::write(script_path, script).expect("mcp fixture script should write"); +} + #[cfg(test)] mod sandbox_report_tests { use super::{format_sandbox_report, HookAbortMonitor}; diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index c12e5a4..81e4121 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -59,6 +59,15 @@ pub struct ToolSpec { #[derive(Debug, Clone, PartialEq)] pub struct GlobalToolRegistry { plugin_tools: Vec, + runtime_tools: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct RuntimeToolDefinition { + pub name: String, + pub description: Option, + pub input_schema: Value, + pub required_permission: PermissionMode, } impl GlobalToolRegistry { @@ -66,6 +75,7 @@ impl GlobalToolRegistry { pub fn builtin() -> Self { Self { plugin_tools: Vec::new(), + runtime_tools: Vec::new(), } } @@ -88,7 +98,37 @@ impl GlobalToolRegistry { } } - Ok(Self { plugin_tools }) + Ok(Self { + plugin_tools, + runtime_tools: Vec::new(), + }) + } + + pub fn with_runtime_tools( + mut self, + runtime_tools: Vec, + ) -> Result { + let mut seen_names = mvp_tool_specs() + .into_iter() + .map(|spec| spec.name.to_string()) + .chain( + self.plugin_tools + .iter() + .map(|tool| tool.definition().name.clone()), + ) + .collect::>(); + + for tool in &runtime_tools { + if !seen_names.insert(tool.name.clone()) { + return Err(format!( + "runtime tool `{}` conflicts with an existing tool name", + tool.name + )); + } + } + + self.runtime_tools = runtime_tools; + Ok(self) } pub fn normalize_allowed_tools( @@ -108,6 +148,7 @@ impl GlobalToolRegistry { .iter() .map(|tool| tool.definition().name.clone()), ) + .chain(self.runtime_tools.iter().map(|tool| tool.name.clone())) .collect::>(); let mut name_map = canonical_names .iter() @@ -154,6 +195,15 @@ impl GlobalToolRegistry { description: Some(spec.description.to_string()), input_schema: spec.input_schema, }); + let runtime = self + .runtime_tools + .iter() + .filter(|tool| allowed_tools.is_none_or(|allowed| allowed.contains(tool.name.as_str()))) + .map(|tool| ToolDefinition { + name: tool.name.clone(), + description: tool.description.clone(), + input_schema: tool.input_schema.clone(), + }); let plugin = self .plugin_tools .iter() @@ -166,7 +216,7 @@ impl GlobalToolRegistry { description: tool.definition().description.clone(), input_schema: tool.definition().input_schema.clone(), }); - builtin.chain(plugin).collect() + builtin.chain(runtime).chain(plugin).collect() } pub fn permission_specs( @@ -177,6 +227,11 @@ impl GlobalToolRegistry { .into_iter() .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) .map(|spec| (spec.name.to_string(), spec.required_permission)); + let runtime = self + .runtime_tools + .iter() + .filter(|tool| allowed_tools.is_none_or(|allowed| allowed.contains(tool.name.as_str()))) + .map(|tool| (tool.name.clone(), tool.required_permission)); let plugin = self .plugin_tools .iter() @@ -189,7 +244,32 @@ impl GlobalToolRegistry { .map(|permission| (tool.definition().name.clone(), permission)) }) .collect::, _>>()?; - Ok(builtin.chain(plugin).collect()) + Ok(builtin.chain(runtime).chain(plugin).collect()) + } + + #[must_use] + pub fn has_runtime_tool(&self, name: &str) -> bool { + self.runtime_tools.iter().any(|tool| tool.name == name) + } + + #[must_use] + pub fn search( + &self, + query: &str, + max_results: usize, + pending_mcp_servers: Option>, + ) -> ToolSearchOutput { + let query = query.trim().to_string(); + let normalized_query = normalize_tool_search_query(&query); + let matches = search_tool_specs(&query, max_results.max(1), &self.searchable_tool_specs()); + + ToolSearchOutput { + matches, + query, + normalized_query, + total_deferred_tools: self.searchable_tool_specs().len(), + pending_mcp_servers, + } } pub fn execute(&self, name: &str, input: &Value) -> Result { @@ -203,6 +283,24 @@ impl GlobalToolRegistry { .execute(input) .map_err(|error| error.to_string()) } + + fn searchable_tool_specs(&self) -> Vec { + let builtin = deferred_tool_specs() + .into_iter() + .map(|spec| SearchableToolSpec { + name: spec.name.to_string(), + description: spec.description.to_string(), + }); + let runtime = self.runtime_tools.iter().map(|tool| SearchableToolSpec { + name: tool.name.clone(), + description: tool.description.clone().unwrap_or_default(), + }); + let plugin = self.plugin_tools.iter().map(|tool| SearchableToolSpec { + name: tool.definition().name.clone(), + description: tool.definition().description.clone().unwrap_or_default(), + }); + builtin.chain(runtime).chain(plugin).collect() + } } fn normalize_tool_name(value: &str) -> String { @@ -946,8 +1044,8 @@ struct AgentJob { allowed_tools: BTreeSet, } -#[derive(Debug, Serialize)] -struct ToolSearchOutput { +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub struct ToolSearchOutput { matches: Vec, query: String, normalized_query: String, @@ -1031,6 +1129,12 @@ struct PlanModeOutput { current_local_mode: Option, } +#[derive(Debug, Clone)] +struct SearchableToolSpec { + name: String, + description: String, +} + #[derive(Debug, Serialize)] struct StructuredOutputResult { data: String, @@ -2163,19 +2267,7 @@ fn final_assistant_text(summary: &runtime::TurnSummary) -> String { #[allow(clippy::needless_pass_by_value)] fn execute_tool_search(input: ToolSearchInput) -> ToolSearchOutput { - let deferred = deferred_tool_specs(); - let max_results = input.max_results.unwrap_or(5).max(1); - let query = input.query.trim().to_string(); - let normalized_query = normalize_tool_search_query(&query); - let matches = search_tool_specs(&query, max_results, &deferred); - - ToolSearchOutput { - matches, - query, - normalized_query, - total_deferred_tools: deferred.len(), - pending_mcp_servers: None, - } + GlobalToolRegistry::builtin().search(&input.query, input.max_results.unwrap_or(5), None) } fn deferred_tool_specs() -> Vec { @@ -2190,7 +2282,7 @@ fn deferred_tool_specs() -> Vec { .collect() } -fn search_tool_specs(query: &str, max_results: usize, specs: &[ToolSpec]) -> Vec { +fn search_tool_specs(query: &str, max_results: usize, specs: &[SearchableToolSpec]) -> Vec { let lowered = query.to_lowercase(); if let Some(selection) = lowered.strip_prefix("select:") { return selection @@ -2201,8 +2293,8 @@ fn search_tool_specs(query: &str, max_results: usize, specs: &[ToolSpec]) -> Vec let wanted = canonical_tool_token(wanted); specs .iter() - .find(|spec| canonical_tool_token(spec.name) == wanted) - .map(|spec| spec.name.to_string()) + .find(|spec| canonical_tool_token(&spec.name) == wanted) + .map(|spec| spec.name.clone()) }) .take(max_results) .collect(); @@ -2229,8 +2321,8 @@ fn search_tool_specs(query: &str, max_results: usize, specs: &[ToolSpec]) -> Vec .iter() .filter_map(|spec| { let name = spec.name.to_lowercase(); - let canonical_name = canonical_tool_token(spec.name); - let normalized_description = normalize_tool_search_query(spec.description); + let canonical_name = canonical_tool_token(&spec.name); + let normalized_description = normalize_tool_search_query(&spec.description); let haystack = format!( "{name} {} {canonical_name}", spec.description.to_lowercase() @@ -2263,7 +2355,7 @@ fn search_tool_specs(query: &str, max_results: usize, specs: &[ToolSpec]) -> Vec if score == 0 && !lowered.is_empty() { return None; } - Some((score, spec.name.to_string())) + Some((score, spec.name.clone())) }) .collect::>(); @@ -3424,7 +3516,7 @@ mod tests { use super::{ agent_permission_policy, allowed_tools_for_subagent, execute_agent_with_spawn, execute_tool, final_assistant_text, mvp_tool_specs, permission_mode_from_plugin, - persist_agent_terminal_state, push_output_block, AgentInput, AgentJob, + persist_agent_terminal_state, push_output_block, AgentInput, AgentJob, GlobalToolRegistry, SubagentToolExecutor, }; use api::OutputContentBlock; @@ -3486,6 +3578,48 @@ mod tests { assert!(empty_permission.contains("unsupported plugin permission: ")); } + #[test] + fn runtime_tools_extend_registry_definitions_permissions_and_search() { + let registry = GlobalToolRegistry::builtin() + .with_runtime_tools(vec![super::RuntimeToolDefinition { + name: "mcp__demo__echo".to_string(), + description: Some("Echo text from the demo MCP server".to_string()), + input_schema: json!({ + "type": "object", + "properties": { "text": { "type": "string" } }, + "additionalProperties": false + }), + required_permission: runtime::PermissionMode::ReadOnly, + }]) + .expect("runtime tools should register"); + + let allowed = registry + .normalize_allowed_tools(&["mcp__demo__echo".to_string()]) + .expect("runtime tool should be allow-listable") + .expect("allow-list should be populated"); + assert!(allowed.contains("mcp__demo__echo")); + + let definitions = registry.definitions(Some(&allowed)); + assert_eq!(definitions.len(), 1); + assert_eq!(definitions[0].name, "mcp__demo__echo"); + + let permissions = registry + .permission_specs(Some(&allowed)) + .expect("runtime tool permissions should resolve"); + assert_eq!( + permissions, + vec![( + "mcp__demo__echo".to_string(), + runtime::PermissionMode::ReadOnly + )] + ); + + let search = registry.search("demo echo", 5, Some(vec!["pending-server".to_string()])); + let output = serde_json::to_value(search).expect("search output should serialize"); + assert_eq!(output["matches"][0], "mcp__demo__echo"); + assert_eq!(output["pending_mcp_servers"][0], "pending-server"); + } + #[test] fn web_fetch_returns_prompt_aware_summary() { let server = TestServer::spawn(Arc::new(|request_line: &str| {