diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 1881de8..a4ac1c0 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -1,994 +1,141 @@ -use std::collections::VecDeque; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use runtime::{ - load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, - OAuthTokenExchangeRequest, -}; -use serde::Deserialize; - use crate::error::ApiError; -use crate::sse::SseParser; +use crate::providers::anthropic::{self, AnthropicClient, AuthSource}; +use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig}; +use crate::providers::{self, Provider, ProviderKind}; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; -const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; -const ANTHROPIC_VERSION: &str = "2023-06-01"; -const REQUEST_ID_HEADER: &str = "request-id"; -const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; -const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); -const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); -const DEFAULT_MAX_RETRIES: u32 = 2; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum AuthSource { - None, - ApiKey(String), - BearerToken(String), - ApiKeyAndBearer { - api_key: String, - bearer_token: String, - }, +async fn send_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.send_message(request).await } -impl AuthSource { - pub fn from_env() -> Result { - let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; - let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; - match (api_key, auth_token) { - (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { - api_key, - bearer_token, - }), - (Some(api_key), None) => Ok(Self::ApiKey(api_key)), - (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), - (None, None) => Err(ApiError::MissingApiKey), - } - } - - #[must_use] - pub fn api_key(&self) -> Option<&str> { - match self { - Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), - Self::None | Self::BearerToken(_) => None, - } - } - - #[must_use] - pub fn bearer_token(&self) -> Option<&str> { - match self { - Self::BearerToken(token) - | Self::ApiKeyAndBearer { - bearer_token: token, - .. - } => Some(token), - Self::None | Self::ApiKey(_) => None, - } - } - - #[must_use] - pub fn masked_authorization_header(&self) -> &'static str { - if self.bearer_token().is_some() { - "Bearer [REDACTED]" - } else { - "" - } - } - - pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - if let Some(api_key) = self.api_key() { - request_builder = request_builder.header("x-api-key", api_key); - } - if let Some(token) = self.bearer_token() { - request_builder = request_builder.bearer_auth(token); - } - request_builder - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] -pub struct OAuthTokenSet { - pub access_token: String, - pub refresh_token: Option, - pub expires_at: Option, - #[serde(default)] - pub scopes: Vec, -} - -impl From for AuthSource { - fn from(value: OAuthTokenSet) -> Self { - Self::BearerToken(value.access_token) - } +async fn stream_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.stream_message(request).await } #[derive(Debug, Clone)] -pub struct ApiHttpClient { - http: reqwest::Client, - auth: AuthSource, - base_url: String, - max_retries: u32, - initial_backoff: Duration, - max_backoff: Duration, +pub enum ProviderClient { + Anthropic(AnthropicClient), + Xai(OpenAiCompatClient), + OpenAi(OpenAiCompatClient), } -impl ApiHttpClient { - #[must_use] - pub fn new(api_key: impl Into) -> Self { - Self { - http: reqwest::Client::new(), - auth: AuthSource::ApiKey(api_key.into()), - base_url: DEFAULT_BASE_URL.to_string(), - max_retries: DEFAULT_MAX_RETRIES, - initial_backoff: DEFAULT_INITIAL_BACKOFF, - max_backoff: DEFAULT_MAX_BACKOFF, +impl ProviderClient { + pub fn from_model(model: &str) -> Result { + Self::from_model_with_anthropic_auth(model, None) + } + + pub fn from_model_with_anthropic_auth( + model: &str, + anthropic_auth: Option, + ) -> Result { + let resolved_model = providers::resolve_model_alias(model); + match providers::detect_provider_kind(&resolved_model) { + ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth { + Some(auth) => AnthropicClient::from_auth(auth), + None => AnthropicClient::from_env()?, + })), + ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env( + OpenAiCompatConfig::xai(), + )?)), + ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env( + OpenAiCompatConfig::openai(), + )?)), } } #[must_use] - pub fn from_auth(auth: AuthSource) -> Self { - Self { - http: reqwest::Client::new(), - auth, - base_url: DEFAULT_BASE_URL.to_string(), - max_retries: DEFAULT_MAX_RETRIES, - initial_backoff: DEFAULT_INITIAL_BACKOFF, - max_backoff: DEFAULT_MAX_BACKOFF, + pub const fn provider_kind(&self) -> ProviderKind { + match self { + Self::Anthropic(_) => ProviderKind::Anthropic, + Self::Xai(_) => ProviderKind::Xai, + Self::OpenAi(_) => ProviderKind::OpenAi, } } - pub fn from_env() -> Result { - Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) - } - - #[must_use] - pub fn with_auth_source(mut self, auth: AuthSource) -> Self { - self.auth = auth; - self - } - - #[must_use] - pub fn with_auth_token(mut self, auth_token: Option) -> Self { - match ( - self.auth.api_key().map(ToOwned::to_owned), - auth_token.filter(|token| !token.is_empty()), - ) { - (Some(api_key), Some(bearer_token)) => { - self.auth = AuthSource::ApiKeyAndBearer { - api_key, - bearer_token, - }; - } - (Some(api_key), None) => { - self.auth = AuthSource::ApiKey(api_key); - } - (None, Some(bearer_token)) => { - self.auth = AuthSource::BearerToken(bearer_token); - } - (None, None) => { - self.auth = AuthSource::None; - } - } - self - } - - #[must_use] - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = base_url.into(); - self - } - - #[must_use] - pub fn with_retry_policy( - mut self, - max_retries: u32, - initial_backoff: Duration, - max_backoff: Duration, - ) -> Self { - self.max_retries = max_retries; - self.initial_backoff = initial_backoff; - self.max_backoff = max_backoff; - self - } - - #[must_use] - pub fn auth_source(&self) -> &AuthSource { - &self.auth - } - pub async fn send_message( &self, request: &MessageRequest, ) -> Result { - let request = MessageRequest { - stream: false, - ..request.clone() - }; - let response = self.send_with_retry(&request).await?; - let request_id = request_id_from_headers(response.headers()); - let mut response = response - .json::() - .await - .map_err(ApiError::from)?; - if response.request_id.is_none() { - response.request_id = request_id; + match self { + Self::Anthropic(client) => send_via_provider(client, request).await, + Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await, } - Ok(response) } pub async fn stream_message( &self, request: &MessageRequest, ) -> Result { - let response = self - .send_with_retry(&request.clone().with_streaming()) - .await?; - Ok(MessageStream { - request_id: request_id_from_headers(response.headers()), - response, - parser: SseParser::new(), - pending: VecDeque::new(), - done: false, - }) - } - - pub async fn exchange_oauth_code( - &self, - config: &OAuthConfig, - request: &OAuthTokenExchangeRequest, - ) -> Result { - let response = self - .http - .post(&config.token_url) - .header("content-type", "application/x-www-form-urlencoded") - .form(&request.form_params()) - .send() - .await - .map_err(ApiError::from)?; - let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) - } - - pub async fn refresh_oauth_token( - &self, - config: &OAuthConfig, - request: &OAuthRefreshRequest, - ) -> Result { - let response = self - .http - .post(&config.token_url) - .header("content-type", "application/x-www-form-urlencoded") - .form(&request.form_params()) - .send() - .await - .map_err(ApiError::from)?; - let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) - } - - async fn send_with_retry( - &self, - request: &MessageRequest, - ) -> Result { - let mut attempts = 0; - let mut last_error: Option; - - loop { - attempts += 1; - match self.send_raw_request(request).await { - Ok(response) => match expect_success(response).await { - Ok(response) => return Ok(response), - Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { - last_error = Some(error); - } - Err(error) => return Err(error), - }, - Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { - last_error = Some(error); - } - Err(error) => return Err(error), - } - - if attempts > self.max_retries { - break; - } - - tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; - } - - Err(ApiError::RetriesExhausted { - attempts, - last_error: Box::new(last_error.expect("retry loop must capture an error")), - }) - } - - async fn send_raw_request( - &self, - request: &MessageRequest, - ) -> Result { - let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); - let request_builder = self - .http - .post(&request_url) - .header("anthropic-version", ANTHROPIC_VERSION) - .header("content-type", "application/json"); - let mut request_builder = self.auth.apply(request_builder); - - request_builder = request_builder.json(request); - request_builder.send().await.map_err(ApiError::from) - } - - fn backoff_for_attempt(&self, attempt: u32) -> Result { - let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { - return Err(ApiError::BackoffOverflow { - attempt, - base_delay: self.initial_backoff, - }); - }; - Ok(self - .initial_backoff - .checked_mul(multiplier) - .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) - } -} - -impl AuthSource { - pub fn from_env_or_saved() -> Result { - if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { - return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - Some(bearer_token) => Ok(Self::ApiKeyAndBearer { - api_key, - bearer_token, - }), - None => Ok(Self::ApiKey(api_key)), - }; - } - if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - return Ok(Self::BearerToken(bearer_token)); - } - match load_saved_oauth_token() { - Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { - if token_set.refresh_token.is_some() { - Err(ApiError::Auth( - "saved OAuth token is expired; load runtime OAuth config to refresh it" - .to_string(), - )) - } else { - Err(ApiError::ExpiredOAuthToken) - } - } - Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), - Ok(None) => Err(ApiError::MissingApiKey), - Err(error) => Err(error), + match self { + Self::Anthropic(client) => stream_via_provider(client, request) + .await + .map(MessageStream::Anthropic), + Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request) + .await + .map(MessageStream::OpenAiCompat), } } } -#[must_use] -pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { - token_set - .expires_at - .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) -} - -pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { - let Some(token_set) = load_saved_oauth_token()? else { - return Ok(None); - }; - resolve_saved_oauth_token_set(config, token_set).map(Some) -} - -pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result -where - F: FnOnce() -> Result, ApiError>, -{ - if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { - return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { - api_key, - bearer_token, - }), - None => Ok(AuthSource::ApiKey(api_key)), - }; - } - if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - return Ok(AuthSource::BearerToken(bearer_token)); - } - - let Some(token_set) = load_saved_oauth_token()? else { - return Err(ApiError::MissingApiKey); - }; - if !oauth_token_is_expired(&token_set) { - return Ok(AuthSource::BearerToken(token_set.access_token)); - } - if token_set.refresh_token.is_none() { - return Err(ApiError::ExpiredOAuthToken); - } - - let Some(config) = load_oauth_config()? else { - return Err(ApiError::Auth( - "saved OAuth token is expired; runtime OAuth config is missing".to_string(), - )); - }; - Ok(AuthSource::from(resolve_saved_oauth_token_set( - &config, token_set, - )?)) -} - -fn resolve_saved_oauth_token_set( - config: &OAuthConfig, - token_set: OAuthTokenSet, -) -> Result { - if !oauth_token_is_expired(&token_set) { - return Ok(token_set); - } - let Some(refresh_token) = token_set.refresh_token.clone() else { - return Err(ApiError::ExpiredOAuthToken); - }; - let client = ApiHttpClient::from_auth(AuthSource::None).with_base_url(read_base_url()); - let refreshed = client_runtime_block_on(async { - client - .refresh_oauth_token( - config, - &OAuthRefreshRequest::from_config( - config, - refresh_token, - Some(token_set.scopes.clone()), - ), - ) - .await - })?; - let resolved = OAuthTokenSet { - access_token: refreshed.access_token, - refresh_token: refreshed.refresh_token.or(token_set.refresh_token), - expires_at: refreshed.expires_at, - scopes: refreshed.scopes, - }; - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: resolved.access_token.clone(), - refresh_token: resolved.refresh_token.clone(), - expires_at: resolved.expires_at, - scopes: resolved.scopes.clone(), - }) - .map_err(ApiError::from)?; - Ok(resolved) -} - -fn client_runtime_block_on(future: F) -> Result -where - F: std::future::Future>, -{ - tokio::runtime::Runtime::new() - .map_err(ApiError::from)? - .block_on(future) -} - -fn load_saved_oauth_token() -> Result, ApiError> { - let token_set = load_oauth_credentials().map_err(ApiError::from)?; - Ok(token_set.map(|token_set| OAuthTokenSet { - access_token: token_set.access_token, - refresh_token: token_set.refresh_token, - expires_at: token_set.expires_at, - scopes: token_set.scopes, - })) -} - -fn now_unix_timestamp() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0, |duration| duration.as_secs()) -} - -fn read_env_non_empty(key: &str) -> Result, ApiError> { - match std::env::var(key) { - Ok(value) if !value.is_empty() => Ok(Some(value)), - Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), - Err(error) => Err(ApiError::from(error)), - } -} - -#[cfg(test)] -fn read_api_key() -> Result { - let auth = AuthSource::from_env_or_saved()?; - auth.api_key() - .or_else(|| auth.bearer_token()) - .map(ToOwned::to_owned) - .ok_or(ApiError::MissingApiKey) -} - -#[cfg(test)] -fn read_auth_token() -> Option { - read_env_non_empty("ANTHROPIC_AUTH_TOKEN") - .ok() - .and_then(std::convert::identity) -} - -#[must_use] -pub fn read_base_url() -> String { - std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) -} - -fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { - headers - .get(REQUEST_ID_HEADER) - .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) - .and_then(|value| value.to_str().ok()) - .map(ToOwned::to_owned) -} - #[derive(Debug)] -pub struct MessageStream { - request_id: Option, - response: reqwest::Response, - parser: SseParser, - pending: VecDeque, - done: bool, +pub enum MessageStream { + Anthropic(anthropic::MessageStream), + OpenAiCompat(openai_compat::MessageStream), } impl MessageStream { #[must_use] pub fn request_id(&self) -> Option<&str> { - self.request_id.as_deref() + match self { + Self::Anthropic(stream) => stream.request_id(), + Self::OpenAiCompat(stream) => stream.request_id(), + } } pub async fn next_event(&mut self) -> Result, ApiError> { - loop { - if let Some(event) = self.pending.pop_front() { - return Ok(Some(event)); - } - - if self.done { - let remaining = self.parser.finish()?; - self.pending.extend(remaining); - if let Some(event) = self.pending.pop_front() { - return Ok(Some(event)); - } - return Ok(None); - } - - match self.response.chunk().await? { - Some(chunk) => { - self.pending.extend(self.parser.push(&chunk)?); - } - None => { - self.done = true; - } - } + match self { + Self::Anthropic(stream) => stream.next_event().await, + Self::OpenAiCompat(stream) => stream.next_event().await, } } } -async fn expect_success(response: reqwest::Response) -> Result { - let status = response.status(); - if status.is_success() { - return Ok(response); - } - - let body = response.text().await.unwrap_or_else(|_| String::new()); - let parsed_error = serde_json::from_str::(&body).ok(); - let retryable = is_retryable_status(status); - - Err(ApiError::Api { - status, - error_type: parsed_error - .as_ref() - .map(|error| error.error.error_type.clone()), - message: parsed_error - .as_ref() - .map(|error| error.error.message.clone()), - body, - retryable, - }) +pub use anthropic::{ + oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet, +}; +#[must_use] +pub fn read_base_url() -> String { + anthropic::read_base_url() } -const fn is_retryable_status(status: reqwest::StatusCode) -> bool { - matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) -} - -#[derive(Debug, Deserialize)] -struct AnthropicErrorEnvelope { - error: AnthropicErrorBody, -} - -#[derive(Debug, Deserialize)] -struct AnthropicErrorBody { - #[serde(rename = "type")] - error_type: String, - message: String, +#[must_use] +pub fn read_xai_base_url() -> String { + openai_compat::read_base_url(OpenAiCompatConfig::xai()) } #[cfg(test)] mod tests { - use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; - use std::io::{Read, Write}; - use std::net::TcpListener; - use std::sync::{Mutex, OnceLock}; - use std::thread; - use std::time::{Duration, SystemTime, UNIX_EPOCH}; + use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind}; - use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; - - use crate::client::{ - now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, - resolve_startup_auth_source, ApiHttpClient, AuthSource, OAuthTokenSet, - }; - use crate::types::{ContentBlockDelta, MessageRequest}; - - fn env_lock() -> std::sync::MutexGuard<'static, ()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) - .lock() - .expect("env lock") - } - - fn temp_config_home() -> std::path::PathBuf { - std::env::temp_dir().join(format!( - "api-oauth-test-{}-{}", - std::process::id(), - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time") - .as_nanos() - )) - } - - fn sample_oauth_config(token_url: String) -> OAuthConfig { - OAuthConfig { - client_id: "runtime-client".to_string(), - authorize_url: "https://console.test/oauth/authorize".to_string(), - token_url, - callback_port: Some(4545), - manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), - scopes: vec!["org:read".to_string(), "user:write".to_string()], - } - } - - fn spawn_token_server(response_body: &'static str) -> String { - let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); - let address = listener.local_addr().expect("local addr"); - thread::spawn(move || { - let (mut stream, _) = listener.accept().expect("accept connection"); - let mut buffer = [0_u8; 4096]; - let _ = stream.read(&mut buffer).expect("read request"); - let response = format!( - "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", - response_body.len(), - response_body - ); - stream - .write_all(response.as_bytes()) - .expect("write response"); - }); - format!("http://{address}/oauth/token") + #[test] + fn resolves_existing_and_grok_aliases() { + assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); } #[test] - fn read_api_key_requires_presence() { - let _guard = env_lock(); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - std::env::remove_var("CLAW_CONFIG_HOME"); - let error = super::read_api_key().expect_err("missing key should error"); - assert!(matches!(error, crate::error::ApiError::MissingApiKey)); - } - - #[test] - fn read_api_key_requires_non_empty_value() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); - std::env::remove_var("ANTHROPIC_API_KEY"); - let error = super::read_api_key().expect_err("empty key should error"); - assert!(matches!(error, crate::error::ApiError::MissingApiKey)); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - } - - #[test] - fn read_api_key_prefers_api_key_env() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + fn provider_detection_prefers_model_family() { + assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai); assert_eq!( - super::read_api_key().expect("api key should load"), - "legacy-key" - ); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - } - - #[test] - fn read_auth_token_reads_auth_token_env() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - } - - #[test] - fn oauth_token_maps_to_bearer_auth_source() { - let auth = AuthSource::from(OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(123), - scopes: vec!["scope:a".to_string()], - }); - assert_eq!(auth.bearer_token(), Some("access-token")); - assert_eq!(auth.api_key(), None); - } - - #[test] - fn auth_source_from_env_combines_api_key_and_bearer_token() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); - let auth = AuthSource::from_env().expect("env auth"); - assert_eq!(auth.api_key(), Some("legacy-key")); - assert_eq!(auth.bearer_token(), Some("auth-token")); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - } - - #[test] - fn auth_source_from_saved_oauth_when_env_absent() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAW_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "saved-access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(now_unix_timestamp() + 300), - scopes: vec!["scope:a".to_string()], - }) - .expect("save oauth credentials"); - - let auth = AuthSource::from_env_or_saved().expect("saved auth"); - assert_eq!(auth.bearer_token(), Some("saved-access-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAW_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn oauth_token_expiry_uses_expires_at_timestamp() { - assert!(oauth_token_is_expired(&OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: None, - expires_at: Some(1), - scopes: Vec::new(), - })); - assert!(!oauth_token_is_expired(&OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: None, - expires_at: Some(now_unix_timestamp() + 60), - scopes: Vec::new(), - })); - } - - #[test] - fn resolve_saved_oauth_token_refreshes_expired_credentials() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAW_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let token_url = spawn_token_server( - "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", - ); - let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) - .expect("resolve refreshed token") - .expect("token set present"); - assert_eq!(resolved.access_token, "refreshed-token"); - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.access_token, "refreshed-token"); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAW_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAW_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "saved-access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(now_unix_timestamp() + 300), - scopes: vec!["scope:a".to_string()], - }) - .expect("save oauth credentials"); - - let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) - .expect("startup auth"); - assert_eq!(auth.bearer_token(), Some("saved-access-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAW_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAW_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let error = - resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); - assert!( - matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) - ); - - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.access_token, "expired-access-token"); - assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAW_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAW_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let token_url = spawn_token_server( - "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", - ); - let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) - .expect("resolve refreshed token") - .expect("token set present"); - assert_eq!(resolved.access_token, "refreshed-token"); - assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAW_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn message_request_stream_helper_sets_stream_true() { - let request = MessageRequest { - model: "claude-opus-4-6".to_string(), - max_tokens: 64, - messages: vec![], - system: None, - tools: None, - tool_choice: None, - stream: false, - }; - - assert!(request.with_streaming().stream); - } - - #[test] - fn backoff_doubles_until_maximum() { - let client = ApiHttpClient::new("test-key").with_retry_policy( - 3, - Duration::from_millis(10), - Duration::from_millis(25), - ); - assert_eq!( - client.backoff_for_attempt(1).expect("attempt 1"), - Duration::from_millis(10) - ); - assert_eq!( - client.backoff_for_attempt(2).expect("attempt 2"), - Duration::from_millis(20) - ); - assert_eq!( - client.backoff_for_attempt(3).expect("attempt 3"), - Duration::from_millis(25) - ); - } - - #[test] - fn retryable_statuses_are_detected() { - assert!(super::is_retryable_status( - reqwest::StatusCode::TOO_MANY_REQUESTS - )); - assert!(super::is_retryable_status( - reqwest::StatusCode::INTERNAL_SERVER_ERROR - )); - assert!(!super::is_retryable_status( - reqwest::StatusCode::UNAUTHORIZED - )); - } - - #[test] - fn tool_delta_variant_round_trips() { - let delta = ContentBlockDelta::InputJsonDelta { - partial_json: "{\"city\":\"Paris\"}".to_string(), - }; - let encoded = serde_json::to_string(&delta).expect("delta should serialize"); - let decoded: ContentBlockDelta = - serde_json::from_str(&encoded).expect("delta should deserialize"); - assert_eq!(decoded, delta); - } - - #[test] - fn request_id_uses_primary_or_fallback_header() { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); - assert_eq!( - super::request_id_from_headers(&headers).as_deref(), - Some("req_primary") - ); - - headers.clear(); - headers.insert( - ALT_REQUEST_ID_HEADER, - "req_fallback".parse().expect("header"), - ); - assert_eq!( - super::request_id_from_headers(&headers).as_deref(), - Some("req_fallback") - ); - } - - #[test] - fn auth_source_applies_headers() { - let auth = AuthSource::ApiKeyAndBearer { - api_key: "test-key".to_string(), - bearer_token: "proxy-token".to_string(), - }; - let request = auth - .apply(reqwest::Client::new().post("https://example.test")) - .build() - .expect("request build"); - let headers = request.headers(); - assert_eq!( - headers.get("x-api-key").and_then(|v| v.to_str().ok()), - Some("test-key") - ); - assert_eq!( - headers.get("authorization").and_then(|v| v.to_str().ok()), - Some("Bearer proxy-token") + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::Anthropic ); } } diff --git a/rust/crates/api/src/error.rs b/rust/crates/api/src/error.rs index 2c31691..7649889 100644 --- a/rust/crates/api/src/error.rs +++ b/rust/crates/api/src/error.rs @@ -4,7 +4,10 @@ use std::time::Duration; #[derive(Debug)] pub enum ApiError { - MissingApiKey, + MissingCredentials { + provider: &'static str, + env_vars: &'static [&'static str], + }, ExpiredOAuthToken, Auth(String), InvalidApiKeyEnv(VarError), @@ -30,13 +33,21 @@ pub enum ApiError { } impl ApiError { + #[must_use] + pub const fn missing_credentials( + provider: &'static str, + env_vars: &'static [&'static str], + ) -> Self { + Self::MissingCredentials { provider, env_vars } + } + #[must_use] pub fn is_retryable(&self) -> bool { match self { Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(), Self::Api { retryable, .. } => *retryable, Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(), - Self::MissingApiKey + Self::MissingCredentials { .. } | Self::ExpiredOAuthToken | Self::Auth(_) | Self::InvalidApiKeyEnv(_) @@ -51,12 +62,11 @@ impl ApiError { impl Display for ApiError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::MissingApiKey => { - write!( - f, - "ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API" - ) - } + Self::MissingCredentials { provider, env_vars } => write!( + f, + "missing {provider} credentials; export {} before calling the {provider} API", + env_vars.join(" or ") + ), Self::ExpiredOAuthToken => { write!( f, @@ -65,10 +75,7 @@ impl Display for ApiError { } Self::Auth(message) => write!(f, "auth error: {message}"), Self::InvalidApiKeyEnv(error) => { - write!( - f, - "failed to read ANTHROPIC_AUTH_TOKEN / ANTHROPIC_API_KEY: {error}" - ) + write!(f, "failed to read credential environment variable: {error}") } Self::Http(error) => write!(f, "http error: {error}"), Self::Io(error) => write!(f, "io error: {error}"), @@ -81,20 +88,14 @@ impl Display for ApiError { .. } => match (error_type, message) { (Some(error_type), Some(message)) => { - write!( - f, - "anthropic api returned {status} ({error_type}): {message}" - ) + write!(f, "api returned {status} ({error_type}): {message}") } - _ => write!(f, "anthropic api returned {status}: {body}"), + _ => write!(f, "api returned {status}: {body}"), }, Self::RetriesExhausted { attempts, last_error, - } => write!( - f, - "anthropic api failed after {attempts} attempts: {last_error}" - ), + } => write!(f, "api failed after {attempts} attempts: {last_error}"), Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"), Self::BackoffOverflow { attempt, diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 4684af7..7702fee 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -1,13 +1,19 @@ mod client; mod error; +mod providers; mod sse; mod types; pub use client::{ - oauth_token_is_expired, read_base_url, resolve_saved_oauth_token, resolve_startup_auth_source, - ApiHttpClient, AuthSource, MessageStream, OAuthTokenSet, + oauth_token_is_expired, read_base_url, read_xai_base_url, resolve_saved_oauth_token, + resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient, }; pub use error::ApiError; +pub use providers::anthropic::{AnthropicClient, AuthSource}; +pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig}; +pub use providers::{ + detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind, +}; pub use sse::{parse_frame, SseParser}; pub use types::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs new file mode 100644 index 0000000..0883e60 --- /dev/null +++ b/rust/crates/api/src/providers/anthropic.rs @@ -0,0 +1,1038 @@ +use std::collections::VecDeque; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use runtime::{ + load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, + OAuthTokenExchangeRequest, +}; +use serde::Deserialize; + +use crate::error::ApiError; + +use super::{Provider, ProviderFuture}; +use crate::sse::SseParser; +use crate::types::{MessageRequest, MessageResponse, StreamEvent}; + +pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthSource { + None, + ApiKey(String), + BearerToken(String), + ApiKeyAndBearer { + api_key: String, + bearer_token: String, + }, +} + +impl AuthSource { + pub fn from_env() -> Result { + let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; + let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; + match (api_key, auth_token) { + (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + (Some(api_key), None) => Ok(Self::ApiKey(api_key)), + (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), + (None, None) => Err(ApiError::missing_credentials( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), + } + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + match self { + Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), + Self::None | Self::BearerToken(_) => None, + } + } + + #[must_use] + pub fn bearer_token(&self) -> Option<&str> { + match self { + Self::BearerToken(token) + | Self::ApiKeyAndBearer { + bearer_token: token, + .. + } => Some(token), + Self::None | Self::ApiKey(_) => None, + } + } + + #[must_use] + pub fn masked_authorization_header(&self) -> &'static str { + if self.bearer_token().is_some() { + "Bearer [REDACTED]" + } else { + "" + } + } + + pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(api_key) = self.api_key() { + request_builder = request_builder.header("x-api-key", api_key); + } + if let Some(token) = self.bearer_token() { + request_builder = request_builder.bearer_auth(token); + } + request_builder + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + #[serde(default)] + pub scopes: Vec, +} + +impl From for AuthSource { + fn from(value: OAuthTokenSet) -> Self { + Self::BearerToken(value.access_token) + } +} + +#[derive(Debug, Clone)] +pub struct AnthropicClient { + http: reqwest::Client, + auth: AuthSource, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl AnthropicClient { + #[must_use] + pub fn new(api_key: impl Into) -> Self { + Self { + http: reqwest::Client::new(), + auth: AuthSource::ApiKey(api_key.into()), + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + #[must_use] + pub fn from_auth(auth: AuthSource) -> Self { + Self { + http: reqwest::Client::new(), + auth, + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env() -> Result { + Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) + } + + #[must_use] + pub fn with_auth_source(mut self, auth: AuthSource) -> Self { + self.auth = auth; + self + } + + #[must_use] + pub fn with_auth_token(mut self, auth_token: Option) -> Self { + match ( + self.auth.api_key().map(ToOwned::to_owned), + auth_token.filter(|token| !token.is_empty()), + ) { + (Some(api_key), Some(bearer_token)) => { + self.auth = AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }; + } + (Some(api_key), None) => { + self.auth = AuthSource::ApiKey(api_key); + } + (None, Some(bearer_token)) => { + self.auth = AuthSource::BearerToken(bearer_token); + } + (None, None) => { + self.auth = AuthSource::None; + } + } + self + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + #[must_use] + pub fn auth_source(&self) -> &AuthSource { + &self.auth + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let mut response = response + .json::() + .await + .map_err(ApiError::from)?; + if response.request_id.is_none() { + response.request_id = request_id; + } + Ok(response) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: SseParser::new(), + pending: VecDeque::new(), + done: false, + }) + } + + pub async fn exchange_oauth_code( + &self, + config: &OAuthConfig, + request: &OAuthTokenExchangeRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + pub async fn refresh_oauth_token( + &self, + config: &OAuthConfig, + request: &OAuthRefreshRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + let mut last_error: Option; + + loop { + attempts += 1; + match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + } + + if attempts > self.max_retries { + break; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + } + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error.expect("retry loop must capture an error")), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); + let request_builder = self + .http + .post(&request_url) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json"); + let mut request_builder = self.auth.apply(request_builder); + + request_builder = request_builder.json(request); + request_builder.send().await.map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl AuthSource { + pub fn from_env_or_saved() -> Result { + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(Self::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(Self::BearerToken(bearer_token)); + } + match load_saved_oauth_token() { + Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { + if token_set.refresh_token.is_some() { + Err(ApiError::Auth( + "saved OAuth token is expired; load runtime OAuth config to refresh it" + .to_string(), + )) + } else { + Err(ApiError::ExpiredOAuthToken) + } + } + Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), + Ok(None) => Err(ApiError::missing_credentials( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), + Err(error) => Err(error), + } + } +} + +#[must_use] +pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { + token_set + .expires_at + .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) +} + +pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { + let Some(token_set) = load_saved_oauth_token()? else { + return Ok(None); + }; + resolve_saved_oauth_token_set(config, token_set).map(Some) +} + +pub fn has_auth_from_env_or_saved() -> Result { + Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some() + || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some() + || load_saved_oauth_token()?.is_some()) +} + +pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result +where + F: FnOnce() -> Result, ApiError>, +{ + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(AuthSource::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(AuthSource::BearerToken(bearer_token)); + } + + let Some(token_set) = load_saved_oauth_token()? else { + return Err(ApiError::missing_credentials( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )); + }; + if !oauth_token_is_expired(&token_set) { + return Ok(AuthSource::BearerToken(token_set.access_token)); + } + if token_set.refresh_token.is_none() { + return Err(ApiError::ExpiredOAuthToken); + } + + let Some(config) = load_oauth_config()? else { + return Err(ApiError::Auth( + "saved OAuth token is expired; runtime OAuth config is missing".to_string(), + )); + }; + Ok(AuthSource::from(resolve_saved_oauth_token_set( + &config, token_set, + )?)) +} + +fn resolve_saved_oauth_token_set( + config: &OAuthConfig, + token_set: OAuthTokenSet, +) -> Result { + if !oauth_token_is_expired(&token_set) { + return Ok(token_set); + } + let Some(refresh_token) = token_set.refresh_token.clone() else { + return Err(ApiError::ExpiredOAuthToken); + }; + let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url()); + let refreshed = client_runtime_block_on(async { + client + .refresh_oauth_token( + config, + &OAuthRefreshRequest::from_config( + config, + refresh_token, + Some(token_set.scopes.clone()), + ), + ) + .await + })?; + let resolved = OAuthTokenSet { + access_token: refreshed.access_token, + refresh_token: refreshed.refresh_token.or(token_set.refresh_token), + expires_at: refreshed.expires_at, + scopes: refreshed.scopes, + }; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: resolved.access_token.clone(), + refresh_token: resolved.refresh_token.clone(), + expires_at: resolved.expires_at, + scopes: resolved.scopes.clone(), + }) + .map_err(ApiError::from)?; + Ok(resolved) +} + +fn client_runtime_block_on(future: F) -> Result +where + F: std::future::Future>, +{ + tokio::runtime::Runtime::new() + .map_err(ApiError::from)? + .block_on(future) +} + +fn load_saved_oauth_token() -> Result, ApiError> { + let token_set = load_oauth_credentials().map_err(ApiError::from)?; + Ok(token_set.map(|token_set| OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })) +} + +fn now_unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[cfg(test)] +fn read_api_key() -> Result { + let auth = AuthSource::from_env_or_saved()?; + auth.api_key() + .or_else(|| auth.bearer_token()) + .map(ToOwned::to_owned) + .ok_or(ApiError::missing_credentials( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )) +} + +#[cfg(test)] +fn read_auth_token() -> Option { + read_env_non_empty("ANTHROPIC_AUTH_TOKEN") + .ok() + .and_then(std::convert::identity) +} + +#[must_use] +pub fn read_base_url() -> String { + std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +impl Provider for AnthropicClient { + type Stream = MessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: SseParser, + pending: VecDeque, + done: bool, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + let remaining = self.parser.finish()?; + self.pending.extend(remaining); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + self.pending.extend(self.parser.push(&chunk)?); + } + None => { + self.done = true; + } + } + } + } +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_else(|_| String::new()); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .map(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .map(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +#[derive(Debug, Deserialize)] +struct AnthropicErrorEnvelope { + error: AnthropicErrorBody, +} + +#[derive(Debug, Deserialize)] +struct AnthropicErrorBody { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +#[cfg(test)] +mod tests { + use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::sync::{Mutex, OnceLock}; + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; + + use super::{ + now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, + resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, + }; + use crate::types::{ContentBlockDelta, MessageRequest}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + fn temp_config_home() -> std::path::PathBuf { + std::env::temp_dir().join(format!( + "api-oauth-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )) + } + + fn sample_oauth_config(token_url: String) -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url, + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:read".to_string(), "user:write".to_string()], + } + } + + fn spawn_token_server(response_body: &'static str) -> String { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let address = listener.local_addr().expect("local addr"); + thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept connection"); + let mut buffer = [0_u8; 4096]; + let _ = stream.read(&mut buffer).expect("read request"); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + response_body.len(), + response_body + ); + stream + .write_all(response.as_bytes()) + .expect("write response"); + }); + format!("http://{address}/oauth/token") + } + + #[test] + fn read_api_key_requires_presence() { + let _guard = env_lock(); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + let error = super::read_api_key().expect_err("missing key should error"); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); + } + + #[test] + fn read_api_key_requires_non_empty_value() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); + std::env::remove_var("ANTHROPIC_API_KEY"); + let error = super::read_api_key().expect_err("empty key should error"); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn read_api_key_prefers_api_key_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + assert_eq!( + super::read_api_key().expect("api key should load"), + "legacy-key" + ); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn read_auth_token_reads_auth_token_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn oauth_token_maps_to_bearer_auth_source() { + let auth = AuthSource::from(OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(123), + scopes: vec!["scope:a".to_string()], + }); + assert_eq!(auth.bearer_token(), Some("access-token")); + assert_eq!(auth.api_key(), None); + } + + #[test] + fn auth_source_from_env_combines_api_key_and_bearer_token() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + let auth = AuthSource::from_env().expect("env auth"); + assert_eq!(auth.api_key(), Some("legacy-key")); + assert_eq!(auth.bearer_token(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn auth_source_from_saved_oauth_when_env_absent() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = AuthSource::from_env_or_saved().expect("saved auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn oauth_token_expiry_uses_expires_at_timestamp() { + assert!(oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(1), + scopes: Vec::new(), + })); + assert!(!oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(now_unix_timestamp() + 60), + scopes: Vec::new(), + })); + } + + #[test] + fn resolve_saved_oauth_token_refreshes_expired_credentials() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "refreshed-token"); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) + .expect("startup auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let error = + resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); + assert!( + matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) + ); + + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "expired-access-token"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn message_request_stream_helper_sets_stream_true() { + let request = MessageRequest { + model: "claude-opus-4-6".to_string(), + max_tokens: 64, + messages: vec![], + system: None, + tools: None, + tool_choice: None, + stream: false, + }; + + assert!(request.with_streaming().stream); + } + + #[test] + fn backoff_doubles_until_maximum() { + let client = AnthropicClient::new("test-key").with_retry_policy( + 3, + Duration::from_millis(10), + Duration::from_millis(25), + ); + assert_eq!( + client.backoff_for_attempt(1).expect("attempt 1"), + Duration::from_millis(10) + ); + assert_eq!( + client.backoff_for_attempt(2).expect("attempt 2"), + Duration::from_millis(20) + ); + assert_eq!( + client.backoff_for_attempt(3).expect("attempt 3"), + Duration::from_millis(25) + ); + } + + #[test] + fn retryable_statuses_are_detected() { + assert!(super::is_retryable_status( + reqwest::StatusCode::TOO_MANY_REQUESTS + )); + assert!(super::is_retryable_status( + reqwest::StatusCode::INTERNAL_SERVER_ERROR + )); + assert!(!super::is_retryable_status( + reqwest::StatusCode::UNAUTHORIZED + )); + } + + #[test] + fn tool_delta_variant_round_trips() { + let delta = ContentBlockDelta::InputJsonDelta { + partial_json: "{\"city\":\"Paris\"}".to_string(), + }; + let encoded = serde_json::to_string(&delta).expect("delta should serialize"); + let decoded: ContentBlockDelta = + serde_json::from_str(&encoded).expect("delta should deserialize"); + assert_eq!(decoded, delta); + } + + #[test] + fn request_id_uses_primary_or_fallback_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_primary") + ); + + headers.clear(); + headers.insert( + ALT_REQUEST_ID_HEADER, + "req_fallback".parse().expect("header"), + ); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_fallback") + ); + } + + #[test] + fn auth_source_applies_headers() { + let auth = AuthSource::ApiKeyAndBearer { + api_key: "test-key".to_string(), + bearer_token: "proxy-token".to_string(), + }; + let request = auth + .apply(reqwest::Client::new().post("https://example.test")) + .build() + .expect("request build"); + let headers = request.headers(); + assert_eq!( + headers.get("x-api-key").and_then(|v| v.to_str().ok()), + Some("test-key") + ); + assert_eq!( + headers.get("authorization").and_then(|v| v.to_str().ok()), + Some("Bearer proxy-token") + ); + } +} diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs new file mode 100644 index 0000000..d28febd --- /dev/null +++ b/rust/crates/api/src/providers/mod.rs @@ -0,0 +1,216 @@ +use std::future::Future; +use std::pin::Pin; + +use crate::error::ApiError; +use crate::types::{MessageRequest, MessageResponse}; + +pub mod anthropic; +pub mod openai_compat; + +pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; + +pub trait Provider { + type Stream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse>; + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream>; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProviderKind { + Anthropic, + Xai, + OpenAi, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ProviderMetadata { + pub provider: ProviderKind, + pub auth_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ + ( + "opus", + ProviderMetadata { + provider: ProviderKind::Anthropic, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ), + ( + "sonnet", + ProviderMetadata { + provider: ProviderKind::Anthropic, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ), + ( + "haiku", + ProviderMetadata { + provider: ProviderKind::Anthropic, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ), + ( + "grok", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-2", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), +]; + +#[must_use] +pub fn resolve_model_alias(model: &str) -> String { + let trimmed = model.trim(); + let lower = trimmed.to_ascii_lowercase(); + MODEL_REGISTRY + .iter() + .find_map(|(alias, metadata)| { + (*alias == lower).then_some(match metadata.provider { + ProviderKind::Anthropic => match *alias { + "opus" => "claude-opus-4-6", + "sonnet" => "claude-sonnet-4-6", + "haiku" => "claude-haiku-4-5-20251213", + _ => trimmed, + }, + ProviderKind::Xai => match *alias { + "grok" | "grok-3" => "grok-3", + "grok-mini" | "grok-3-mini" => "grok-3-mini", + "grok-2" => "grok-2", + _ => trimmed, + }, + ProviderKind::OpenAi => trimmed, + }) + }) + .map_or_else(|| trimmed.to_string(), ToOwned::to_owned) +} + +#[must_use] +pub fn metadata_for_model(model: &str) -> Option { + let canonical = resolve_model_alias(model); + if canonical.starts_with("claude") { + return Some(ProviderMetadata { + provider: ProviderKind::Anthropic, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }); + } + if canonical.starts_with("grok") { + return Some(ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }); + } + None +} + +#[must_use] +pub fn detect_provider_kind(model: &str) -> ProviderKind { + if let Some(metadata) = metadata_for_model(model) { + return metadata.provider; + } + if anthropic::has_auth_from_env_or_saved().unwrap_or(false) { + return ProviderKind::Anthropic; + } + if openai_compat::has_api_key("OPENAI_API_KEY") { + return ProviderKind::OpenAi; + } + if openai_compat::has_api_key("XAI_API_KEY") { + return ProviderKind::Xai; + } + ProviderKind::Anthropic +} + +#[must_use] +pub fn max_tokens_for_model(model: &str) -> u32 { + let canonical = resolve_model_alias(model); + if canonical.contains("opus") { + 32_000 + } else { + 64_000 + } +} + +#[cfg(test)] +mod tests { + use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind}; + + #[test] + fn resolves_grok_aliases() { + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); + assert_eq!(resolve_model_alias("grok-2"), "grok-2"); + } + + #[test] + fn detects_provider_from_model_name_first() { + assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai); + assert_eq!( + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::Anthropic + ); + } + + #[test] + fn keeps_existing_max_token_heuristic() { + assert_eq!(max_tokens_for_model("opus"), 32_000); + assert_eq!(max_tokens_for_model("grok-3"), 64_000); + } +} diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs new file mode 100644 index 0000000..e8210ae --- /dev/null +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -0,0 +1,1050 @@ +use std::collections::{BTreeMap, VecDeque}; +use std::time::Duration; + +use serde::Deserialize; +use serde_json::{json, Value}; + +use crate::error::ApiError; +use crate::types::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, + MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, + ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, +}; + +use super::{Provider, ProviderFuture}; + +pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; +pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct OpenAiCompatConfig { + pub provider_name: &'static str, + pub api_key_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; +const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; + +impl OpenAiCompatConfig { + #[must_use] + pub const fn xai() -> Self { + Self { + provider_name: "xAI", + api_key_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: DEFAULT_XAI_BASE_URL, + } + } + + #[must_use] + pub const fn openai() -> Self { + Self { + provider_name: "OpenAI", + api_key_env: "OPENAI_API_KEY", + base_url_env: "OPENAI_BASE_URL", + default_base_url: DEFAULT_OPENAI_BASE_URL, + } + } + #[must_use] + pub fn credential_env_vars(self) -> &'static [&'static str] { + match self.provider_name { + "xAI" => XAI_ENV_VARS, + "OpenAI" => OPENAI_ENV_VARS, + _ => &[], + } + } +} + +#[derive(Debug, Clone)] +pub struct OpenAiCompatClient { + http: reqwest::Client, + api_key: String, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl OpenAiCompatClient { + #[must_use] + pub fn new(api_key: impl Into, config: OpenAiCompatConfig) -> Self { + Self { + http: reqwest::Client::new(), + api_key: api_key.into(), + base_url: read_base_url(config), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env(config: OpenAiCompatConfig) -> Result { + let Some(api_key) = read_env_non_empty(config.api_key_env)? else { + return Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )); + }; + Ok(Self::new(api_key, config)) + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let payload = response.json::().await?; + let mut normalized = normalize_response(&request.model, payload)?; + if normalized.request_id.is_none() { + normalized.request_id = request_id; + } + Ok(normalized) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: OpenAiSseParser::new(), + pending: VecDeque::new(), + done: false, + state: StreamState::new(request.model.clone()), + }) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + + let last_error = loop { + attempts += 1; + let retryable_error = match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }; + + if attempts > self.max_retries { + break retryable_error; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + }; + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = chat_completions_endpoint(&self.base_url); + self.http + .post(&request_url) + .header("content-type", "application/json") + .bearer_auth(&self.api_key) + .json(&build_chat_completion_request(request)) + .send() + .await + .map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl Provider for OpenAiCompatClient { + type Stream = MessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: OpenAiSseParser, + pending: VecDeque, + done: bool, + state: StreamState, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + self.pending.extend(self.state.finish()?); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + for parsed in self.parser.push(&chunk)? { + self.pending.extend(self.state.ingest_chunk(parsed)?); + } + } + None => { + self.done = true; + } + } + } + } +} + +#[derive(Debug, Default)] +struct OpenAiSseParser { + buffer: Vec, +} + +impl OpenAiSseParser { + fn new() -> Self { + Self::default() + } + + fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { + self.buffer.extend_from_slice(chunk); + let mut events = Vec::new(); + + while let Some(frame) = next_sse_frame(&mut self.buffer) { + if let Some(event) = parse_sse_frame(&frame)? { + events.push(event); + } + } + + Ok(events) + } +} + +#[derive(Debug)] +struct StreamState { + model: String, + message_started: bool, + text_started: bool, + text_finished: bool, + finished: bool, + stop_reason: Option, + usage: Option, + tool_calls: BTreeMap, +} + +impl StreamState { + fn new(model: String) -> Self { + Self { + model, + message_started: false, + text_started: false, + text_finished: false, + finished: false, + stop_reason: None, + usage: None, + tool_calls: BTreeMap::new(), + } + } + + fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result, ApiError> { + let mut events = Vec::new(); + if !self.message_started { + self.message_started = true; + events.push(StreamEvent::MessageStart(MessageStartEvent { + message: MessageResponse { + id: chunk.id.clone(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: Vec::new(), + model: chunk.model.clone().unwrap_or_else(|| self.model.clone()), + stop_reason: None, + stop_sequence: None, + usage: Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }, + request_id: None, + }, + })); + } + + if let Some(usage) = chunk.usage { + self.usage = Some(Usage { + input_tokens: usage.prompt_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: usage.completion_tokens, + }); + } + + for choice in chunk.choices { + if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) { + if !self.text_started { + self.text_started = true; + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 0, + content_block: OutputContentBlock::Text { + text: String::new(), + }, + })); + } + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::TextDelta { text: content }, + })); + } + + for tool_call in choice.delta.tool_calls { + let state = self.tool_calls.entry(tool_call.index).or_default(); + state.apply(tool_call); + let block_index = state.block_index(); + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + } else { + continue; + } + } + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + } + + if let Some(finish_reason) = choice.finish_reason { + self.stop_reason = Some(normalize_finish_reason(&finish_reason)); + if finish_reason == "tool_calls" { + for state in self.tool_calls.values_mut() { + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + } + } + } + + Ok(events) + } + + fn finish(&mut self) -> Result, ApiError> { + if self.finished { + return Ok(Vec::new()); + } + self.finished = true; + + let mut events = Vec::new(); + if self.text_started && !self.text_finished { + self.text_finished = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: 0, + })); + } + + for state in self.tool_calls.values_mut() { + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + } + } + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + + if self.message_started { + events.push(StreamEvent::MessageDelta(MessageDeltaEvent { + delta: MessageDelta { + stop_reason: Some( + self.stop_reason + .clone() + .unwrap_or_else(|| "end_turn".to_string()), + ), + stop_sequence: None, + }, + usage: self.usage.clone().unwrap_or(Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }), + })); + events.push(StreamEvent::MessageStop(MessageStopEvent {})); + } + Ok(events) + } +} + +#[derive(Debug, Default)] +struct ToolCallState { + openai_index: u32, + id: Option, + name: Option, + arguments: String, + emitted_len: usize, + started: bool, + stopped: bool, +} + +impl ToolCallState { + fn apply(&mut self, tool_call: DeltaToolCall) { + self.openai_index = tool_call.index; + if let Some(id) = tool_call.id { + self.id = Some(id); + } + if let Some(name) = tool_call.function.name { + self.name = Some(name); + } + if let Some(arguments) = tool_call.function.arguments { + self.arguments.push_str(&arguments); + } + } + + const fn block_index(&self) -> u32 { + self.openai_index + 1 + } + + fn start_event(&self) -> Result, ApiError> { + let Some(name) = self.name.clone() else { + return Ok(None); + }; + let id = self + .id + .clone() + .unwrap_or_else(|| format!("tool_call_{}", self.openai_index)); + Ok(Some(ContentBlockStartEvent { + index: self.block_index(), + content_block: OutputContentBlock::ToolUse { + id, + name, + input: json!({}), + }, + })) + } + + fn delta_event(&mut self) -> Option { + if self.emitted_len >= self.arguments.len() { + return None; + } + let delta = self.arguments[self.emitted_len..].to_string(); + self.emitted_len = self.arguments.len(); + Some(ContentBlockDeltaEvent { + index: self.block_index(), + delta: ContentBlockDelta::InputJsonDelta { + partial_json: delta, + }, + }) + } +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionResponse { + id: String, + model: String, + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatChoice { + message: ChatMessage, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatMessage { + role: String, + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolCall { + id: String, + function: ResponseToolFunction, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolFunction { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct OpenAiUsage { + #[serde(default)] + prompt_tokens: u32, + #[serde(default)] + completion_tokens: u32, +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionChunk { + id: String, + #[serde(default)] + model: Option, + #[serde(default)] + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChunkChoice { + delta: ChunkDelta, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Default, Deserialize)] +struct ChunkDelta { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct DeltaToolCall { + #[serde(default)] + index: u32, + #[serde(default)] + id: Option, + #[serde(default)] + function: DeltaFunction, +} + +#[derive(Debug, Default, Deserialize)] +struct DeltaFunction { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct ErrorEnvelope { + error: ErrorBody, +} + +#[derive(Debug, Deserialize)] +struct ErrorBody { + #[serde(rename = "type")] + error_type: Option, + message: Option, +} + +fn build_chat_completion_request(request: &MessageRequest) -> Value { + let mut messages = Vec::new(); + if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { + messages.push(json!({ + "role": "system", + "content": system, + })); + } + for message in &request.messages { + messages.extend(translate_message(message)); + } + + let mut payload = json!({ + "model": request.model, + "max_tokens": request.max_tokens, + "messages": messages, + "stream": request.stream, + }); + + if let Some(tools) = &request.tools { + payload["tools"] = + Value::Array(tools.iter().map(openai_tool_definition).collect::>()); + } + if let Some(tool_choice) = &request.tool_choice { + payload["tool_choice"] = openai_tool_choice(tool_choice); + } + + payload +} + +fn translate_message(message: &InputMessage) -> Vec { + match message.role.as_str() { + "assistant" => { + let mut text = String::new(); + let mut tool_calls = Vec::new(); + for block in &message.content { + match block { + InputContentBlock::Text { text: value } => text.push_str(value), + InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({ + "id": id, + "type": "function", + "function": { + "name": name, + "arguments": input.to_string(), + } + })), + InputContentBlock::ToolResult { .. } => {} + } + } + if text.is_empty() && tool_calls.is_empty() { + Vec::new() + } else { + vec![json!({ + "role": "assistant", + "content": (!text.is_empty()).then_some(text), + "tool_calls": tool_calls, + })] + } + } + _ => message + .content + .iter() + .filter_map(|block| match block { + InputContentBlock::Text { text } => Some(json!({ + "role": "user", + "content": text, + })), + InputContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } => Some(json!({ + "role": "tool", + "tool_call_id": tool_use_id, + "content": flatten_tool_result_content(content), + "is_error": is_error, + })), + InputContentBlock::ToolUse { .. } => None, + }) + .collect(), + } +} + +fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String { + content + .iter() + .map(|block| match block { + ToolResultContentBlock::Text { text } => text.clone(), + ToolResultContentBlock::Json { value } => value.to_string(), + }) + .collect::>() + .join("\n") +} + +fn openai_tool_definition(tool: &ToolDefinition) -> Value { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + } + }) +} + +fn openai_tool_choice(tool_choice: &ToolChoice) -> Value { + match tool_choice { + ToolChoice::Auto => Value::String("auto".to_string()), + ToolChoice::Any => Value::String("required".to_string()), + ToolChoice::Tool { name } => json!({ + "type": "function", + "function": { "name": name }, + }), + } +} + +fn normalize_response( + model: &str, + response: ChatCompletionResponse, +) -> Result { + let choice = response + .choices + .into_iter() + .next() + .ok_or(ApiError::InvalidSseFrame( + "chat completion response missing choices", + ))?; + let mut content = Vec::new(); + if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) { + content.push(OutputContentBlock::Text { text }); + } + for tool_call in choice.message.tool_calls { + content.push(OutputContentBlock::ToolUse { + id: tool_call.id, + name: tool_call.function.name, + input: parse_tool_arguments(&tool_call.function.arguments), + }); + } + + Ok(MessageResponse { + id: response.id, + kind: "message".to_string(), + role: choice.message.role, + content, + model: response.model.if_empty_then(model.to_string()), + stop_reason: choice + .finish_reason + .map(|value| normalize_finish_reason(&value)), + stop_sequence: None, + usage: Usage { + input_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.prompt_tokens), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.completion_tokens), + }, + request_id: None, + }) +} + +fn parse_tool_arguments(arguments: &str) -> Value { + serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments })) +} + +fn next_sse_frame(buffer: &mut Vec) -> Option { + let separator = buffer + .windows(2) + .position(|window| window == b"\n\n") + .map(|position| (position, 2)) + .or_else(|| { + buffer + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|position| (position, 4)) + })?; + + let (position, separator_len) = separator; + let frame = buffer.drain(..position + separator_len).collect::>(); + let frame_len = frame.len().saturating_sub(separator_len); + Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) +} + +fn parse_sse_frame(frame: &str) -> Result, ApiError> { + let trimmed = frame.trim(); + if trimmed.is_empty() { + return Ok(None); + } + + let mut data_lines = Vec::new(); + for line in trimmed.lines() { + if line.starts_with(':') { + continue; + } + if let Some(data) = line.strip_prefix("data:") { + data_lines.push(data.trim_start()); + } + } + if data_lines.is_empty() { + return Ok(None); + } + let payload = data_lines.join("\n"); + if payload == "[DONE]" { + return Ok(None); + } + serde_json::from_str(&payload) + .map(Some) + .map_err(ApiError::from) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[must_use] +pub fn has_api_key(key: &str) -> bool { + read_env_non_empty(key) + .ok() + .and_then(std::convert::identity) + .is_some() +} + +#[must_use] +pub fn read_base_url(config: OpenAiCompatConfig) -> String { + std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) +} + +fn chat_completions_endpoint(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/chat/completions") { + trimmed.to_string() + } else { + format!("{trimmed}/chat/completions") + } +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_default(); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .and_then(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .and_then(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +fn normalize_finish_reason(value: &str) -> String { + match value { + "stop" => "end_turn", + "tool_calls" => "tool_use", + other => other, + } + .to_string() +} + +trait StringExt { + fn if_empty_then(self, fallback: String) -> String; +} + +impl StringExt for String { + fn if_empty_then(self, fallback: String) -> String { + if self.is_empty() { + fallback + } else { + self + } + } +} + +#[cfg(test)] +mod tests { + use super::{ + build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason, + openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, + }; + use crate::error::ApiError; + use crate::types::{ + InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, + ToolResultContentBlock, + }; + use serde_json::json; + use std::sync::{Mutex, OnceLock}; + + #[test] + fn request_translation_uses_openai_compatible_shape() { + let payload = build_chat_completion_request(&MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![ + InputContentBlock::Text { + text: "hello".to_string(), + }, + InputContentBlock::ToolResult { + tool_use_id: "tool_1".to_string(), + content: vec![ToolResultContentBlock::Json { + value: json!({"ok": true}), + }], + is_error: false, + }, + ], + }], + system: Some("be helpful".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Get weather".to_string()), + input_schema: json!({"type": "object"}), + }]), + tool_choice: Some(ToolChoice::Auto), + stream: false, + }); + + assert_eq!(payload["messages"][0]["role"], json!("system")); + assert_eq!(payload["messages"][1]["role"], json!("user")); + assert_eq!(payload["messages"][2]["role"], json!("tool")); + assert_eq!(payload["tools"][0]["type"], json!("function")); + assert_eq!(payload["tool_choice"], json!("auto")); + } + + #[test] + fn tool_choice_translation_supports_required_function() { + assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required")); + assert_eq!( + openai_tool_choice(&ToolChoice::Tool { + name: "weather".to_string(), + }), + json!({"type": "function", "function": {"name": "weather"}}) + ); + } + + #[test] + fn parses_tool_arguments_fallback() { + assert_eq!( + parse_tool_arguments("{\"city\":\"Paris\"}"), + json!({"city": "Paris"}) + ); + assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"})); + } + + #[test] + fn missing_xai_api_key_is_provider_specific() { + let _lock = env_lock(); + std::env::remove_var("XAI_API_KEY"); + let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai()) + .expect_err("missing key should error"); + assert!(matches!( + error, + ApiError::MissingCredentials { + provider: "xAI", + .. + } + )); + } + + #[test] + fn endpoint_builder_accepts_base_urls_and_full_endpoints() { + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/chat/completions"), + "https://api.x.ai/v1/chat/completions" + ); + } + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + #[test] + fn normalizes_stop_reasons() { + assert_eq!(normalize_finish_reason("stop"), "end_turn"); + assert_eq!(normalize_finish_reason("tool_calls"), "tool_use"); + } +} diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index e2eaef6..b52f890 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -3,9 +3,9 @@ use std::sync::Arc; use std::time::Duration; use api::{ - ApiHttpClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, - InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock, - StreamEvent, ToolChoice, ToolDefinition, + AnthropicClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent, + ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -34,7 +34,7 @@ async fn send_message_posts_json_and_parses_response() { ) .await; - let client = ApiHttpClient::new("test-key") + let client = AnthropicClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) .with_base_url(server.base_url()); let response = client @@ -75,48 +75,6 @@ async fn send_message_posts_json_and_parses_response() { assert_eq!(body["tool_choice"]["type"], json!("auto")); } -#[tokio::test] -async fn send_message_parses_response_with_thinking_blocks() { - let state = Arc::new(Mutex::new(Vec::::new())); - let body = concat!( - "{", - "\"id\":\"msg_thinking\",", - "\"type\":\"message\",", - "\"role\":\"assistant\",", - "\"content\":[", - "{\"type\":\"thinking\",\"thinking\":\"step 1\",\"signature\":\"sig_123\"},", - "{\"type\":\"text\",\"text\":\"Final answer\"}", - "],", - "\"model\":\"claude-3-7-sonnet-latest\",", - "\"stop_reason\":\"end_turn\",", - "\"stop_sequence\":null,", - "\"usage\":{\"input_tokens\":12,\"output_tokens\":4}", - "}" - ); - let server = spawn_server( - state, - vec![http_response("200 OK", "application/json", body)], - ) - .await; - - let client = ApiHttpClient::new("test-key").with_base_url(server.base_url()); - let response = client - .send_message(&sample_request(false)) - .await - .expect("request should succeed"); - - assert_eq!(response.content.len(), 2); - assert!(matches!( - &response.content[0], - OutputContentBlock::Thinking { thinking, signature } - if thinking == "step 1" && signature.as_deref() == Some("sig_123") - )); - assert!(matches!( - &response.content[1], - OutputContentBlock::Text { text } if text == "Final answer" - )); -} - #[tokio::test] async fn stream_message_parses_sse_events_with_tool_use() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -146,7 +104,7 @@ async fn stream_message_parses_sse_events_with_tool_use() { ) .await; - let client = ApiHttpClient::new("test-key") + let client = AnthropicClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) .with_base_url(server.base_url()); let mut stream = client @@ -204,85 +162,6 @@ async fn stream_message_parses_sse_events_with_tool_use() { assert!(request.body.contains("\"stream\":true")); } -#[tokio::test] -async fn stream_message_parses_sse_events_with_thinking_blocks() { - let state = Arc::new(Mutex::new(Vec::::new())); - let sse = concat!( - "event: message_start\n", - "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream_thinking\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", - "event: content_block_start\n", - "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\n", - "event: content_block_delta\n", - "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"step 1\"}}\n\n", - "event: content_block_delta\n", - "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"signature_delta\",\"signature\":\"sig_123\"}}\n\n", - "event: content_block_stop\n", - "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", - "event: content_block_start\n", - "data: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"text\",\"text\":\"Final answer\"}}\n\n", - "event: content_block_stop\n", - "data: {\"type\":\"content_block_stop\",\"index\":1}\n\n", - "event: message_delta\n", - "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n", - "event: message_stop\n", - "data: {\"type\":\"message_stop\"}\n\n", - "data: [DONE]\n\n" - ); - let server = spawn_server( - state, - vec![http_response("200 OK", "text/event-stream", sse)], - ) - .await; - - let client = ApiHttpClient::new("test-key").with_base_url(server.base_url()); - let mut stream = client - .stream_message(&sample_request(false)) - .await - .expect("stream should start"); - - let mut events = Vec::new(); - while let Some(event) = stream - .next_event() - .await - .expect("stream event should parse") - { - events.push(event); - } - - assert_eq!(events.len(), 9); - assert!(matches!( - &events[1], - StreamEvent::ContentBlockStart(ContentBlockStartEvent { - content_block: OutputContentBlock::Thinking { thinking, signature }, - .. - }) if thinking.is_empty() && signature.is_none() - )); - assert!(matches!( - &events[2], - StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { - delta: ContentBlockDelta::ThinkingDelta { thinking }, - .. - }) if thinking == "step 1" - )); - assert!(matches!( - &events[3], - StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { - delta: ContentBlockDelta::SignatureDelta { signature }, - .. - }) if signature == "sig_123" - )); - assert!(matches!( - &events[5], - StreamEvent::ContentBlockStart(ContentBlockStartEvent { - content_block: OutputContentBlock::Text { text }, - .. - }) if text == "Final answer" - )); - assert!(matches!(events[6], StreamEvent::ContentBlockStop(_))); - assert!(matches!(events[7], StreamEvent::MessageDelta(_))); - assert!(matches!(events[8], StreamEvent::MessageStop(_))); -} - #[tokio::test] async fn retries_retryable_failures_before_succeeding() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -303,7 +182,7 @@ async fn retries_retryable_failures_before_succeeding() { ) .await; - let client = ApiHttpClient::new("test-key") + let client = AnthropicClient::new("test-key") .with_base_url(server.base_url()) .with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2)); @@ -316,6 +195,47 @@ async fn retries_retryable_failures_before_succeeding() { assert_eq!(state.lock().await.len(), 2); } +#[tokio::test] +async fn provider_client_dispatches_anthropic_requests() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + )], + ) + .await; + + let client = ProviderClient::from_model_with_anthropic_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("test-key".to_string())), + ) + .expect("anthropic provider client should be constructed"); + let client = match client { + ProviderClient::Anthropic(client) => { + ProviderClient::Anthropic(client.with_base_url(server.base_url())) + } + other => panic!("expected anthropic provider, got {other:?}"), + }; + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 5); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/v1/messages"); + assert_eq!( + request.headers.get("x-api-key").map(String::as_str), + Some("test-key") + ); +} + #[tokio::test] async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -336,7 +256,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { ) .await; - let client = ApiHttpClient::new("test-key") + let client = AnthropicClient::new("test-key") .with_base_url(server.base_url()) .with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2)); @@ -367,7 +287,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY and network access"] async fn live_stream_smoke_test() { - let client = ApiHttpClient::from_env().expect("ANTHROPIC_API_KEY must be set"); + let client = AnthropicClient::from_env().expect("ANTHROPIC_API_KEY must be set"); let mut stream = client .stream_message(&MessageRequest { model: std::env::var("ANTHROPIC_MODEL") diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs new file mode 100644 index 0000000..b345b1f --- /dev/null +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -0,0 +1,415 @@ +use std::collections::HashMap; +use std::ffi::OsString; +use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; + +use api::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, +}; +use serde_json::json; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; + +#[tokio::test] +async fn send_message_uses_openai_compatible_endpoint_and_auth() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_test\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.model, "grok-3"); + assert_eq!(response.total_tokens(), 16); + assert_eq!( + response.content, + vec![OutputContentBlock::Text { + text: "Hello from Grok".to_string(), + }] + ); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["model"], json!("grok-3")); + assert_eq!(body["messages"][0]["role"], json!("system")); + assert_eq!(body["tools"][0]["type"], json!("function")); +} + +#[tokio::test] +async fn send_message_accepts_full_chat_completions_endpoint_override() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_full_endpoint\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let endpoint_url = format!("{}/chat/completions", server.base_url()); + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(endpoint_url); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.total_tokens(), 10); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); +} + +#[tokio::test] +async fn stream_message_normalizes_text_and_multiple_tool_calls() { + let state = Arc::new(Mutex::new(Vec::::new())); + let sse = concat!( + "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n", + "data: [DONE]\n\n" + ); + let server = spawn_server( + state.clone(), + vec![http_response_with_headers( + "200 OK", + "text/event-stream", + sse, + &[("x-request-id", "req_grok_stream")], + )], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let mut stream = client + .stream_message(&sample_request(false)) + .await + .expect("stream should start"); + + assert_eq!(stream.request_id(), Some("req_grok_stream")); + + let mut events = Vec::new(); + while let Some(event) = stream.next_event().await.expect("event should parse") { + events.push(event); + } + + assert!(matches!(events[0], StreamEvent::MessageStart(_))); + assert!(matches!( + events[1], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + content_block: OutputContentBlock::Text { .. }, + .. + }) + )); + assert!(matches!( + events[2], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + delta: ContentBlockDelta::TextDelta { .. }, + .. + }) + )); + assert!(matches!( + events[3], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 1, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[4], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 1, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[5], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 2, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[6], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 2, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[7], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 }) + )); + assert!(matches!( + events[8], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 }) + )); + assert!(matches!( + events[9], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 }) + )); + assert!(matches!(events[10], StreamEvent::MessageDelta(_))); + assert!(matches!(events[11], StreamEvent::MessageStop(_))); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert!(request.body.contains("\"stream\":true")); +} + +#[tokio::test] +async fn provider_client_dispatches_xai_requests_from_env() { + let _lock = env_lock(); + let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key"); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}", + )], + ) + .await; + let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url()); + + let client = + ProviderClient::from_model("grok").expect("xAI provider client should be constructed"); + assert!(matches!(client, ProviderClient::Xai(_))); + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 13); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CapturedRequest { + path: String, + headers: HashMap, + body: String, +} + +struct TestServer { + base_url: String, + join_handle: tokio::task::JoinHandle<()>, +} + +impl TestServer { + fn base_url(&self) -> String { + self.base_url.clone() + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.join_handle.abort(); + } +} + +async fn spawn_server( + state: Arc>>, + responses: Vec, +) -> TestServer { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let address = listener.local_addr().expect("listener addr"); + let join_handle = tokio::spawn(async move { + for response in responses { + let (mut socket, _) = listener.accept().await.expect("accept"); + let mut buffer = Vec::new(); + let mut header_end = None; + loop { + let mut chunk = [0_u8; 1024]; + let read = socket.read(&mut chunk).await.expect("read request"); + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if let Some(position) = find_header_end(&buffer) { + header_end = Some(position); + break; + } + } + + let header_end = header_end.expect("headers should exist"); + let (header_bytes, remaining) = buffer.split_at(header_end); + let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers"); + let mut lines = header_text.split("\r\n"); + let request_line = lines.next().expect("request line"); + let path = request_line + .split_whitespace() + .nth(1) + .expect("path") + .to_string(); + let mut headers = HashMap::new(); + let mut content_length = 0_usize; + for line in lines { + if line.is_empty() { + continue; + } + let (name, value) = line.split_once(':').expect("header"); + let value = value.trim().to_string(); + if name.eq_ignore_ascii_case("content-length") { + content_length = value.parse().expect("content length"); + } + headers.insert(name.to_ascii_lowercase(), value); + } + + let mut body = remaining[4..].to_vec(); + while body.len() < content_length { + let mut chunk = vec![0_u8; content_length - body.len()]; + let read = socket.read(&mut chunk).await.expect("read body"); + if read == 0 { + break; + } + body.extend_from_slice(&chunk[..read]); + } + + state.lock().await.push(CapturedRequest { + path, + headers, + body: String::from_utf8(body).expect("utf8 body"), + }); + + socket + .write_all(response.as_bytes()) + .await + .expect("write response"); + } + }); + + TestServer { + base_url: format!("http://{address}"), + join_handle, + } +} + +fn find_header_end(bytes: &[u8]) -> Option { + bytes.windows(4).position(|window| window == b"\r\n\r\n") +} + +fn http_response(status: &str, content_type: &str, body: &str) -> String { + http_response_with_headers(status, content_type, body, &[]) +} + +fn http_response_with_headers( + status: &str, + content_type: &str, + body: &str, + headers: &[(&str, &str)], +) -> String { + let mut extra_headers = String::new(); + for (name, value) in headers { + use std::fmt::Write as _; + write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write"); + } + format!( + "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ) +} + +fn sample_request(stream: bool) -> MessageRequest { + MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "Say hello".to_string(), + }], + }], + system: Some("Use tools when needed".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Fetches weather".to_string()), + input_schema: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + }), + }]), + tool_choice: Some(ToolChoice::Auto), + stream, + } +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct ScopedEnvVar { + key: &'static str, + previous: Option, +} + +impl ScopedEnvVar { + fn set(key: &'static str, value: impl AsRef) -> Self { + let previous = std::env::var_os(key); + std::env::set_var(key, value); + Self { key, previous } + } +} + +impl Drop for ScopedEnvVar { + fn drop(&mut self) { + match &self.previous { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/rust/crates/api/tests/provider_client_integration.rs b/rust/crates/api/tests/provider_client_integration.rs new file mode 100644 index 0000000..204bf35 --- /dev/null +++ b/rust/crates/api/tests/provider_client_integration.rs @@ -0,0 +1,86 @@ +use std::ffi::OsString; +use std::sync::{Mutex, OnceLock}; + +use api::{read_xai_base_url, ApiError, AuthSource, ProviderClient, ProviderKind}; + +#[test] +fn provider_client_routes_grok_aliases_through_xai() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key")); + + let client = ProviderClient::from_model("grok-mini").expect("grok alias should resolve"); + + assert_eq!(client.provider_kind(), ProviderKind::Xai); +} + +#[test] +fn provider_client_reports_missing_xai_credentials_for_grok_models() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", None); + + let error = ProviderClient::from_model("grok-3") + .expect_err("grok requests without XAI_API_KEY should fail fast"); + + match error { + ApiError::MissingCredentials { provider, env_vars } => { + assert_eq!(provider, "xAI"); + assert_eq!(env_vars, &["XAI_API_KEY"]); + } + other => panic!("expected missing xAI credentials, got {other:?}"), + } +} + +#[test] +fn provider_client_uses_explicit_anthropic_auth_without_env_lookup() { + let _lock = env_lock(); + let _anthropic_api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None); + let _anthropic_auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None); + + let client = ProviderClient::from_model_with_anthropic_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("anthropic-test-key".to_string())), + ) + .expect("explicit anthropic auth should avoid env lookup"); + + assert_eq!(client.provider_kind(), ProviderKind::Anthropic); +} + +#[test] +fn read_xai_base_url_prefers_env_override() { + let _lock = env_lock(); + let _xai_base_url = EnvVarGuard::set("XAI_BASE_URL", Some("https://example.xai.test/v1")); + + assert_eq!(read_xai_base_url(), "https://example.xai.test/v1"); +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct EnvVarGuard { + key: &'static str, + original: Option, +} + +impl EnvVarGuard { + fn set(key: &'static str, value: Option<&str>) -> Self { + let original = std::env::var_os(key); + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + Self { key, original } + } +} + +impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.original { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index a73f2f4..1abdce4 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -1,20 +1,15 @@ use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; -use plugins::{HookRunner as PluginHookRunner, PluginRegistry}; - use crate::compact::{ compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, }; use crate::config::RuntimeFeatureConfig; -use crate::hooks::HookRunner; +use crate::hooks::{HookRunResult, HookRunner}; use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter}; use crate::session::{ContentBlock, ConversationMessage, Session}; use crate::usage::{TokenUsage, UsageTracker}; -const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 200_000; -const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS"; - #[derive(Debug, Clone, PartialEq, Eq)] pub struct ApiRequest { pub system_prompt: Vec, @@ -91,12 +86,6 @@ pub struct TurnSummary { pub tool_results: Vec, pub iterations: usize, pub usage: TokenUsage, - pub auto_compaction: Option, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct AutoCompactionEvent { - pub removed_message_count: usize, } pub struct ConversationRuntime { @@ -108,25 +97,6 @@ pub struct ConversationRuntime { max_iterations: usize, usage_tracker: UsageTracker, hook_runner: HookRunner, - auto_compaction_input_tokens_threshold: u32, - plugin_hook_runner: Option, - plugin_registry: Option, - plugins_shutdown: bool, -} - -impl ConversationRuntime { - fn shutdown_registered_plugins(&mut self) -> Result<(), RuntimeError> { - if self.plugins_shutdown { - return Ok(()); - } - if let Some(registry) = &self.plugin_registry { - registry - .shutdown() - .map_err(|error| RuntimeError::new(format!("plugin shutdown failed: {error}")))?; - } - self.plugins_shutdown = true; - Ok(()) - } } impl ConversationRuntime @@ -148,19 +118,18 @@ where tool_executor, permission_policy, system_prompt, - RuntimeFeatureConfig::default(), + &RuntimeFeatureConfig::default(), ) } #[must_use] - #[allow(clippy::needless_pass_by_value)] pub fn new_with_features( session: Session, api_client: C, tool_executor: T, permission_policy: PermissionPolicy, system_prompt: Vec, - feature_config: RuntimeFeatureConfig, + feature_config: &RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); Self { @@ -171,57 +140,16 @@ where system_prompt, max_iterations: usize::MAX, usage_tracker, - hook_runner: HookRunner::from_feature_config(&feature_config), - auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(), - plugin_hook_runner: None, - plugin_registry: None, - plugins_shutdown: false, + hook_runner: HookRunner::from_feature_config(feature_config), } } - #[allow(clippy::needless_pass_by_value)] - pub fn new_with_plugins( - session: Session, - api_client: C, - tool_executor: T, - permission_policy: PermissionPolicy, - system_prompt: Vec, - feature_config: RuntimeFeatureConfig, - plugin_registry: PluginRegistry, - ) -> Result { - let plugin_hook_runner = - PluginHookRunner::from_registry(&plugin_registry).map_err(|error| { - RuntimeError::new(format!("plugin hook registration failed: {error}")) - })?; - plugin_registry - .initialize() - .map_err(|error| RuntimeError::new(format!("plugin initialization failed: {error}")))?; - let mut runtime = Self::new_with_features( - session, - api_client, - tool_executor, - permission_policy, - system_prompt, - feature_config, - ); - runtime.plugin_hook_runner = Some(plugin_hook_runner); - runtime.plugin_registry = Some(plugin_registry); - Ok(runtime) - } - #[must_use] pub fn with_max_iterations(mut self, max_iterations: usize) -> Self { self.max_iterations = max_iterations; self } - #[must_use] - pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self { - self.auto_compaction_input_tokens_threshold = threshold; - self - } - - #[allow(clippy::too_many_lines)] pub fn run_turn( &mut self, user_input: impl Into, @@ -234,7 +162,6 @@ where let mut assistant_messages = Vec::new(); let mut tool_results = Vec::new(); let mut iterations = 0; - let mut max_turn_input_tokens = 0; loop { iterations += 1; @@ -251,7 +178,6 @@ where let events = self.api_client.stream(request)?; let (assistant_message, usage) = build_assistant_message(events)?; if let Some(usage) = usage { - max_turn_input_tokens = max_turn_input_tokens.max(usage.input_tokens); self.usage_tracker.record(usage); } let pending_tool_uses = assistant_message @@ -288,74 +214,35 @@ where ConversationMessage::tool_result( tool_use_id, tool_name, - format_hook_message(pre_hook_result.messages(), &deny_message), + format_hook_message(&pre_hook_result, &deny_message), true, ) } else { - let plugin_pre_hook_result = - self.run_plugin_pre_tool_use(&tool_name, &input); - if plugin_pre_hook_result.is_denied() { - let deny_message = - format!("PreToolUse hook denied tool `{tool_name}`"); - let mut messages = pre_hook_result.messages().to_vec(); - messages.extend(plugin_pre_hook_result.messages().iter().cloned()); - ConversationMessage::tool_result( - tool_use_id, - tool_name, - format_hook_message(&messages, &deny_message), - true, - ) - } else { - let (mut output, mut is_error) = - match self.tool_executor.execute(&tool_name, &input) { - Ok(output) => (output, false), - Err(error) => (error.to_string(), true), - }; - output = - merge_hook_feedback(pre_hook_result.messages(), output, false); - output = merge_hook_feedback( - plugin_pre_hook_result.messages(), - output, - false, - ); + let (mut output, mut is_error) = + match self.tool_executor.execute(&tool_name, &input) { + Ok(output) => (output, false), + Err(error) => (error.to_string(), true), + }; + output = merge_hook_feedback(pre_hook_result.messages(), output, false); - let hook_output = output.clone(); - let post_hook_result = self.hook_runner.run_post_tool_use( - &tool_name, - &input, - &hook_output, - is_error, - ); - let plugin_post_hook_result = self.run_plugin_post_tool_use( - &tool_name, - &input, - &hook_output, - is_error, - ); - if post_hook_result.is_denied() { - is_error = true; - } - if plugin_post_hook_result.is_denied() { - is_error = true; - } - output = merge_hook_feedback( - post_hook_result.messages(), - output, - post_hook_result.is_denied(), - ); - output = merge_hook_feedback( - plugin_post_hook_result.messages(), - output, - plugin_post_hook_result.is_denied(), - ); - - ConversationMessage::tool_result( - tool_use_id, - tool_name, - output, - is_error, - ) + let post_hook_result = self + .hook_runner + .run_post_tool_use(&tool_name, &input, &output, is_error); + if post_hook_result.is_denied() { + is_error = true; } + output = merge_hook_feedback( + post_hook_result.messages(), + output, + post_hook_result.is_denied(), + ); + + ConversationMessage::tool_result( + tool_use_id, + tool_name, + output, + is_error, + ) } } PermissionOutcome::Deny { reason } => { @@ -367,14 +254,11 @@ where } } - let auto_compaction = self.maybe_auto_compact(max_turn_input_tokens); - Ok(TurnSummary { assistant_messages, tool_results, iterations, usage: self.usage_tracker.cumulative_usage(), - auto_compaction, }) } @@ -399,81 +283,9 @@ where } #[must_use] - pub fn into_session(mut self) -> Session { - let _ = self.shutdown_registered_plugins(); - std::mem::take(&mut self.session) + pub fn into_session(self) -> Session { + self.session } - - pub fn shutdown_plugins(&mut self) -> Result<(), RuntimeError> { - self.shutdown_registered_plugins() - } - - fn run_plugin_pre_tool_use(&self, tool_name: &str, input: &str) -> plugins::HookRunResult { - self.plugin_hook_runner.as_ref().map_or_else( - || plugins::HookRunResult::allow(Vec::new()), - |runner| runner.run_pre_tool_use(tool_name, input), - ) - } - - fn run_plugin_post_tool_use( - &self, - tool_name: &str, - input: &str, - output: &str, - is_error: bool, - ) -> plugins::HookRunResult { - self.plugin_hook_runner.as_ref().map_or_else( - || plugins::HookRunResult::allow(Vec::new()), - |runner| runner.run_post_tool_use(tool_name, input, output, is_error), - ) - } - - fn maybe_auto_compact(&mut self, turn_input_tokens: u32) -> Option { - if turn_input_tokens < self.auto_compaction_input_tokens_threshold { - return None; - } - - let result = compact_session( - &self.session, - CompactionConfig { - max_estimated_tokens: usize::try_from(self.auto_compaction_input_tokens_threshold) - .unwrap_or(usize::MAX), - ..CompactionConfig::default() - }, - ); - - if result.removed_message_count == 0 { - return None; - } - - self.session = result.compacted_session; - Some(AutoCompactionEvent { - removed_message_count: result.removed_message_count, - }) - } -} - -impl Drop for ConversationRuntime { - fn drop(&mut self) { - let _ = self.shutdown_registered_plugins(); - } -} - -#[must_use] -pub fn auto_compaction_threshold_from_env() -> u32 { - parse_auto_compaction_threshold( - std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR) - .ok() - .as_deref(), - ) -} - -#[must_use] -fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 { - value - .and_then(|raw| raw.trim().parse::().ok()) - .filter(|threshold| *threshold > 0) - .unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD) } fn build_assistant_message( @@ -523,11 +335,11 @@ fn flush_text_block(text: &mut String, blocks: &mut Vec) { } } -fn format_hook_message(messages: &[String], fallback: &str) -> String { - if messages.is_empty() { +fn format_hook_message(result: &HookRunResult, fallback: &str) -> String { + if result.messages().is_empty() { fallback.to_string() } else { - messages.join("\n") + result.messages().join("\n") } } @@ -584,9 +396,8 @@ impl ToolExecutor for StaticToolExecutor { #[cfg(test)] mod tests { use super::{ - parse_auto_compaction_threshold, ApiClient, ApiRequest, AssistantEvent, - AutoCompactionEvent, ConversationRuntime, RuntimeError, StaticToolExecutor, - DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD, + ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, + StaticToolExecutor, }; use crate::compact::CompactionConfig; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; @@ -597,13 +408,7 @@ mod tests { use crate::prompt::{ProjectContext, SystemPromptBuilder}; use crate::session::{ContentBlock, MessageRole, Session}; use crate::usage::TokenUsage; - use plugins::{PluginManager, PluginManagerConfig}; - use std::fs; - #[cfg(unix)] - use std::os::unix::fs::PermissionsExt; - use std::path::Path; use std::path::PathBuf; - use std::time::{SystemTime, UNIX_EPOCH}; struct ScriptedApiClient { call_count: usize, @@ -665,68 +470,6 @@ mod tests { } } - fn temp_dir(label: &str) -> PathBuf { - let nanos = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time should be after epoch") - .as_nanos(); - std::env::temp_dir().join(format!("runtime-plugin-{label}-{nanos}")) - } - - fn write_lifecycle_plugin(root: &Path, name: &str) -> PathBuf { - fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); - fs::create_dir_all(root.join("lifecycle")).expect("lifecycle dir"); - let log_path = root.join("lifecycle.log"); - fs::write( - root.join("lifecycle").join("init.sh"), - "#!/bin/sh\nprintf 'init\\n' >> lifecycle.log\n", - ) - .expect("write init script"); - fs::write( - root.join("lifecycle").join("shutdown.sh"), - "#!/bin/sh\nprintf 'shutdown\\n' >> lifecycle.log\n", - ) - .expect("write shutdown script"); - fs::write( - root.join(".claude-plugin").join("plugin.json"), - format!( - "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"runtime lifecycle plugin\",\n \"lifecycle\": {{\n \"Init\": [\"./lifecycle/init.sh\"],\n \"Shutdown\": [\"./lifecycle/shutdown.sh\"]\n }}\n}}" - ), - ) - .expect("write plugin manifest"); - log_path - } - - fn write_hook_plugin(root: &Path, name: &str, pre_message: &str, post_message: &str) { - fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); - fs::create_dir_all(root.join("hooks")).expect("hooks dir"); - fs::write( - root.join("hooks").join("pre.sh"), - format!("#!/bin/sh\nprintf '%s\\n' '{pre_message}'\n"), - ) - .expect("write pre hook"); - fs::write( - root.join("hooks").join("post.sh"), - format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"), - ) - .expect("write post hook"); - #[cfg(unix)] - { - let exec_mode = fs::Permissions::from_mode(0o755); - fs::set_permissions(root.join("hooks").join("pre.sh"), exec_mode.clone()) - .expect("chmod pre hook"); - fs::set_permissions(root.join("hooks").join("post.sh"), exec_mode) - .expect("chmod post hook"); - } - fs::write( - root.join(".claude-plugin").join("plugin.json"), - format!( - "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"runtime hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" - ), - ) - .expect("write plugin manifest"); - } - #[test] fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() { let api_client = ScriptedApiClient { call_count: 0 }; @@ -765,7 +508,6 @@ mod tests { assert_eq!(summary.tool_results.len(), 1); assert_eq!(runtime.session().messages.len(), 4); assert_eq!(summary.usage.output_tokens, 10); - assert_eq!(summary.auto_compaction, None); assert!(matches!( runtime.session().messages[1].blocks[1], ContentBlock::ToolUse { .. } @@ -867,7 +609,7 @@ mod tests { }), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), )), @@ -933,7 +675,7 @@ mod tests { StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'post hook ran'")], )), @@ -968,153 +710,6 @@ mod tests { ); } - #[test] - fn initializes_and_shuts_down_plugins_with_runtime_lifecycle() { - let config_home = temp_dir("config"); - let source_root = temp_dir("source"); - let _ = write_lifecycle_plugin(&source_root, "runtime-lifecycle"); - - let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); - let install = manager - .install(source_root.to_str().expect("utf8 path")) - .expect("install should succeed"); - let log_path = install.install_path.join("lifecycle.log"); - let registry = manager.plugin_registry().expect("registry should load"); - - { - let runtime = ConversationRuntime::new_with_plugins( - Session::new(), - ScriptedApiClient { call_count: 0 }, - StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), - PermissionPolicy::new(PermissionMode::WorkspaceWrite), - vec!["system".to_string()], - RuntimeFeatureConfig::default(), - registry, - ) - .expect("runtime should initialize plugins"); - - let log = fs::read_to_string(&log_path).expect("init log should exist"); - assert_eq!(log, "init\n"); - drop(runtime); - } - - let log = fs::read_to_string(&log_path).expect("shutdown log should exist"); - assert_eq!(log, "init\nshutdown\n"); - - let _ = fs::remove_dir_all(config_home); - let _ = fs::remove_dir_all(source_root); - } - - #[test] - fn executes_hooks_from_installed_plugins_during_tool_use() { - struct TwoCallApiClient { - calls: usize, - } - - impl ApiClient for TwoCallApiClient { - fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { - self.calls += 1; - match self.calls { - 1 => Ok(vec![ - AssistantEvent::ToolUse { - id: "tool-1".to_string(), - name: "add".to_string(), - input: r#"{"lhs":2,"rhs":2}"#.to_string(), - }, - AssistantEvent::MessageStop, - ]), - 2 => { - assert!(request - .messages - .iter() - .any(|message| message.role == MessageRole::Tool)); - Ok(vec![ - AssistantEvent::TextDelta("done".to_string()), - AssistantEvent::MessageStop, - ]) - } - _ => Err(RuntimeError::new("unexpected extra API call")), - } - } - } - - let config_home = temp_dir("hook-config"); - let first_source_root = temp_dir("hook-source-a"); - let second_source_root = temp_dir("hook-source-b"); - write_hook_plugin( - &first_source_root, - "first", - "plugin pre one", - "plugin post one", - ); - write_hook_plugin( - &second_source_root, - "second", - "plugin pre two", - "plugin post two", - ); - - let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); - manager - .install(first_source_root.to_str().expect("utf8 path")) - .expect("first plugin install should succeed"); - manager - .install(second_source_root.to_str().expect("utf8 path")) - .expect("second plugin install should succeed"); - let registry = manager.plugin_registry().expect("registry should load"); - - let mut runtime = ConversationRuntime::new_with_plugins( - Session::new(), - TwoCallApiClient { calls: 0 }, - StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), - PermissionPolicy::new(PermissionMode::DangerFullAccess), - vec!["system".to_string()], - RuntimeFeatureConfig::default(), - registry, - ) - .expect("runtime should load plugin hooks"); - - let summary = runtime - .run_turn("use add", None) - .expect("tool loop succeeds"); - - assert_eq!(summary.tool_results.len(), 1); - let ContentBlock::ToolResult { - is_error, output, .. - } = &summary.tool_results[0].blocks[0] - else { - panic!("expected tool result block"); - }; - assert!( - !*is_error, - "plugin hooks should not force an error: {output:?}" - ); - assert!( - output.contains('4'), - "tool output missing value: {output:?}" - ); - assert!( - output.contains("plugin pre one"), - "tool output missing first pre hook feedback: {output:?}" - ); - assert!( - output.contains("plugin pre two"), - "tool output missing second pre hook feedback: {output:?}" - ); - assert!( - output.contains("plugin post one"), - "tool output missing first post hook feedback: {output:?}" - ); - assert!( - output.contains("plugin post two"), - "tool output missing second post hook feedback: {output:?}" - ); - - let _ = fs::remove_dir_all(config_home); - let _ = fs::remove_dir_all(first_source_root); - let _ = fs::remove_dir_all(second_source_root); - } - #[test] fn reconstructs_usage_tracker_from_restored_session() { struct SimpleApi; @@ -1203,177 +798,4 @@ mod tests { fn shell_snippet(script: &str) -> String { script.to_string() } - - #[test] - fn auto_compacts_when_turn_input_threshold_is_crossed() { - struct SimpleApi; - impl ApiClient for SimpleApi { - fn stream( - &mut self, - _request: ApiRequest, - ) -> Result, RuntimeError> { - Ok(vec![ - AssistantEvent::TextDelta("done".to_string()), - AssistantEvent::Usage(TokenUsage { - input_tokens: 120_000, - output_tokens: 4, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }), - AssistantEvent::MessageStop, - ]) - } - } - - let session = Session { - version: 1, - messages: vec![ - crate::session::ConversationMessage::user_text("one ".repeat(30_000)), - crate::session::ConversationMessage::assistant(vec![ContentBlock::Text { - text: "two ".repeat(30_000), - }]), - crate::session::ConversationMessage::user_text("three ".repeat(30_000)), - crate::session::ConversationMessage::assistant(vec![ContentBlock::Text { - text: "four ".repeat(30_000), - }]), - ], - }; - - let mut runtime = ConversationRuntime::new( - session, - SimpleApi, - StaticToolExecutor::new(), - PermissionPolicy::new(PermissionMode::DangerFullAccess), - vec!["system".to_string()], - ) - .with_auto_compaction_input_tokens_threshold(100_000); - - let summary = runtime - .run_turn("trigger", None) - .expect("turn should succeed"); - - assert_eq!( - summary.auto_compaction, - Some(AutoCompactionEvent { - removed_message_count: 2, - }) - ); - assert_eq!(runtime.session().messages[0].role, MessageRole::System); - } - - #[test] - fn auto_compaction_does_not_repeat_after_context_is_already_compacted() { - struct SequentialUsageApi { - call_count: usize, - } - - impl ApiClient for SequentialUsageApi { - fn stream( - &mut self, - _request: ApiRequest, - ) -> Result, RuntimeError> { - self.call_count += 1; - let input_tokens = if self.call_count == 1 { 120_000 } else { 64 }; - Ok(vec![ - AssistantEvent::TextDelta("done".to_string()), - AssistantEvent::Usage(TokenUsage { - input_tokens, - output_tokens: 4, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }), - AssistantEvent::MessageStop, - ]) - } - } - - let session = Session { - version: 1, - messages: vec![ - crate::session::ConversationMessage::user_text("one ".repeat(30_000)), - crate::session::ConversationMessage::assistant(vec![ContentBlock::Text { - text: "two ".repeat(30_000), - }]), - crate::session::ConversationMessage::user_text("three ".repeat(30_000)), - crate::session::ConversationMessage::assistant(vec![ContentBlock::Text { - text: "four ".repeat(30_000), - }]), - ], - }; - - let mut runtime = ConversationRuntime::new( - session, - SequentialUsageApi { call_count: 0 }, - StaticToolExecutor::new(), - PermissionPolicy::new(PermissionMode::DangerFullAccess), - vec!["system".to_string()], - ) - .with_auto_compaction_input_tokens_threshold(100_000); - - let first = runtime - .run_turn("trigger", None) - .expect("first turn should succeed"); - assert_eq!( - first.auto_compaction, - Some(AutoCompactionEvent { - removed_message_count: 2, - }) - ); - - let second = runtime - .run_turn("continue", None) - .expect("second turn should succeed"); - assert_eq!(second.auto_compaction, None); - assert_eq!(runtime.session().messages[0].role, MessageRole::System); - } - - #[test] - fn skips_auto_compaction_below_threshold() { - struct SimpleApi; - impl ApiClient for SimpleApi { - fn stream( - &mut self, - _request: ApiRequest, - ) -> Result, RuntimeError> { - Ok(vec![ - AssistantEvent::TextDelta("done".to_string()), - AssistantEvent::Usage(TokenUsage { - input_tokens: 99_999, - output_tokens: 4, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }), - AssistantEvent::MessageStop, - ]) - } - } - - let mut runtime = ConversationRuntime::new( - Session::new(), - SimpleApi, - StaticToolExecutor::new(), - PermissionPolicy::new(PermissionMode::DangerFullAccess), - vec!["system".to_string()], - ) - .with_auto_compaction_input_tokens_threshold(100_000); - - let summary = runtime - .run_turn("trigger", None) - .expect("turn should succeed"); - assert_eq!(summary.auto_compaction, None); - assert_eq!(runtime.session().messages.len(), 2); - } - - #[test] - fn auto_compaction_threshold_defaults_and_parses_values() { - assert_eq!( - parse_auto_compaction_threshold(None), - DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD - ); - assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321); - assert_eq!( - parse_auto_compaction_threshold(Some("not-a-number")), - DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD - ); - } } diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs index 4aff002..63ef9ff 100644 --- a/rust/crates/runtime/src/hooks.rs +++ b/rust/crates/runtime/src/hooks.rs @@ -1,5 +1,4 @@ use std::ffi::OsStr; -use std::path::Path; use std::process::Command; use serde_json::json; @@ -52,6 +51,16 @@ pub struct HookRunner { config: RuntimeHookConfig, } +#[derive(Debug, Clone, Copy)] +struct HookCommandRequest<'a> { + event: HookEvent, + tool_name: &'a str, + tool_input: &'a str, + tool_output: Option<&'a str>, + is_error: bool, + payload: &'a str, +} + impl HookRunner { #[must_use] pub fn new(config: RuntimeHookConfig) -> Self { @@ -119,14 +128,16 @@ impl HookRunner { let mut messages = Vec::new(); for command in commands { - match self.run_command( + match Self::run_command( command, - event, - tool_name, - tool_input, - tool_output, - is_error, - &payload, + HookCommandRequest { + event, + tool_name, + tool_input, + tool_output, + is_error, + payload: &payload, + }, ) { HookCommandOutcome::Allow { message } => { if let Some(message) = message { @@ -150,30 +161,23 @@ impl HookRunner { HookRunResult::allow(messages) } - #[allow(clippy::too_many_arguments, clippy::unused_self)] - fn run_command( - &self, - command: &str, - event: HookEvent, - tool_name: &str, - tool_input: &str, - tool_output: Option<&str>, - is_error: bool, - payload: &str, - ) -> HookCommandOutcome { + fn run_command(command: &str, request: HookCommandRequest<'_>) -> HookCommandOutcome { let mut child = shell_command(command); child.stdin(std::process::Stdio::piped()); child.stdout(std::process::Stdio::piped()); child.stderr(std::process::Stdio::piped()); - child.env("HOOK_EVENT", event.as_str()); - child.env("HOOK_TOOL_NAME", tool_name); - child.env("HOOK_TOOL_INPUT", tool_input); - child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" }); - if let Some(tool_output) = tool_output { + child.env("HOOK_EVENT", request.event.as_str()); + child.env("HOOK_TOOL_NAME", request.tool_name); + child.env("HOOK_TOOL_INPUT", request.tool_input); + child.env( + "HOOK_TOOL_IS_ERROR", + if request.is_error { "1" } else { "0" }, + ); + if let Some(tool_output) = request.tool_output { child.env("HOOK_TOOL_OUTPUT", tool_output); } - match child.output_with_stdin(payload.as_bytes()) { + match child.output_with_stdin(request.payload.as_bytes()) { Ok(output) => { let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); @@ -191,16 +195,18 @@ impl HookRunner { }, None => HookCommandOutcome::Warn { message: format!( - "{} hook `{command}` terminated by signal while handling `{tool_name}`", - event.as_str() + "{} hook `{command}` terminated by signal while handling `{}`", + request.event.as_str(), + request.tool_name ), }, } } Err(error) => HookCommandOutcome::Warn { message: format!( - "{} hook `{command}` failed to start for `{tool_name}`: {error}", - event.as_str() + "{} hook `{command}` failed to start for `{}`: {error}", + request.event.as_str(), + request.tool_name ), }, } @@ -239,11 +245,7 @@ fn shell_command(command: &str) -> CommandWithStdin { }; #[cfg(not(windows))] - let command_builder = if Path::new(command).exists() { - let mut command_builder = Command::new("sh"); - command_builder.arg(command); - CommandWithStdin::new(command_builder) - } else { + let command_builder = { let mut command_builder = Command::new("sh"); command_builder.arg("-lc").arg(command); CommandWithStdin::new(command_builder) diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 3c96284..847f94f 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -2,32 +2,27 @@ mod init; mod input; mod render; -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet}; use std::env; -use std::fmt::Write as _; use std::fs; use std::io::{self, Read, Write}; use std::net::TcpListener; use std::path::{Path, PathBuf}; use std::process::Command; -use std::sync::mpsc::{self, RecvTimeoutError}; -use std::sync::{Arc, Mutex}; -use std::thread; -use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use std::time::{SystemTime, UNIX_EPOCH}; use api::{ - resolve_startup_auth_source, ApiHttpClient, AuthSource, ContentBlockDelta, InputContentBlock, - InputMessage, MessageRequest, MessageResponse, OutputContentBlock, + detect_provider_kind, max_tokens_for_model, resolve_model_alias, resolve_startup_auth_source, + AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, InputMessage, + MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use commands::{ - handle_plugins_slash_command, render_slash_command_help, resume_supported_slash_commands, - slash_command_specs, SlashCommand, + render_slash_command_help, resume_supported_slash_commands, slash_command_specs, SlashCommand, }; use compat_harness::{extract_manifest, UpstreamPaths}; use init::initialize_repo; -use plugins::{PluginManager, PluginManagerConfig, PluginRegistry}; use render::{MarkdownStreamState, Spinner, TerminalRenderer}; use runtime::{ clear_oauth_credentials, generate_pkce_pair, generate_state, load_system_prompt, @@ -38,22 +33,14 @@ use runtime::{ Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, }; use serde_json::json; -use tools::GlobalToolRegistry; +use tools::{execute_tool, mvp_tool_specs, ToolSpec}; const DEFAULT_MODEL: &str = "claude-opus-4-6"; -fn max_tokens_for_model(model: &str) -> u32 { - if model.contains("opus") { - 32_000 - } else { - 64_000 - } -} const DEFAULT_DATE: &str = "2026-03-31"; const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545; const VERSION: &str = env!("CARGO_PKG_VERSION"); const BUILD_TARGET: Option<&str> = option_env!("TARGET"); const GIT_SHA: Option<&str> = option_env!("GIT_SHA"); -const INTERNAL_PROGRESS_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(3); type AllowedToolSet = BTreeSet; @@ -204,7 +191,7 @@ fn parse_args(args: &[String]) -> Result { index += 1; } "-p" => { - // Claw Code compat: -p "prompt" = one-shot prompt + // Claude Code compat: -p "prompt" = one-shot prompt let prompt = args[index + 1..].join(" "); if prompt.trim().is_empty() { return Err("-p requires a prompt string".to_string()); @@ -218,7 +205,7 @@ fn parse_args(args: &[String]) -> Result { }); } "--print" => { - // Claw Code compat: --print makes output non-interactive + // Claude Code compat: --print makes output non-interactive output_format = CliOutputFormat::Text; index += 1; } @@ -295,34 +282,52 @@ fn parse_args(args: &[String]) -> Result { } } -fn resolve_model_alias(model: &str) -> &str { - match model { - "opus" => "claude-opus-4-6", - "sonnet" => "claude-sonnet-4-6", - "haiku" => "claude-haiku-4-5-20251213", - _ => model, - } -} - fn normalize_allowed_tools(values: &[String]) -> Result, String> { if values.is_empty() { return Ok(None); } - match current_tool_registry() { - Ok(registry) => registry.normalize_allowed_tools(values), - Err(_) => GlobalToolRegistry::builtin().normalize_allowed_tools(values), + + let canonical_names = mvp_tool_specs() + .into_iter() + .map(|spec| spec.name.to_string()) + .collect::>(); + let mut name_map = canonical_names + .iter() + .map(|name| (normalize_tool_name(name), name.clone())) + .collect::>(); + + for (alias, canonical) in [ + ("read", "read_file"), + ("write", "write_file"), + ("edit", "edit_file"), + ("glob", "glob_search"), + ("grep", "grep_search"), + ] { + name_map.insert(alias.to_string(), canonical.to_string()); } + + let mut allowed = AllowedToolSet::new(); + for value in values { + for token in value + .split(|ch: char| ch == ',' || ch.is_whitespace()) + .filter(|token| !token.is_empty()) + { + let normalized = normalize_tool_name(token); + let canonical = name_map.get(&normalized).ok_or_else(|| { + format!( + "unsupported tool in --allowedTools: {token} (expected one of: {})", + canonical_names.join(", ") + ) + })?; + allowed.insert(canonical.clone()); + } + } + + Ok(Some(allowed)) } -fn current_tool_registry() -> Result { - let cwd = env::current_dir().map_err(|error| error.to_string())?; - let loader = ConfigLoader::default_for(&cwd); - let runtime_config = loader.load().map_err(|error| error.to_string())?; - let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); - let plugin_tools = plugin_manager - .aggregated_tools() - .map_err(|error| error.to_string())?; - GlobalToolRegistry::with_plugin_tools(plugin_tools) +fn normalize_tool_name(value: &str) -> String { + value.trim().replace('-', "_").to_ascii_lowercase() } fn parse_permission_mode_arg(value: &str) -> Result { @@ -352,11 +357,11 @@ fn default_permission_mode() -> PermissionMode { .map_or(PermissionMode::DangerFullAccess, permission_mode_from_label) } -fn filter_tool_specs( - tool_registry: &GlobalToolRegistry, - allowed_tools: Option<&AllowedToolSet>, -) -> Vec { - tool_registry.definitions(allowed_tools) +fn filter_tool_specs(allowed_tools: Option<&AllowedToolSet>) -> Vec { + mvp_tool_specs() + .into_iter() + .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) + .collect() } fn parse_system_prompt_args(args: &[String]) -> Result { @@ -479,7 +484,7 @@ fn run_login() -> Result<(), Box> { return Err(io::Error::new(io::ErrorKind::InvalidData, "oauth state mismatch").into()); } - let client = ApiHttpClient::from_auth(AuthSource::None).with_base_url(api::read_base_url()); + let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(api::read_base_url()); let exchange_request = OAuthTokenExchangeRequest::from_config(oauth, code, state, pkce.verifier, redirect_uri); let runtime = tokio::runtime::Runtime::new()?; @@ -755,10 +760,6 @@ fn format_compact_report(removed: usize, resulting_messages: usize, skipped: boo } } -fn format_auto_compaction_notice(removed: usize) -> String { - format!("[auto-compacted: removed {removed} messages]") -} - fn parse_git_status_metadata(status: Option<&str>) -> (Option, Option) { let Some(status) = status else { return (None, None); @@ -897,18 +898,10 @@ fn run_resume_command( )), }) } - SlashCommand::Bughunter { .. } - | SlashCommand::Commit - | SlashCommand::Pr { .. } - | SlashCommand::Issue { .. } - | SlashCommand::Ultraplan { .. } - | SlashCommand::Teleport { .. } - | SlashCommand::DebugToolCall - | SlashCommand::Resume { .. } + SlashCommand::Resume { .. } | SlashCommand::Model { .. } | SlashCommand::Permissions { .. } | SlashCommand::Session { .. } - | SlashCommand::Plugins { .. } | SlashCommand::Unknown(_) => Err("unsupported resumed slash command".into()), } } @@ -972,7 +965,7 @@ struct LiveCli { allowed_tools: Option, permission_mode: PermissionMode, system_prompt: Vec, - runtime: ConversationRuntime, + runtime: ConversationRuntime, session: SessionHandle, } @@ -993,7 +986,6 @@ impl LiveCli { true, allowed_tools.clone(), permission_mode, - None, )?; let cli = Self { model, @@ -1043,19 +1035,13 @@ impl LiveCli { let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); match result { - Ok(summary) => { + Ok(_) => { spinner.finish( "✨ Done", TerminalRenderer::new().color_theme(), &mut stdout, )?; println!(); - if let Some(event) = summary.auto_compaction { - println!( - "{}", - format_auto_compaction_notice(event.removed_message_count) - ); - } self.persist_session()?; Ok(()) } @@ -1091,7 +1077,6 @@ impl LiveCli { false, self.allowed_tools.clone(), self.permission_mode, - None, )?; let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let summary = runtime.run_turn(input, Some(&mut permission_prompter))?; @@ -1103,10 +1088,6 @@ impl LiveCli { "message": final_assistant_text(&summary), "model": self.model, "iterations": summary.iterations, - "auto_compaction": summary.auto_compaction.map(|event| json!({ - "removed_messages": event.removed_message_count, - "notice": format_auto_compaction_notice(event.removed_message_count), - })), "tool_uses": collect_tool_uses(&summary), "tool_results": collect_tool_results(&summary), "usage": { @@ -1133,34 +1114,6 @@ impl LiveCli { self.print_status(); false } - SlashCommand::Bughunter { scope } => { - self.run_bughunter(scope.as_deref())?; - false - } - SlashCommand::Commit => { - self.run_commit()?; - true - } - SlashCommand::Pr { context } => { - self.run_pr(context.as_deref())?; - false - } - SlashCommand::Issue { context } => { - self.run_issue(context.as_deref())?; - false - } - SlashCommand::Ultraplan { task } => { - self.run_ultraplan(task.as_deref())?; - false - } - SlashCommand::Teleport { target } => { - self.run_teleport(target.as_deref())?; - false - } - SlashCommand::DebugToolCall => { - self.run_debug_tool_call()?; - false - } SlashCommand::Compact => { self.compact()?; false @@ -1200,9 +1153,6 @@ impl LiveCli { SlashCommand::Session { action, target } => { self.handle_session_command(action.as_deref(), target.as_deref())? } - SlashCommand::Plugins { action, target } => { - self.handle_plugins_command(action.as_deref(), target.as_deref())? - } SlashCommand::Unknown(name) => { eprintln!("unknown slash command: /{name}"); false @@ -1273,7 +1223,6 @@ impl LiveCli { true, self.allowed_tools.clone(), self.permission_mode, - None, )?; self.model.clone_from(&model); println!( @@ -1317,7 +1266,6 @@ impl LiveCli { true, self.allowed_tools.clone(), self.permission_mode, - None, )?; println!( "{}", @@ -1343,7 +1291,6 @@ impl LiveCli { true, self.allowed_tools.clone(), self.permission_mode, - None, )?; println!( "Session cleared\n Mode fresh session\n Preserved model {}\n Permission mode {}\n Session {}", @@ -1379,7 +1326,6 @@ impl LiveCli { true, self.allowed_tools.clone(), self.permission_mode, - None, )?; self.session = handle; println!( @@ -1452,7 +1398,6 @@ impl LiveCli { true, self.allowed_tools.clone(), self.permission_mode, - None, )?; self.session = handle; println!( @@ -1470,37 +1415,6 @@ impl LiveCli { } } - fn handle_plugins_command( - &mut self, - action: Option<&str>, - target: Option<&str>, - ) -> Result> { - let cwd = env::current_dir()?; - let loader = ConfigLoader::default_for(&cwd); - let runtime_config = loader.load()?; - let mut manager = build_plugin_manager(&cwd, &loader, &runtime_config); - let result = handle_plugins_slash_command(action, target, &mut manager)?; - println!("{}", result.message); - if result.reload_runtime { - self.reload_runtime_features()?; - } - Ok(false) - } - - fn reload_runtime_features(&mut self) -> Result<(), Box> { - self.runtime = build_runtime( - self.runtime.session().clone(), - self.model.clone(), - self.system_prompt.clone(), - true, - true, - self.allowed_tools.clone(), - self.permission_mode, - None, - )?; - self.persist_session() - } - fn compact(&mut self) -> Result<(), Box> { let result = self.runtime.compact(CompactionConfig::default()); let removed = result.removed_message_count; @@ -1514,196 +1428,16 @@ impl LiveCli { true, self.allowed_tools.clone(), self.permission_mode, - None, )?; self.persist_session()?; println!("{}", format_compact_report(removed, kept, skipped)); Ok(()) } - - fn run_internal_prompt_text_with_progress( - &self, - prompt: &str, - enable_tools: bool, - progress: Option, - ) -> Result> { - let session = self.runtime.session().clone(); - let mut runtime = build_runtime( - session, - self.model.clone(), - self.system_prompt.clone(), - enable_tools, - false, - self.allowed_tools.clone(), - self.permission_mode, - progress, - )?; - let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); - let summary = runtime.run_turn(prompt, Some(&mut permission_prompter))?; - Ok(final_assistant_text(&summary).trim().to_string()) - } - - fn run_internal_prompt_text( - &self, - prompt: &str, - enable_tools: bool, - ) -> Result> { - self.run_internal_prompt_text_with_progress(prompt, enable_tools, None) - } - - fn run_bughunter(&self, scope: Option<&str>) -> Result<(), Box> { - let scope = scope.unwrap_or("the current repository"); - let prompt = format!( - "You are /bughunter. Inspect {scope} and identify the most likely bugs or correctness issues. Prioritize concrete findings with file paths, severity, and suggested fixes. Use tools if needed." - ); - println!("{}", self.run_internal_prompt_text(&prompt, true)?); - Ok(()) - } - - fn run_ultraplan(&self, task: Option<&str>) -> Result<(), Box> { - let task = task.unwrap_or("the current repo work"); - let prompt = format!( - "You are /ultraplan. Produce a deep multi-step execution plan for {task}. Include goals, risks, implementation sequence, verification steps, and rollback considerations. Use tools if needed." - ); - let mut progress = InternalPromptProgressRun::start_ultraplan(task); - match self.run_internal_prompt_text_with_progress( - &prompt, - true, - Some(progress.reporter()), - ) { - Ok(plan) => { - progress.finish_success(); - println!("{plan}"); - Ok(()) - } - Err(error) => { - progress.finish_failure(&error.to_string()); - Err(error) - } - } - } - - #[allow(clippy::unused_self)] - fn run_teleport(&self, target: Option<&str>) -> Result<(), Box> { - let Some(target) = target.map(str::trim).filter(|value| !value.is_empty()) else { - println!("Usage: /teleport "); - return Ok(()); - }; - - println!("{}", render_teleport_report(target)?); - Ok(()) - } - - fn run_debug_tool_call(&self) -> Result<(), Box> { - println!("{}", render_last_tool_debug_report(self.runtime.session())?); - Ok(()) - } - - fn run_commit(&mut self) -> Result<(), Box> { - let status = git_output(&["status", "--short"])?; - if status.trim().is_empty() { - println!("Commit\n Result skipped\n Reason no workspace changes"); - return Ok(()); - } - - git_status_ok(&["add", "-A"])?; - let staged_stat = git_output(&["diff", "--cached", "--stat"])?; - let prompt = format!( - "Generate a git commit message in plain text Lore format only. Base it on this staged diff summary:\n\n{}\n\nRecent conversation context:\n{}", - truncate_for_prompt(&staged_stat, 8_000), - recent_user_context(self.runtime.session(), 6) - ); - let message = sanitize_generated_message(&self.run_internal_prompt_text(&prompt, false)?); - if message.trim().is_empty() { - return Err("generated commit message was empty".into()); - } - - let path = write_temp_text_file("claw-commit-message.txt", &message)?; - let output = Command::new("git") - .args(["commit", "--file"]) - .arg(&path) - .current_dir(env::current_dir()?) - .output()?; - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - return Err(format!("git commit failed: {stderr}").into()); - } - - println!( - "Commit\n Result created\n Message file {}\n\n{}", - path.display(), - message.trim() - ); - Ok(()) - } - - fn run_pr(&self, context: Option<&str>) -> Result<(), Box> { - let staged = git_output(&["diff", "--stat"])?; - let prompt = format!( - "Generate a pull request title and body from this conversation and diff summary. Output plain text in this format exactly:\nTITLE: \nBODY:\n<body markdown>\n\nContext hint: {}\n\nDiff summary:\n{}", - context.unwrap_or("none"), - truncate_for_prompt(&staged, 10_000) - ); - let draft = sanitize_generated_message(&self.run_internal_prompt_text(&prompt, false)?); - let (title, body) = parse_titled_body(&draft) - .ok_or_else(|| "failed to parse generated PR title/body".to_string())?; - - if command_exists("gh") { - let body_path = write_temp_text_file("claw-pr-body.md", &body)?; - let output = Command::new("gh") - .args(["pr", "create", "--title", &title, "--body-file"]) - .arg(&body_path) - .current_dir(env::current_dir()?) - .output()?; - if output.status.success() { - let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); - println!( - "PR\n Result created\n Title {title}\n URL {}", - if stdout.is_empty() { "<unknown>" } else { &stdout } - ); - return Ok(()); - } - } - - println!("PR draft\n Title {title}\n\n{body}"); - Ok(()) - } - - fn run_issue(&self, context: Option<&str>) -> Result<(), Box<dyn std::error::Error>> { - let prompt = format!( - "Generate a GitHub issue title and body from this conversation. Output plain text in this format exactly:\nTITLE: <title>\nBODY:\n<body markdown>\n\nContext hint: {}\n\nConversation context:\n{}", - context.unwrap_or("none"), - truncate_for_prompt(&recent_user_context(self.runtime.session(), 10), 10_000) - ); - let draft = sanitize_generated_message(&self.run_internal_prompt_text(&prompt, false)?); - let (title, body) = parse_titled_body(&draft) - .ok_or_else(|| "failed to parse generated issue title/body".to_string())?; - - if command_exists("gh") { - let body_path = write_temp_text_file("claw-issue-body.md", &body)?; - let output = Command::new("gh") - .args(["issue", "create", "--title", &title, "--body-file"]) - .arg(&body_path) - .current_dir(env::current_dir()?) - .output()?; - if output.status.success() { - let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); - println!( - "Issue\n Result created\n Title {title}\n URL {}", - if stdout.is_empty() { "<unknown>" } else { &stdout } - ); - return Ok(()); - } - } - - println!("Issue draft\n Title {title}\n\n{body}"); - Ok(()) - } } fn sessions_dir() -> Result<PathBuf, Box<dyn std::error::Error>> { let cwd = env::current_dir()?; - let path = cwd.join(".claw").join("sessions"); + let path = cwd.join(".claude").join("sessions"); fs::create_dir_all(&path)?; Ok(path) } @@ -1942,12 +1676,9 @@ fn render_config_report(section: Option<&str>) -> Result<String, Box<dyn std::er "env" => runtime_config.get("env"), "hooks" => runtime_config.get("hooks"), "model" => runtime_config.get("model"), - "plugins" => runtime_config - .get("plugins") - .or_else(|| runtime_config.get("enabledPlugins")), other => { lines.push(format!( - " Unsupported config section '{other}'. Use env, hooks, model, or plugins." + " Unsupported config section '{other}'. Use env, hooks, or model." )); return Ok(lines.join( " @@ -2053,206 +1784,6 @@ fn render_diff_report() -> Result<String, Box<dyn std::error::Error>> { Ok(format!("Diff\n\n{}", diff.trim_end())) } -fn render_teleport_report(target: &str) -> Result<String, Box<dyn std::error::Error>> { - let cwd = env::current_dir()?; - - let file_list = Command::new("rg") - .args(["--files"]) - .current_dir(&cwd) - .output()?; - let file_matches = if file_list.status.success() { - String::from_utf8(file_list.stdout)? - .lines() - .filter(|line| line.contains(target)) - .take(10) - .map(ToOwned::to_owned) - .collect::<Vec<_>>() - } else { - Vec::new() - }; - - let content_output = Command::new("rg") - .args(["-n", "-S", "--color", "never", target, "."]) - .current_dir(&cwd) - .output()?; - - let mut lines = vec![format!("Teleport\n Target {target}")]; - if !file_matches.is_empty() { - lines.push(String::new()); - lines.push("File matches".to_string()); - lines.extend(file_matches.into_iter().map(|path| format!(" {path}"))); - } - - if content_output.status.success() { - let matches = String::from_utf8(content_output.stdout)?; - if !matches.trim().is_empty() { - lines.push(String::new()); - lines.push("Content matches".to_string()); - lines.push(truncate_for_prompt(&matches, 4_000)); - } - } - - if lines.len() == 1 { - lines.push(" Result no matches found".to_string()); - } - - Ok(lines.join("\n")) -} - -fn render_last_tool_debug_report(session: &Session) -> Result<String, Box<dyn std::error::Error>> { - let last_tool_use = session - .messages - .iter() - .rev() - .find_map(|message| { - message.blocks.iter().rev().find_map(|block| match block { - ContentBlock::ToolUse { id, name, input } => { - Some((id.clone(), name.clone(), input.clone())) - } - _ => None, - }) - }) - .ok_or_else(|| "no prior tool call found in session".to_string())?; - - let tool_result = session.messages.iter().rev().find_map(|message| { - message.blocks.iter().rev().find_map(|block| match block { - ContentBlock::ToolResult { - tool_use_id, - tool_name, - output, - is_error, - } if tool_use_id == &last_tool_use.0 => { - Some((tool_name.clone(), output.clone(), *is_error)) - } - _ => None, - }) - }); - - let mut lines = vec![ - "Debug tool call".to_string(), - format!(" Tool id {}", last_tool_use.0), - format!(" Tool name {}", last_tool_use.1), - " Input".to_string(), - indent_block(&last_tool_use.2, 4), - ]; - - match tool_result { - Some((tool_name, output, is_error)) => { - lines.push(" Result".to_string()); - lines.push(format!(" name {tool_name}")); - lines.push(format!( - " status {}", - if is_error { "error" } else { "ok" } - )); - lines.push(indent_block(&output, 4)); - } - None => lines.push(" Result missing tool result".to_string()), - } - - Ok(lines.join("\n")) -} - -fn indent_block(value: &str, spaces: usize) -> String { - let indent = " ".repeat(spaces); - value - .lines() - .map(|line| format!("{indent}{line}")) - .collect::<Vec<_>>() - .join("\n") -} - -fn git_output(args: &[&str]) -> Result<String, Box<dyn std::error::Error>> { - let output = Command::new("git") - .args(args) - .current_dir(env::current_dir()?) - .output()?; - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - return Err(format!("git {} failed: {stderr}", args.join(" ")).into()); - } - Ok(String::from_utf8(output.stdout)?) -} - -fn git_status_ok(args: &[&str]) -> Result<(), Box<dyn std::error::Error>> { - let output = Command::new("git") - .args(args) - .current_dir(env::current_dir()?) - .output()?; - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - return Err(format!("git {} failed: {stderr}", args.join(" ")).into()); - } - Ok(()) -} - -fn command_exists(name: &str) -> bool { - Command::new("which") - .arg(name) - .output() - .map(|output| output.status.success()) - .unwrap_or(false) -} - -fn write_temp_text_file( - filename: &str, - contents: &str, -) -> Result<PathBuf, Box<dyn std::error::Error>> { - let path = env::temp_dir().join(filename); - fs::write(&path, contents)?; - Ok(path) -} - -fn recent_user_context(session: &Session, limit: usize) -> String { - let requests = session - .messages - .iter() - .filter(|message| message.role == MessageRole::User) - .filter_map(|message| { - message.blocks.iter().find_map(|block| match block { - ContentBlock::Text { text } => Some(text.trim().to_string()), - _ => None, - }) - }) - .rev() - .take(limit) - .collect::<Vec<_>>(); - - if requests.is_empty() { - "<no prior user messages>".to_string() - } else { - requests - .into_iter() - .rev() - .enumerate() - .map(|(index, text)| format!("{}. {}", index + 1, text)) - .collect::<Vec<_>>() - .join("\n") - } -} - -fn truncate_for_prompt(value: &str, limit: usize) -> String { - if value.chars().count() <= limit { - value.trim().to_string() - } else { - let truncated = value.chars().take(limit).collect::<String>(); - format!("{}\n…[truncated]", truncated.trim_end()) - } -} - -fn sanitize_generated_message(value: &str) -> String { - value.trim().trim_matches('`').trim().replace("\r\n", "\n") -} - -fn parse_titled_body(value: &str) -> Option<(String, String)> { - let normalized = sanitize_generated_message(value); - let title = normalized - .lines() - .find_map(|line| line.strip_prefix("TITLE:").map(str::trim))?; - let body_start = normalized.find("BODY:")?; - let body = normalized[body_start + "BODY:".len()..].trim(); - Some((title.to_string(), body.to_string())) -} - fn render_version_report() -> String { let git_sha = GIT_SHA.unwrap_or("unknown"); let target = BUILD_TARGET.unwrap_or("unknown"); @@ -2357,388 +1888,15 @@ fn build_system_prompt() -> Result<Vec<String>, Box<dyn std::error::Error>> { )?) } -fn build_runtime_plugin_state() -> Result< - ( - runtime::RuntimeFeatureConfig, - PluginRegistry, - GlobalToolRegistry, - ), - Box<dyn std::error::Error>, -> { +fn build_runtime_feature_config( +) -> Result<runtime::RuntimeFeatureConfig, Box<dyn std::error::Error>> { let cwd = env::current_dir()?; - let loader = ConfigLoader::default_for(&cwd); - let runtime_config = loader.load()?; - let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); - let plugin_registry = plugin_manager.plugin_registry()?; - let tool_registry = GlobalToolRegistry::with_plugin_tools(plugin_registry.aggregated_tools()?)?; - Ok(( - runtime_config.feature_config().clone(), - plugin_registry, - tool_registry, - )) + Ok(ConfigLoader::default_for(cwd) + .load()? + .feature_config() + .clone()) } -fn build_plugin_manager( - cwd: &Path, - loader: &ConfigLoader, - runtime_config: &runtime::RuntimeConfig, -) -> PluginManager { - let plugin_settings = runtime_config.plugins(); - let mut plugin_config = PluginManagerConfig::new(loader.config_home().to_path_buf()); - plugin_config.enabled_plugins = plugin_settings.enabled_plugins().clone(); - plugin_config.external_dirs = plugin_settings - .external_directories() - .iter() - .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)) - .collect(); - plugin_config.install_root = plugin_settings - .install_root() - .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); - plugin_config.registry_path = plugin_settings - .registry_path() - .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); - plugin_config.bundled_root = plugin_settings - .bundled_root() - .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); - PluginManager::new(plugin_config) -} - -fn resolve_plugin_path(cwd: &Path, config_home: &Path, value: &str) -> PathBuf { - let path = PathBuf::from(value); - if path.is_absolute() { - path - } else if value.starts_with('.') { - cwd.join(path) - } else { - config_home.join(path) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct InternalPromptProgressState { - command_label: &'static str, - task_label: String, - step: usize, - phase: String, - detail: Option<String>, - saw_final_text: bool, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum InternalPromptProgressEvent { - Started, - Update, - Heartbeat, - Complete, - Failed, -} - -#[derive(Debug)] -struct InternalPromptProgressShared { - state: Mutex<InternalPromptProgressState>, - output_lock: Mutex<()>, - started_at: Instant, -} - -#[derive(Debug, Clone)] -struct InternalPromptProgressReporter { - shared: Arc<InternalPromptProgressShared>, -} - -#[derive(Debug)] -struct InternalPromptProgressRun { - reporter: InternalPromptProgressReporter, - heartbeat_stop: Option<mpsc::Sender<()>>, - heartbeat_handle: Option<thread::JoinHandle<()>>, -} - -impl InternalPromptProgressReporter { - fn ultraplan(task: &str) -> Self { - Self { - shared: Arc::new(InternalPromptProgressShared { - state: Mutex::new(InternalPromptProgressState { - command_label: "Ultraplan", - task_label: task.to_string(), - step: 0, - phase: "planning started".to_string(), - detail: Some(format!("task: {task}")), - saw_final_text: false, - }), - output_lock: Mutex::new(()), - started_at: Instant::now(), - }), - } - } - - fn emit(&self, event: InternalPromptProgressEvent, error: Option<&str>) { - let snapshot = self.snapshot(); - let line = - format_internal_prompt_progress_line(event, &snapshot, self.elapsed(), error); - self.write_line(&line); - } - - fn mark_model_phase(&self) { - let snapshot = { - let mut state = self - .shared - .state - .lock() - .expect("internal prompt progress state poisoned"); - state.step += 1; - state.phase = if state.step == 1 { - "analyzing request".to_string() - } else { - "reviewing findings".to_string() - }; - state.detail = Some(format!("task: {}", state.task_label)); - state.clone() - }; - self.write_line(&format_internal_prompt_progress_line( - InternalPromptProgressEvent::Update, - &snapshot, - self.elapsed(), - None, - )); - } - - fn mark_tool_phase(&self, name: &str, input: &str) { - let detail = describe_tool_progress(name, input); - let snapshot = { - let mut state = self - .shared - .state - .lock() - .expect("internal prompt progress state poisoned"); - state.step += 1; - state.phase = format!("running {name}"); - state.detail = Some(detail); - state.clone() - }; - self.write_line(&format_internal_prompt_progress_line( - InternalPromptProgressEvent::Update, - &snapshot, - self.elapsed(), - None, - )); - } - - fn mark_text_phase(&self, text: &str) { - let trimmed = text.trim(); - if trimmed.is_empty() { - return; - } - let detail = truncate_for_summary(first_visible_line(trimmed), 120); - let snapshot = { - let mut state = self - .shared - .state - .lock() - .expect("internal prompt progress state poisoned"); - if state.saw_final_text { - return; - } - state.saw_final_text = true; - state.step += 1; - state.phase = "drafting final plan".to_string(); - state.detail = (!detail.is_empty()).then_some(detail); - state.clone() - }; - self.write_line(&format_internal_prompt_progress_line( - InternalPromptProgressEvent::Update, - &snapshot, - self.elapsed(), - None, - )); - } - - fn emit_heartbeat(&self) { - let snapshot = self.snapshot(); - self.write_line(&format_internal_prompt_progress_line( - InternalPromptProgressEvent::Heartbeat, - &snapshot, - self.elapsed(), - None, - )); - } - - fn snapshot(&self) -> InternalPromptProgressState { - self.shared - .state - .lock() - .expect("internal prompt progress state poisoned") - .clone() - } - - fn elapsed(&self) -> Duration { - self.shared.started_at.elapsed() - } - - fn write_line(&self, line: &str) { - let _guard = self - .shared - .output_lock - .lock() - .expect("internal prompt progress output lock poisoned"); - let mut stdout = io::stdout(); - let _ = writeln!(stdout, "{line}"); - let _ = stdout.flush(); - } -} - -impl InternalPromptProgressRun { - fn start_ultraplan(task: &str) -> Self { - let reporter = InternalPromptProgressReporter::ultraplan(task); - reporter.emit(InternalPromptProgressEvent::Started, None); - - let (heartbeat_stop, heartbeat_rx) = mpsc::channel(); - let heartbeat_reporter = reporter.clone(); - let heartbeat_handle = thread::spawn(move || { - loop { - match heartbeat_rx.recv_timeout(INTERNAL_PROGRESS_HEARTBEAT_INTERVAL) { - Ok(()) | Err(RecvTimeoutError::Disconnected) => break, - Err(RecvTimeoutError::Timeout) => heartbeat_reporter.emit_heartbeat(), - } - } - }); - - Self { - reporter, - heartbeat_stop: Some(heartbeat_stop), - heartbeat_handle: Some(heartbeat_handle), - } - } - - fn reporter(&self) -> InternalPromptProgressReporter { - self.reporter.clone() - } - - fn finish_success(&mut self) { - self.stop_heartbeat(); - self.reporter.emit(InternalPromptProgressEvent::Complete, None); - } - - fn finish_failure(&mut self, error: &str) { - self.stop_heartbeat(); - self.reporter - .emit(InternalPromptProgressEvent::Failed, Some(error)); - } - - fn stop_heartbeat(&mut self) { - if let Some(sender) = self.heartbeat_stop.take() { - let _ = sender.send(()); - } - if let Some(handle) = self.heartbeat_handle.take() { - let _ = handle.join(); - } - } -} - -impl Drop for InternalPromptProgressRun { - fn drop(&mut self) { - self.stop_heartbeat(); - } -} - -fn format_internal_prompt_progress_line( - event: InternalPromptProgressEvent, - snapshot: &InternalPromptProgressState, - elapsed: Duration, - error: Option<&str>, -) -> String { - let elapsed_seconds = elapsed.as_secs(); - let step_label = if snapshot.step == 0 { - "current step pending".to_string() - } else { - format!("current step {}", snapshot.step) - }; - let mut status_bits = vec![step_label, format!("phase {}", snapshot.phase)]; - if let Some(detail) = snapshot.detail.as_deref().filter(|detail| !detail.is_empty()) { - status_bits.push(detail.to_string()); - } - let status = status_bits.join(" · "); - match event { - InternalPromptProgressEvent::Started => { - format!("🧭 {} status · planning started · {status}", snapshot.command_label) - } - InternalPromptProgressEvent::Update => { - format!("… {} status · {status}", snapshot.command_label) - } - InternalPromptProgressEvent::Heartbeat => format!( - "… {} heartbeat · {elapsed_seconds}s elapsed · {status}", - snapshot.command_label - ), - InternalPromptProgressEvent::Complete => format!( - "✔ {} status · completed · {elapsed_seconds}s elapsed · {} steps total", - snapshot.command_label, - snapshot.step - ), - InternalPromptProgressEvent::Failed => format!( - "✘ {} status · failed · {elapsed_seconds}s elapsed · {}", - snapshot.command_label, - error.unwrap_or("unknown error") - ), - } -} - -fn describe_tool_progress(name: &str, input: &str) -> String { - let parsed: serde_json::Value = - serde_json::from_str(input).unwrap_or(serde_json::Value::String(input.to_string())); - match name { - "bash" | "Bash" => { - let command = parsed - .get("command") - .and_then(|value| value.as_str()) - .unwrap_or_default(); - if command.is_empty() { - "running shell command".to_string() - } else { - format!("command {}", truncate_for_summary(command.trim(), 100)) - } - } - "read_file" | "Read" => format!("reading {}", extract_tool_path(&parsed)), - "write_file" | "Write" => format!("writing {}", extract_tool_path(&parsed)), - "edit_file" | "Edit" => format!("editing {}", extract_tool_path(&parsed)), - "glob_search" | "Glob" => { - let pattern = parsed - .get("pattern") - .and_then(|value| value.as_str()) - .unwrap_or("?"); - let scope = parsed - .get("path") - .and_then(|value| value.as_str()) - .unwrap_or("."); - format!("glob `{pattern}` in {scope}") - } - "grep_search" | "Grep" => { - let pattern = parsed - .get("pattern") - .and_then(|value| value.as_str()) - .unwrap_or("?"); - let scope = parsed - .get("path") - .and_then(|value| value.as_str()) - .unwrap_or("."); - format!("grep `{pattern}` in {scope}") - } - "web_search" | "WebSearch" => parsed - .get("query") - .and_then(|value| value.as_str()) - .map_or_else( - || "running web search".to_string(), - |query| format!("query {}", truncate_for_summary(query, 100)), - ), - _ => { - let summary = summarize_tool_payload(input); - if summary.is_empty() { - format!("running {name}") - } else { - format!("{name}: {summary}") - } - } - } -} - -#[allow(clippy::needless_pass_by_value)] fn build_runtime( session: Session, model: String, @@ -2747,26 +1905,17 @@ fn build_runtime( emit_output: bool, allowed_tools: Option<AllowedToolSet>, permission_mode: PermissionMode, - progress_reporter: Option<InternalPromptProgressReporter>, -) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>> +) -> Result<ConversationRuntime<ProviderRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>> { - let (feature_config, plugin_registry, tool_registry) = build_runtime_plugin_state()?; - Ok(ConversationRuntime::new_with_plugins( + let feature_config = build_runtime_feature_config()?; + Ok(ConversationRuntime::new_with_features( session, - AnthropicRuntimeClient::new( - model, - enable_tools, - emit_output, - allowed_tools.clone(), - tool_registry.clone(), - progress_reporter, - )?, - CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry.clone()), - permission_policy(permission_mode, &tool_registry), + ProviderRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, + CliToolExecutor::new(allowed_tools, emit_output), + permission_policy(permission_mode), system_prompt, - feature_config, - plugin_registry, - )?) + &feature_config, + )) } struct CliPermissionPrompter { @@ -2815,36 +1964,37 @@ impl runtime::PermissionPrompter for CliPermissionPrompter { } } -struct AnthropicRuntimeClient { +struct ProviderRuntimeClient { runtime: tokio::runtime::Runtime, - client: ApiHttpClient, + client: ProviderClient, model: String, enable_tools: bool, emit_output: bool, allowed_tools: Option<AllowedToolSet>, - tool_registry: GlobalToolRegistry, - progress_reporter: Option<InternalPromptProgressReporter>, } -impl AnthropicRuntimeClient { +impl ProviderRuntimeClient { fn new( model: String, enable_tools: bool, emit_output: bool, allowed_tools: Option<AllowedToolSet>, - tool_registry: GlobalToolRegistry, - progress_reporter: Option<InternalPromptProgressReporter>, ) -> Result<Self, Box<dyn std::error::Error>> { + let model = resolve_model_alias(&model).to_string(); + let client = match detect_provider_kind(&model) { + ProviderKind::Anthropic => ProviderClient::from_model_with_anthropic_auth( + &model, + Some(resolve_cli_auth_source()?), + )?, + ProviderKind::Xai | ProviderKind::OpenAi => ProviderClient::from_model(&model)?, + }; Ok(Self { runtime: tokio::runtime::Runtime::new()?, - client: ApiHttpClient::from_auth(resolve_cli_auth_source()?) - .with_base_url(api::read_base_url()), + client, model, enable_tools, emit_output, allowed_tools, - tool_registry, - progress_reporter, }) } } @@ -2859,20 +2009,24 @@ fn resolve_cli_auth_source() -> Result<AuthSource, Box<dyn std::error::Error>> { })?) } -impl ApiClient for AnthropicRuntimeClient { +impl ApiClient for ProviderRuntimeClient { #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { - if let Some(progress_reporter) = &self.progress_reporter { - progress_reporter.mark_model_phase(); - } let message_request = MessageRequest { model: self.model.clone(), max_tokens: max_tokens_for_model(&self.model), messages: convert_messages(&request.messages), system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), - tools: self - .enable_tools - .then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())), + tools: self.enable_tools.then(|| { + filter_tool_specs(self.allowed_tools.as_ref()) + .into_iter() + .map(|spec| ToolDefinition { + name: spec.name.to_string(), + description: Some(spec.description.to_string()), + input_schema: spec.input_schema, + }) + .collect() + }), tool_choice: self.enable_tools.then_some(ToolChoice::Auto), stream: true, }; @@ -2893,7 +2047,7 @@ impl ApiClient for AnthropicRuntimeClient { let renderer = TerminalRenderer::new(); let mut markdown_stream = MarkdownStreamState::default(); let mut events = Vec::new(); - let mut pending_tool: Option<(String, String, String)> = None; + let mut pending_tools: BTreeMap<u32, (String, String, String)> = BTreeMap::new(); let mut saw_stop = false; while let Some(event) = stream @@ -2904,24 +2058,29 @@ impl ApiClient for AnthropicRuntimeClient { match event { ApiStreamEvent::MessageStart(start) => { for block in start.message.content { - push_output_block(block, out, &mut events, &mut pending_tool, true)?; + push_output_block( + block, + 0, + out, + &mut events, + &mut pending_tools, + true, + )?; } } ApiStreamEvent::ContentBlockStart(start) => { push_output_block( start.content_block, + start.index, out, &mut events, - &mut pending_tool, + &mut pending_tools, true, )?; } ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { ContentBlockDelta::TextDelta { text } => { if !text.is_empty() { - if let Some(progress_reporter) = &self.progress_reporter { - progress_reporter.mark_text_phase(&text); - } if let Some(rendered) = markdown_stream.push(&renderer, &text) { write!(out, "{rendered}") .and_then(|()| out.flush()) @@ -2931,23 +2090,18 @@ impl ApiClient for AnthropicRuntimeClient { } } ContentBlockDelta::InputJsonDelta { partial_json } => { - if let Some((_, _, input)) = &mut pending_tool { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { input.push_str(&partial_json); } } - ContentBlockDelta::ThinkingDelta { .. } - | ContentBlockDelta::SignatureDelta { .. } => {} }, - ApiStreamEvent::ContentBlockStop(_) => { + ApiStreamEvent::ContentBlockStop(stop) => { if let Some(rendered) = markdown_stream.flush(&renderer) { write!(out, "{rendered}") .and_then(|()| out.flush()) .map_err(|error| RuntimeError::new(error.to_string()))?; } - if let Some((id, name, input)) = pending_tool.take() { - if let Some(progress_reporter) = &self.progress_reporter { - progress_reporter.mark_tool_phase(&name, &input); - } + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { // Display tool call now that input is fully accumulated writeln!(out, "\n{}", format_tool_call_start(&name, &input)) .and_then(|()| out.flush()) @@ -3144,17 +2298,13 @@ fn format_tool_result(name: &str, output: &str, is_error: bool) -> String { "edit_file" | "Edit" => format_edit_result(icon, &parsed), "glob_search" | "Glob" => format_glob_result(icon, &parsed), "grep_search" | "Grep" => format_grep_result(icon, &parsed), - _ => format_generic_tool_result(icon, name, &parsed), + _ => { + let summary = truncate_for_summary(output.trim(), 200); + format!("{icon} \x1b[38;5;245m{name}:\x1b[0m {summary}") + } } } -const DISPLAY_TRUNCATION_NOTICE: &str = - "\x1b[2m… output truncated for display; full result preserved in session.\x1b[0m"; -const READ_DISPLAY_MAX_LINES: usize = 80; -const READ_DISPLAY_MAX_CHARS: usize = 6_000; -const TOOL_OUTPUT_DISPLAY_MAX_LINES: usize = 60; -const TOOL_OUTPUT_DISPLAY_MAX_CHARS: usize = 4_000; - fn extract_tool_path(parsed: &serde_json::Value) -> String { parsed .get("file_path") @@ -3215,34 +2365,23 @@ fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { .get("backgroundTaskId") .and_then(|value| value.as_str()) { - write!(&mut lines[0], " backgrounded ({task_id})").expect("write to string"); + lines[0].push_str(&format!(" backgrounded ({task_id})")); } else if let Some(status) = parsed .get("returnCodeInterpretation") .and_then(|value| value.as_str()) .filter(|status| !status.is_empty()) { - write!(&mut lines[0], " {status}").expect("write to string"); + lines[0].push_str(&format!(" {status}")); } if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { if !stdout.trim().is_empty() { - lines.push(truncate_output_for_display( - stdout, - TOOL_OUTPUT_DISPLAY_MAX_LINES, - TOOL_OUTPUT_DISPLAY_MAX_CHARS, - )); + lines.push(stdout.trim_end().to_string()); } } if let Some(stderr) = parsed.get("stderr").and_then(|value| value.as_str()) { if !stderr.trim().is_empty() { - lines.push(format!( - "\x1b[38;5;203m{}\x1b[0m", - truncate_output_for_display( - stderr, - TOOL_OUTPUT_DISPLAY_MAX_LINES, - TOOL_OUTPUT_DISPLAY_MAX_CHARS, - ) - )); + lines.push(format!("\x1b[38;5;203m{}\x1b[0m", stderr.trim_end())); } } @@ -3254,15 +2393,15 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String { let path = extract_tool_path(file); let start_line = file .get("startLine") - .and_then(serde_json::Value::as_u64) + .and_then(|value| value.as_u64()) .unwrap_or(1); let num_lines = file .get("numLines") - .and_then(serde_json::Value::as_u64) + .and_then(|value| value.as_u64()) .unwrap_or(0); let total_lines = file .get("totalLines") - .and_then(serde_json::Value::as_u64) + .and_then(|value| value.as_u64()) .unwrap_or(num_lines); let content = file .get("content") @@ -3275,7 +2414,7 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String { start_line, end_line.max(start_line), total_lines, - truncate_output_for_display(content, READ_DISPLAY_MAX_LINES, READ_DISPLAY_MAX_CHARS) + content ) } @@ -3288,7 +2427,8 @@ fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String { let line_count = parsed .get("content") .and_then(|value| value.as_str()) - .map_or(0, |content| content.lines().count()); + .map(|content| content.lines().count()) + .unwrap_or(0); format!( "{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", if kind == "create" { "Wrote" } else { "Updated" }, @@ -3319,7 +2459,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { let path = extract_tool_path(parsed); let suffix = if parsed .get("replaceAll") - .and_then(serde_json::Value::as_bool) + .and_then(|value| value.as_bool()) .unwrap_or(false) { " (replace all)" @@ -3347,7 +2487,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { let num_files = parsed .get("numFiles") - .and_then(serde_json::Value::as_u64) + .and_then(|value| value.as_u64()) .unwrap_or(0); let filenames = parsed .get("filenames") @@ -3371,11 +2511,11 @@ fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { let num_matches = parsed .get("numMatches") - .and_then(serde_json::Value::as_u64) + .and_then(|value| value.as_u64()) .unwrap_or(0); let num_files = parsed .get("numFiles") - .and_then(serde_json::Value::as_u64) + .and_then(|value| value.as_u64()) .unwrap_or(0); let content = parsed .get("content") @@ -3397,14 +2537,7 @@ fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { "{icon} \x1b[38;5;245mgrep_search\x1b[0m {num_matches} matches across {num_files} files" ); if !content.trim().is_empty() { - format!( - "{summary}\n{}", - truncate_output_for_display( - content, - TOOL_OUTPUT_DISPLAY_MAX_LINES, - TOOL_OUTPUT_DISPLAY_MAX_CHARS, - ) - ) + format!("{summary}\n{}", content.trim_end()) } else if !filenames.is_empty() { format!("{summary}\n{filenames}") } else { @@ -3412,30 +2545,6 @@ fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { } } -fn format_generic_tool_result(icon: &str, name: &str, parsed: &serde_json::Value) -> String { - let rendered_output = match parsed { - serde_json::Value::String(text) => text.clone(), - serde_json::Value::Null => String::new(), - serde_json::Value::Object(_) | serde_json::Value::Array(_) => { - serde_json::to_string_pretty(parsed).unwrap_or_else(|_| parsed.to_string()) - } - _ => parsed.to_string(), - }; - let preview = truncate_output_for_display( - &rendered_output, - TOOL_OUTPUT_DISPLAY_MAX_LINES, - TOOL_OUTPUT_DISPLAY_MAX_CHARS, - ); - - if preview.is_empty() { - format!("{icon} \x1b[38;5;245m{name}\x1b[0m") - } else if preview.contains('\n') { - format!("{icon} \x1b[38;5;245m{name}\x1b[0m\n{preview}") - } else { - format!("{icon} \x1b[38;5;245m{name}:\x1b[0m {preview}") - } -} - fn summarize_tool_payload(payload: &str) -> String { let compact = match serde_json::from_str::<serde_json::Value>(payload) { Ok(value) => value.to_string(), @@ -3454,55 +2563,12 @@ fn truncate_for_summary(value: &str, limit: usize) -> String { } } -fn truncate_output_for_display(content: &str, max_lines: usize, max_chars: usize) -> String { - let original = content.trim_end_matches('\n'); - if original.is_empty() { - return String::new(); - } - - let mut preview_lines = Vec::new(); - let mut used_chars = 0usize; - let mut truncated = false; - - for (index, line) in original.lines().enumerate() { - if index >= max_lines { - truncated = true; - break; - } - - let newline_cost = usize::from(!preview_lines.is_empty()); - let available = max_chars.saturating_sub(used_chars + newline_cost); - if available == 0 { - truncated = true; - break; - } - - let line_chars = line.chars().count(); - if line_chars > available { - preview_lines.push(line.chars().take(available).collect::<String>()); - truncated = true; - break; - } - - preview_lines.push(line.to_string()); - used_chars += newline_cost + line_chars; - } - - let mut preview = preview_lines.join("\n"); - if truncated { - if !preview.is_empty() { - preview.push('\n'); - } - preview.push_str(DISPLAY_TRUNCATION_NOTICE); - } - preview -} - fn push_output_block( block: OutputContentBlock, + block_index: u32, out: &mut (impl Write + ?Sized), events: &mut Vec<AssistantEvent>, - pending_tool: &mut Option<(String, String, String)>, + pending_tools: &mut BTreeMap<u32, (String, String, String)>, streaming_tool_input: bool, ) -> Result<(), RuntimeError> { match block { @@ -3527,9 +2593,8 @@ fn push_output_block( } else { input.to_string() }; - *pending_tool = Some((id, name, initial_input)); + pending_tools.insert(block_index, (id, name, initial_input)); } - OutputContentBlock::Thinking { .. } | OutputContentBlock::RedactedThinking { .. } => {} } Ok(()) } @@ -3539,11 +2604,13 @@ fn response_to_events( out: &mut (impl Write + ?Sized), ) -> Result<Vec<AssistantEvent>, RuntimeError> { let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); - for block in response.content { - push_output_block(block, out, &mut events, &mut pending_tool, false)?; - if let Some((id, name, input)) = pending_tool.take() { + for (index, block) in response.content.into_iter().enumerate() { + let index = + u32::try_from(index).map_err(|_| RuntimeError::new("response block index overflow"))?; + push_output_block(block, index, out, &mut events, &mut pending_tools, false)?; + if let Some((id, name, input)) = pending_tools.remove(&index) { events.push(AssistantEvent::ToolUse { id, name, input }); } } @@ -3562,20 +2629,14 @@ struct CliToolExecutor { renderer: TerminalRenderer, emit_output: bool, allowed_tools: Option<AllowedToolSet>, - tool_registry: GlobalToolRegistry, } impl CliToolExecutor { - fn new( - allowed_tools: Option<AllowedToolSet>, - emit_output: bool, - tool_registry: GlobalToolRegistry, - ) -> Self { + fn new(allowed_tools: Option<AllowedToolSet>, emit_output: bool) -> Self { Self { renderer: TerminalRenderer::new(), emit_output, allowed_tools, - tool_registry, } } } @@ -3593,7 +2654,7 @@ impl ToolExecutor for CliToolExecutor { } let value = serde_json::from_str(input) .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; - match self.tool_registry.execute(tool_name, &value) { + match execute_tool(tool_name, &value) { Ok(output) => { if self.emit_output { let markdown = format_tool_result(tool_name, &output, false); @@ -3616,13 +2677,16 @@ impl ToolExecutor for CliToolExecutor { } } -fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy { - tool_registry.permission_specs(None).into_iter().fold( - PermissionPolicy::new(mode), - |policy, (name, required_permission)| { - policy.with_tool_requirement(name, required_permission) - }, - ) +fn permission_policy(mode: PermissionMode) -> PermissionPolicy { + tool_permission_specs() + .into_iter() + .fold(PermissionPolicy::new(mode), |policy, spec| { + policy.with_tool_requirement(spec.name, spec.required_permission) + }) +} + +fn tool_permission_specs() -> Vec<ToolSpec> { + mvp_tool_specs() } fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> { @@ -3761,48 +2825,19 @@ fn print_help() { #[cfg(test)] mod tests { use super::{ - describe_tool_progress, filter_tool_specs, format_compact_report, format_cost_report, - format_internal_prompt_progress_line, format_model_report, - format_model_switch_report, format_permissions_report, - format_permissions_switch_report, format_resume_report, format_status_report, - format_tool_call_start, format_tool_result, normalize_permission_mode, parse_args, - parse_git_status_metadata, permission_policy, print_help_to, push_output_block, - render_config_report, render_memory_report, render_repl_help, resolve_model_alias, - response_to_events, resume_supported_slash_commands, status_context, CliAction, - CliOutputFormat, InternalPromptProgressEvent, InternalPromptProgressState, SlashCommand, - StatusUsage, DEFAULT_MODEL, + filter_tool_specs, format_compact_report, format_cost_report, format_model_report, + format_model_switch_report, format_permissions_report, format_permissions_switch_report, + format_resume_report, format_status_report, format_tool_call_start, format_tool_result, + normalize_permission_mode, parse_args, parse_git_status_metadata, print_help_to, + push_output_block, render_config_report, render_memory_report, render_repl_help, + resolve_model_alias, response_to_events, resume_supported_slash_commands, status_context, + CliAction, CliOutputFormat, SlashCommand, StatusUsage, DEFAULT_MODEL, }; use api::{MessageResponse, OutputContentBlock, Usage}; - use plugins::{PluginTool, PluginToolDefinition, PluginToolPermission}; use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode}; use serde_json::json; + use std::collections::BTreeMap; use std::path::PathBuf; - use std::time::Duration; - use tools::GlobalToolRegistry; - - fn registry_with_plugin_tool() -> GlobalToolRegistry { - GlobalToolRegistry::with_plugin_tools(vec![PluginTool::new( - "plugin-demo@external", - "plugin-demo", - PluginToolDefinition { - name: "plugin_echo".to_string(), - description: Some("Echo plugin payload".to_string()), - input_schema: json!({ - "type": "object", - "properties": { - "message": { "type": "string" } - }, - "required": ["message"], - "additionalProperties": false - }), - }, - "echo".to_string(), - Vec::new(), - PluginToolPermission::WorkspaceWrite, - None, - )]) - .expect("plugin tool registry should build") - } #[test] fn defaults_to_repl_when_no_args() { @@ -3881,6 +2916,9 @@ mod tests { assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); assert_eq!(resolve_model_alias("sonnet"), "claude-sonnet-4-6"); assert_eq!(resolve_model_alias("haiku"), "claude-haiku-4-5-20251213"); + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); + assert_eq!(resolve_model_alias("grok-2"), "grok-2"); assert_eq!(resolve_model_alias("claude-opus"), "claude-opus"); } @@ -4016,7 +3054,7 @@ mod tests { .into_iter() .map(str::to_string) .collect(); - let filtered = filter_tool_specs(&GlobalToolRegistry::builtin(), Some(&allowed)); + let filtered = filter_tool_specs(Some(&allowed)); let names = filtered .into_iter() .map(|spec| spec.name) @@ -4024,24 +3062,6 @@ mod tests { assert_eq!(names, vec!["read_file", "grep_search"]); } - #[test] - fn filtered_tool_specs_include_plugin_tools() { - let filtered = filter_tool_specs(®istry_with_plugin_tool(), None); - let names = filtered - .into_iter() - .map(|definition| definition.name) - .collect::<Vec<_>>(); - assert!(names.contains(&"bash".to_string())); - assert!(names.contains(&"plugin_echo".to_string())); - } - - #[test] - fn permission_policy_uses_plugin_tool_permissions() { - let policy = permission_policy(PermissionMode::ReadOnly, ®istry_with_plugin_tool()); - let required = policy.required_mode_for("plugin_echo"); - assert_eq!(required, PermissionMode::WorkspaceWrite); - } - #[test] fn shared_help_uses_resume_annotation_copy() { let help = commands::render_slash_command_help(); @@ -4060,16 +3080,13 @@ mod tests { assert!(help.contains("/clear [--confirm]")); assert!(help.contains("/cost")); assert!(help.contains("/resume <session-path>")); - assert!(help.contains("/config [env|hooks|model|plugins]")); + assert!(help.contains("/config [env|hooks|model]")); assert!(help.contains("/memory")); assert!(help.contains("/init")); assert!(help.contains("/diff")); assert!(help.contains("/version")); assert!(help.contains("/export [file]")); assert!(help.contains("/session [list|switch <session-id>]")); - assert!(help.contains( - "/plugins [list|install <path>|enable <name>|disable <name>|uninstall <id>|update <id>]" - )); assert!(help.contains("/exit")); } @@ -4220,9 +3237,6 @@ mod tests { fn config_report_supports_section_views() { let report = render_config_report(Some("env")).expect("config report should render"); assert!(report.contains("Merged section: env")); - let plugins_report = - render_config_report(Some("plugins")).expect("plugins config report should render"); - assert!(plugins_report.contains("Merged section: plugins")); } #[test] @@ -4368,106 +3382,20 @@ mod tests { assert!(done.contains("hello")); } - #[test] - fn tool_rendering_truncates_large_read_output_for_display_only() { - let content = (0..200) - .map(|index| format!("line {index:03}")) - .collect::<Vec<_>>() - .join("\n"); - let output = json!({ - "file": { - "filePath": "src/main.rs", - "content": content, - "numLines": 200, - "startLine": 1, - "totalLines": 200 - } - }) - .to_string(); - - let rendered = format_tool_result("read_file", &output, false); - - assert!(rendered.contains("line 000")); - assert!(rendered.contains("line 079")); - assert!(!rendered.contains("line 199")); - assert!(rendered.contains("full result preserved in session")); - assert!(output.contains("line 199")); - } - - #[test] - fn tool_rendering_truncates_large_bash_output_for_display_only() { - let stdout = (0..120) - .map(|index| format!("stdout {index:03}")) - .collect::<Vec<_>>() - .join("\n"); - let output = json!({ - "stdout": stdout, - "stderr": "", - "returnCodeInterpretation": "completed successfully" - }) - .to_string(); - - let rendered = format_tool_result("bash", &output, false); - - assert!(rendered.contains("stdout 000")); - assert!(rendered.contains("stdout 059")); - assert!(!rendered.contains("stdout 119")); - assert!(rendered.contains("full result preserved in session")); - assert!(output.contains("stdout 119")); - } - - #[test] - fn tool_rendering_truncates_generic_long_output_for_display_only() { - let items = (0..120) - .map(|index| format!("payload {index:03}")) - .collect::<Vec<_>>(); - let output = json!({ - "summary": "plugin payload", - "items": items, - }) - .to_string(); - - let rendered = format_tool_result("plugin_echo", &output, false); - - assert!(rendered.contains("plugin_echo")); - assert!(rendered.contains("payload 000")); - assert!(rendered.contains("payload 040")); - assert!(!rendered.contains("payload 080")); - assert!(!rendered.contains("payload 119")); - assert!(rendered.contains("full result preserved in session")); - assert!(output.contains("payload 119")); - } - - #[test] - fn tool_rendering_truncates_raw_generic_output_for_display_only() { - let output = (0..120) - .map(|index| format!("raw {index:03}")) - .collect::<Vec<_>>() - .join("\n"); - - let rendered = format_tool_result("plugin_echo", &output, false); - - assert!(rendered.contains("plugin_echo")); - assert!(rendered.contains("raw 000")); - assert!(rendered.contains("raw 059")); - assert!(!rendered.contains("raw 119")); - assert!(rendered.contains("full result preserved in session")); - assert!(output.contains("raw 119")); - } - #[test] fn push_output_block_renders_markdown_text() { let mut out = Vec::new(); let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); push_output_block( OutputContentBlock::Text { text: "# Heading".to_string(), }, + 0, &mut out, &mut events, - &mut pending_tool, + &mut pending_tools, false, ) .expect("text block should render"); @@ -4481,7 +3409,7 @@ mod tests { fn push_output_block_skips_empty_object_prefix_for_tool_streams() { let mut out = Vec::new(); let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); push_output_block( OutputContentBlock::ToolUse { @@ -4489,20 +3417,83 @@ mod tests { name: "read_file".to_string(), input: json!({}), }, + 1, &mut out, &mut events, - &mut pending_tool, + &mut pending_tools, true, ) .expect("tool block should accumulate"); assert!(events.is_empty()); assert_eq!( - pending_tool, + pending_tools.remove(&1), Some(("tool-1".to_string(), "read_file".to_string(), String::new(),)) ); } + #[test] + fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() { + let mut out = Vec::new(); + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + 1, + &mut out, + &mut events, + &mut pending_tools, + true, + ) + .expect("first tool should accumulate"); + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "grep_search".to_string(), + input: json!({}), + }, + 2, + &mut out, + &mut events, + &mut pending_tools, + true, + ) + .expect("second tool should accumulate"); + + pending_tools + .get_mut(&1) + .expect("first tool pending") + .2 + .push_str("{\"path\":\"src/main.rs\"}"); + pending_tools + .get_mut(&2) + .expect("second tool pending") + .2 + .push_str("{\"pattern\":\"TODO\"}"); + + assert_eq!( + pending_tools.remove(&1), + Some(( + "tool-1".to_string(), + "read_file".to_string(), + "{\"path\":\"src/main.rs\"}".to_string(), + )) + ); + assert_eq!( + pending_tools.remove(&2), + Some(( + "tool-2".to_string(), + "grep_search".to_string(), + "{\"pattern\":\"TODO\"}".to_string(), + )) + ); + } + #[test] fn response_to_events_preserves_empty_object_json_input_outside_streaming() { let mut out = Vec::new(); @@ -4572,43 +3563,4 @@ mod tests { if name == "read_file" && input == "{\"path\":\"rust/Cargo.toml\"}" )); } - - #[test] - fn response_to_events_ignores_thinking_blocks() { - let mut out = Vec::new(); - let events = response_to_events( - MessageResponse { - id: "msg-3".to_string(), - kind: "message".to_string(), - model: "claude-opus-4-6".to_string(), - role: "assistant".to_string(), - content: vec![ - OutputContentBlock::Thinking { - thinking: "step 1".to_string(), - signature: Some("sig_123".to_string()), - }, - OutputContentBlock::Text { - text: "Final answer".to_string(), - }, - ], - stop_reason: Some("end_turn".to_string()), - stop_sequence: None, - usage: Usage { - input_tokens: 1, - output_tokens: 1, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }, - request_id: None, - }, - &mut out, - ) - .expect("response conversion should succeed"); - - assert!(matches!( - &events[0], - AssistantEvent::TextDelta(text) if text == "Final answer" - )); - assert!(!String::from_utf8(out).expect("utf8").contains("step 1")); - } } diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index a6e1ba4..63be324 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -4,17 +4,16 @@ use std::process::Command; use std::time::{Duration, Instant}; use api::{ - read_base_url, ApiHttpClient, ContentBlockDelta, InputContentBlock, InputMessage, - MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, - ToolDefinition, ToolResultContentBlock, + max_tokens_for_model, resolve_model_alias, ContentBlockDelta, InputContentBlock, InputMessage, + MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; -use plugins::{PluginManager, PluginManagerConfig, PluginTool}; use reqwest::blocking::Client; use runtime::{ edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, - ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ConfigLoader, ContentBlock, - ConversationMessage, ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, - PermissionPolicy, RuntimeConfig, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, + ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, + ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, + RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -56,239 +55,6 @@ pub struct ToolSpec { pub required_permission: PermissionMode, } -#[derive(Debug, Clone, PartialEq)] -pub struct RegisteredTool { - pub definition: ToolDefinition, - pub required_permission: PermissionMode, - handler: RegisteredToolHandler, -} - -#[allow(clippy::large_enum_variant)] -#[derive(Debug, Clone, PartialEq)] -enum RegisteredToolHandler { - Builtin, - Plugin(PluginTool), -} - -#[derive(Debug, Clone, PartialEq)] -pub struct GlobalToolRegistry { - entries: Vec<RegisteredTool>, -} - -impl GlobalToolRegistry { - #[must_use] - pub fn builtin() -> Self { - Self { - entries: mvp_tool_specs() - .into_iter() - .map(|spec| RegisteredTool { - definition: ToolDefinition { - name: spec.name.to_string(), - description: Some(spec.description.to_string()), - input_schema: spec.input_schema, - }, - required_permission: spec.required_permission, - handler: RegisteredToolHandler::Builtin, - }) - .collect(), - } - } - - pub fn with_plugin_tools(plugin_tools: Vec<PluginTool>) -> Result<Self, String> { - let mut registry = Self::builtin(); - let mut seen = registry - .entries - .iter() - .map(|entry| { - ( - normalize_registry_tool_name(&entry.definition.name), - entry.definition.name.clone(), - ) - }) - .collect::<BTreeMap<_, _>>(); - - for tool in plugin_tools { - let normalized = normalize_registry_tool_name(&tool.definition().name); - if let Some(existing) = seen.get(&normalized) { - return Err(format!( - "plugin tool `{}` from `{}` conflicts with already-registered tool `{existing}`", - tool.definition().name, - tool.plugin_id() - )); - } - seen.insert(normalized, tool.definition().name.clone()); - registry.entries.push(RegisteredTool { - definition: ToolDefinition { - name: tool.definition().name.clone(), - description: tool.definition().description.clone(), - input_schema: tool.definition().input_schema.clone(), - }, - required_permission: permission_mode_from_plugin_tool(tool.required_permission())?, - handler: RegisteredToolHandler::Plugin(tool), - }); - } - - Ok(registry) - } - - #[must_use] - pub fn entries(&self) -> &[RegisteredTool] { - &self.entries - } - - fn find_entry(&self, name: &str) -> Option<&RegisteredTool> { - let normalized = normalize_registry_tool_name(name); - self.entries.iter().find(|entry| { - normalize_registry_tool_name(entry.definition.name.as_str()) == normalized - }) - } - - #[must_use] - pub fn definitions(&self, allowed_tools: Option<&BTreeSet<String>>) -> Vec<ToolDefinition> { - self.entries - .iter() - .filter(|entry| { - allowed_tools.is_none_or(|allowed| allowed.contains(entry.definition.name.as_str())) - }) - .map(|entry| entry.definition.clone()) - .collect() - } - - #[must_use] - pub fn permission_specs( - &self, - allowed_tools: Option<&BTreeSet<String>>, - ) -> Vec<(String, PermissionMode)> { - self.entries - .iter() - .filter(|entry| { - allowed_tools.is_none_or(|allowed| allowed.contains(entry.definition.name.as_str())) - }) - .map(|entry| (entry.definition.name.clone(), entry.required_permission)) - .collect() - } - - pub fn normalize_allowed_tools( - &self, - values: &[String], - ) -> Result<Option<BTreeSet<String>>, String> { - if values.is_empty() { - return Ok(None); - } - - let canonical_names = self - .entries - .iter() - .map(|entry| entry.definition.name.clone()) - .collect::<Vec<_>>(); - let mut name_map = canonical_names - .iter() - .map(|name| (normalize_registry_tool_name(name), name.clone())) - .collect::<BTreeMap<_, _>>(); - - for (alias, canonical) in [ - ("read", "read_file"), - ("write", "write_file"), - ("edit", "edit_file"), - ("glob", "glob_search"), - ("grep", "grep_search"), - ] { - if canonical_names.iter().any(|name| name == canonical) { - name_map.insert(alias.to_string(), canonical.to_string()); - } - } - - let mut allowed = BTreeSet::new(); - for value in values { - for token in value - .split(|ch: char| ch == ',' || ch.is_whitespace()) - .filter(|token| !token.is_empty()) - { - let normalized = normalize_registry_tool_name(token); - let canonical = name_map.get(&normalized).ok_or_else(|| { - format!( - "unsupported tool in --allowedTools: {token} (expected one of: {})", - canonical_names.join(", ") - ) - })?; - allowed.insert(canonical.clone()); - } - } - - Ok(Some(allowed)) - } - - pub fn execute(&self, name: &str, input: &Value) -> Result<String, String> { - let entry = self - .find_entry(name) - .ok_or_else(|| format!("unsupported tool: {name}"))?; - match &entry.handler { - RegisteredToolHandler::Builtin => execute_tool(&entry.definition.name, input), - RegisteredToolHandler::Plugin(tool) => { - tool.execute(input).map_err(|error| error.to_string()) - } - } - } -} - -impl Default for GlobalToolRegistry { - fn default() -> Self { - Self::builtin() - } -} - -fn normalize_registry_tool_name(value: &str) -> String { - let trimmed = value.trim(); - let chars = trimmed.chars().collect::<Vec<_>>(); - let mut normalized = String::new(); - - for (index, ch) in chars.iter().copied().enumerate() { - if matches!(ch, '-' | ' ' | '\t' | '\n') { - if !normalized.ends_with('_') { - normalized.push('_'); - } - continue; - } - - if ch == '_' { - if !normalized.ends_with('_') { - normalized.push('_'); - } - continue; - } - - if ch.is_uppercase() { - let prev = chars.get(index.wrapping_sub(1)).copied(); - let next = chars.get(index + 1).copied(); - let needs_separator = index > 0 - && !normalized.ends_with('_') - && (prev.is_some_and(|prev| prev.is_lowercase() || prev.is_ascii_digit()) - || (prev.is_some_and(char::is_uppercase) - && next.is_some_and(char::is_lowercase))); - if needs_separator { - normalized.push('_'); - } - normalized.extend(ch.to_lowercase()); - continue; - } - - normalized.push(ch.to_ascii_lowercase()); - } - - normalized.trim_matches('_').to_string() -} - -fn permission_mode_from_plugin_tool(value: &str) -> Result<PermissionMode, String> { - match value { - "read-only" => Ok(PermissionMode::ReadOnly), - "workspace-write" => Ok(PermissionMode::WorkspaceWrite), - "danger-full-access" => Ok(PermissionMode::DangerFullAccess), - other => Err(format!( - "unsupported plugin tool permission `{other}` (expected read-only, workspace-write, or danger-full-access)" - )), - } -} - #[must_use] #[allow(clippy::too_many_lines)] pub fn mvp_tool_specs() -> Vec<ToolSpec> { @@ -557,7 +323,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> { }, ToolSpec { name: "Config", - description: "Get or set Claw Code settings.", + description: "Get or set Claude Code settings.", input_schema: json!({ "type": "object", "properties": { @@ -1542,11 +1308,6 @@ fn resolve_skill_path(skill: &str) -> Result<std::path::PathBuf, String> { if let Ok(codex_home) = std::env::var("CODEX_HOME") { candidates.push(std::path::PathBuf::from(codex_home).join("skills")); } - if let Ok(home) = std::env::var("HOME") { - let home = std::path::PathBuf::from(home); - candidates.push(home.join(".agents").join("skills")); - candidates.push(home.join(".codex").join("skills")); - } candidates.push(std::path::PathBuf::from("/home/bellman/.codex/skills")); for root in candidates { @@ -1698,22 +1459,20 @@ fn run_agent_job(job: &AgentJob) -> Result<(), String> { fn build_agent_runtime( job: &AgentJob, -) -> Result<ConversationRuntime<AnthropicRuntimeClient, SubagentToolExecutor>, String> { +) -> Result<ConversationRuntime<ProviderRuntimeClient, SubagentToolExecutor>, String> { let model = job .manifest .model .clone() .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); let allowed_tools = job.allowed_tools.clone(); - let tool_registry = current_tool_registry()?; - let api_client = - AnthropicRuntimeClient::new(model, allowed_tools.clone(), tool_registry.clone())?; - let tool_executor = SubagentToolExecutor::new(allowed_tools, tool_registry.clone()); + let api_client = ProviderRuntimeClient::new(model, allowed_tools.clone())?; + let tool_executor = SubagentToolExecutor::new(allowed_tools); Ok(ConversationRuntime::new( Session::new(), api_client, tool_executor, - agent_permission_policy(&tool_registry), + agent_permission_policy(), job.system_prompt.clone(), )) } @@ -1778,7 +1537,7 @@ fn allowed_tools_for_subagent(subagent_type: &str) -> BTreeSet<String> { "SendUserMessage", "PowerShell", ], - "claw-code-guide" => vec![ + "claude-code-guide" => vec![ "read_file", "glob_search", "grep_search", @@ -1822,12 +1581,10 @@ fn allowed_tools_for_subagent(subagent_type: &str) -> BTreeSet<String> { tools.into_iter().map(str::to_string).collect() } -fn agent_permission_policy(tool_registry: &GlobalToolRegistry) -> PermissionPolicy { - tool_registry.permission_specs(None).into_iter().fold( +fn agent_permission_policy() -> PermissionPolicy { + mvp_tool_specs().into_iter().fold( PermissionPolicy::new(PermissionMode::DangerFullAccess), - |policy, (name, required_permission)| { - policy.with_tool_requirement(name, required_permission) - }, + |policy, spec| policy.with_tool_requirement(spec.name, spec.required_permission), ) } @@ -1878,39 +1635,39 @@ fn format_agent_terminal_output(status: &str, result: Option<&str>, error: Optio sections.join("") } -struct AnthropicRuntimeClient { +struct ProviderRuntimeClient { runtime: tokio::runtime::Runtime, - client: ApiHttpClient, + client: ProviderClient, model: String, allowed_tools: BTreeSet<String>, - tool_registry: GlobalToolRegistry, } -impl AnthropicRuntimeClient { - fn new( - model: String, - allowed_tools: BTreeSet<String>, - tool_registry: GlobalToolRegistry, - ) -> Result<Self, String> { - let client = ApiHttpClient::from_env() - .map_err(|error| error.to_string())? - .with_base_url(read_base_url()); +impl ProviderRuntimeClient { + fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> { + let model = resolve_model_alias(&model).to_string(); + let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?; Ok(Self { runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, client, model, allowed_tools, - tool_registry, }) } } -impl ApiClient for AnthropicRuntimeClient { +impl ApiClient for ProviderRuntimeClient { fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { - let tools = self.tool_registry.definitions(Some(&self.allowed_tools)); + let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) + .into_iter() + .map(|spec| ToolDefinition { + name: spec.name.to_string(), + description: Some(spec.description.to_string()), + input_schema: spec.input_schema, + }) + .collect::<Vec<_>>(); let message_request = MessageRequest { model: self.model.clone(), - max_tokens: 32_000, + max_tokens: max_tokens_for_model(&self.model), messages: convert_messages(&request.messages), system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), tools: (!tools.is_empty()).then_some(tools), @@ -1925,7 +1682,7 @@ impl ApiClient for AnthropicRuntimeClient { .await .map_err(|error| RuntimeError::new(error.to_string()))?; let mut events = Vec::new(); - let mut pending_tool: Option<(String, String, String)> = None; + let mut pending_tools: BTreeMap<u32, (String, String, String)> = BTreeMap::new(); let mut saw_stop = false; while let Some(event) = stream @@ -1936,14 +1693,15 @@ impl ApiClient for AnthropicRuntimeClient { match event { ApiStreamEvent::MessageStart(start) => { for block in start.message.content { - push_output_block(block, &mut events, &mut pending_tool, true); + push_output_block(block, 0, &mut events, &mut pending_tools, true); } } ApiStreamEvent::ContentBlockStart(start) => { push_output_block( start.content_block, + start.index, &mut events, - &mut pending_tool, + &mut pending_tools, true, ); } @@ -1954,15 +1712,13 @@ impl ApiClient for AnthropicRuntimeClient { } } ContentBlockDelta::InputJsonDelta { partial_json } => { - if let Some((_, _, input)) = &mut pending_tool { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { input.push_str(&partial_json); } } - ContentBlockDelta::ThinkingDelta { .. } - | ContentBlockDelta::SignatureDelta { .. } => {} }, - ApiStreamEvent::ContentBlockStop(_) => { - if let Some((id, name, input)) = pending_tool.take() { + ApiStreamEvent::ContentBlockStop(stop) => { + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { events.push(AssistantEvent::ToolUse { id, name, input }); } } @@ -2012,82 +1768,32 @@ impl ApiClient for AnthropicRuntimeClient { struct SubagentToolExecutor { allowed_tools: BTreeSet<String>, - tool_registry: GlobalToolRegistry, } impl SubagentToolExecutor { - fn new(allowed_tools: BTreeSet<String>, tool_registry: GlobalToolRegistry) -> Self { - Self { - allowed_tools, - tool_registry, - } + fn new(allowed_tools: BTreeSet<String>) -> Self { + Self { allowed_tools } } } impl ToolExecutor for SubagentToolExecutor { fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> { - let entry = self - .tool_registry - .find_entry(tool_name) - .ok_or_else(|| ToolError::new(format!("unsupported tool: {tool_name}")))?; - if !self.allowed_tools.contains(entry.definition.name.as_str()) { + if !self.allowed_tools.contains(tool_name) { return Err(ToolError::new(format!( "tool `{tool_name}` is not enabled for this sub-agent" ))); } let value = serde_json::from_str(input) .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; - self.tool_registry - .execute(tool_name, &value) - .map_err(ToolError::new) + execute_tool(tool_name, &value).map_err(ToolError::new) } } -fn current_tool_registry() -> Result<GlobalToolRegistry, String> { - let cwd = std::env::current_dir().map_err(|error| error.to_string())?; - let loader = ConfigLoader::default_for(&cwd); - let runtime_config = loader.load().map_err(|error| error.to_string())?; - let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); - let plugin_tools = plugin_manager - .aggregated_tools() - .map_err(|error| error.to_string())?; - GlobalToolRegistry::with_plugin_tools(plugin_tools) -} - -fn build_plugin_manager( - cwd: &Path, - loader: &ConfigLoader, - runtime_config: &RuntimeConfig, -) -> PluginManager { - let plugin_settings = runtime_config.plugins(); - let mut plugin_config = PluginManagerConfig::new(loader.config_home().to_path_buf()); - plugin_config.enabled_plugins = plugin_settings.enabled_plugins().clone(); - plugin_config.external_dirs = plugin_settings - .external_directories() - .iter() - .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)) - .collect(); - plugin_config.install_root = plugin_settings - .install_root() - .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); - plugin_config.registry_path = plugin_settings - .registry_path() - .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); - plugin_config.bundled_root = plugin_settings - .bundled_root() - .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); - PluginManager::new(plugin_config) -} - -fn resolve_plugin_path(cwd: &Path, config_home: &Path, value: &str) -> PathBuf { - let path = PathBuf::from(value); - if path.is_absolute() { - path - } else if value.starts_with('.') { - cwd.join(path) - } else { - config_home.join(path) - } +fn tool_specs_for_allowed_tools(allowed_tools: Option<&BTreeSet<String>>) -> Vec<ToolSpec> { + mvp_tool_specs() + .into_iter() + .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) + .collect() } fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> { @@ -2133,8 +1839,9 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> { fn push_output_block( block: OutputContentBlock, + block_index: u32, events: &mut Vec<AssistantEvent>, - pending_tool: &mut Option<(String, String, String)>, + pending_tools: &mut BTreeMap<u32, (String, String, String)>, streaming_tool_input: bool, ) { match block { @@ -2152,19 +1859,19 @@ fn push_output_block( } else { input.to_string() }; - *pending_tool = Some((id, name, initial_input)); + pending_tools.insert(block_index, (id, name, initial_input)); } - OutputContentBlock::Thinking { .. } | OutputContentBlock::RedactedThinking { .. } => {} } } fn response_to_events(response: MessageResponse) -> Vec<AssistantEvent> { let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); - for block in response.content { - push_output_block(block, &mut events, &mut pending_tool, false); - if let Some((id, name, input)) = pending_tool.take() { + for (index, block) in response.content.into_iter().enumerate() { + let index = u32::try_from(index).expect("response block index overflow"); + push_output_block(block, index, &mut events, &mut pending_tools, false); + if let Some((id, name, input)) = pending_tools.remove(&index) { events.push(AssistantEvent::ToolUse { id, name, input }); } } @@ -2382,7 +2089,7 @@ fn normalize_subagent_type(subagent_type: Option<&str>) -> String { "verification" | "verificationagent" | "verify" | "verifier" => { String::from("Verification") } - "claudecodeguide" | "claudecodeguideagent" | "guide" => String::from("claw-code-guide"), + "claudecodeguide" | "claudecodeguideagent" | "guide" => String::from("claude-code-guide"), "statusline" | "statuslinesetup" => String::from("statusline-setup"), _ => trimmed.to_string(), } @@ -2882,16 +2589,16 @@ fn config_file_for_scope(scope: ConfigScope) -> Result<PathBuf, String> { let cwd = std::env::current_dir().map_err(|error| error.to_string())?; Ok(match scope { ConfigScope::Global => config_home_dir()?.join("settings.json"), - ConfigScope::Settings => cwd.join(".claw").join("settings.local.json"), + ConfigScope::Settings => cwd.join(".claude").join("settings.local.json"), }) } fn config_home_dir() -> Result<PathBuf, String> { - if let Ok(path) = std::env::var("CLAW_CONFIG_HOME") { + if let Ok(path) = std::env::var("CLAUDE_CONFIG_HOME") { return Ok(PathBuf::from(path)); } let home = std::env::var("HOME").map_err(|_| String::from("HOME is not set"))?; - Ok(PathBuf::from(home).join(".claw")) + Ok(PathBuf::from(home).join(".claude")) } fn read_json_object(path: &Path) -> Result<serde_json::Map<String, Value>, String> { @@ -3188,6 +2895,7 @@ fn parse_skill_description(contents: &str) -> Option<String> { #[cfg(test)] mod tests { + use std::collections::BTreeMap; use std::collections::BTreeSet; use std::fs; use std::io::{Read, Write}; @@ -3200,13 +2908,10 @@ mod tests { use super::{ agent_permission_policy, allowed_tools_for_subagent, execute_agent_with_spawn, execute_tool, final_assistant_text, mvp_tool_specs, persist_agent_terminal_state, - response_to_events, AgentInput, AgentJob, GlobalToolRegistry, SubagentToolExecutor, - }; - use api::{MessageResponse, OutputContentBlock, Usage}; - use plugins::{PluginTool, PluginToolDefinition, PluginToolPermission}; - use runtime::{ - ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, Session, ToolExecutor, + push_output_block, AgentInput, AgentJob, SubagentToolExecutor, }; + use api::OutputContentBlock; + use runtime::{ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, Session}; use serde_json::json; fn env_lock() -> &'static Mutex<()> { @@ -3222,17 +2927,6 @@ mod tests { std::env::temp_dir().join(format!("clawd-tools-{unique}-{name}")) } - fn make_executable(path: &PathBuf) { - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - - let mut permissions = std::fs::metadata(path).expect("metadata").permissions(); - permissions.set_mode(0o755); - std::fs::set_permissions(path, permissions).expect("chmod"); - } - } - #[test] fn exposes_mvp_tools() { let names = mvp_tool_specs() @@ -3262,170 +2956,6 @@ mod tests { assert!(error.contains("unsupported tool")); } - #[test] - fn global_registry_registers_and_executes_plugin_tools() { - let script = temp_path("plugin-tool.sh"); - std::fs::write( - &script, - "#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAWD_PLUGIN_ID\" \"$CLAWD_TOOL_NAME\" \"$INPUT\"\n", - ) - .expect("write script"); - make_executable(&script); - - let registry = GlobalToolRegistry::with_plugin_tools(vec![PluginTool::new( - "demo@external", - "demo", - PluginToolDefinition { - name: "plugin_echo".to_string(), - description: Some("Echo plugin input".to_string()), - input_schema: json!({ - "type": "object", - "properties": { "message": { "type": "string" } }, - "required": ["message"], - "additionalProperties": false - }), - }, - "sh".to_string(), - vec![script.display().to_string()], - PluginToolPermission::WorkspaceWrite, - script.parent().map(PathBuf::from), - )]) - .expect("registry should build"); - - let names = registry - .definitions(None) - .into_iter() - .map(|definition| definition.name) - .collect::<Vec<_>>(); - assert!(names.contains(&"bash".to_string())); - assert!(names.contains(&"plugin_echo".to_string())); - - let output = registry - .execute("plugin_echo", &json!({ "message": "hello" })) - .expect("plugin tool should execute"); - let payload: serde_json::Value = serde_json::from_str(&output).expect("valid json"); - assert_eq!(payload["plugin"], "demo@external"); - assert_eq!(payload["tool"], "plugin_echo"); - assert_eq!(payload["input"]["message"], "hello"); - - let _ = std::fs::remove_file(script); - } - - #[test] - fn global_registry_normalizes_plugin_tool_names_for_allowlists_and_execution() { - let script = temp_path("plugin-tool-normalized.sh"); - std::fs::write( - &script, - "#!/bin/sh\nINPUT=$(cat)\nprintf '{\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAWD_TOOL_NAME\" \"$INPUT\"\n", - ) - .expect("write script"); - make_executable(&script); - - let registry = GlobalToolRegistry::with_plugin_tools(vec![PluginTool::new( - "demo@external", - "demo", - PluginToolDefinition { - name: "plugin_echo".to_string(), - description: Some("Echo plugin input".to_string()), - input_schema: json!({ - "type": "object", - "properties": { "message": { "type": "string" } }, - "required": ["message"], - "additionalProperties": false - }), - }, - script.display().to_string(), - Vec::new(), - PluginToolPermission::WorkspaceWrite, - script.parent().map(PathBuf::from), - )]) - .expect("registry should build"); - - let allowed = registry - .normalize_allowed_tools(&[String::from("PLUGIN-ECHO")]) - .expect("plugin tool allowlist should normalize") - .expect("allowlist should be present"); - assert!(allowed.contains("plugin_echo")); - - let output = registry - .execute("plugin-echo", &json!({ "message": "hello" })) - .expect("normalized plugin tool name should execute"); - let payload: serde_json::Value = serde_json::from_str(&output).expect("valid json"); - assert_eq!(payload["tool"], "plugin_echo"); - assert_eq!(payload["input"]["message"], "hello"); - - let builtin_output = GlobalToolRegistry::builtin() - .execute("structured-output", &json!({ "ok": true })) - .expect("normalized builtin tool name should execute"); - let builtin_payload: serde_json::Value = - serde_json::from_str(&builtin_output).expect("valid json"); - assert_eq!(builtin_payload["structured_output"]["ok"], true); - - let _ = std::fs::remove_file(script); - } - - #[test] - fn subagent_executor_executes_allowed_plugin_tools() { - let script = temp_path("subagent-plugin-tool.sh"); - std::fs::write( - &script, - "#!/bin/sh\nINPUT=$(cat)\nprintf '{\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAWD_TOOL_NAME\" \"$INPUT\"\n", - ) - .expect("write script"); - make_executable(&script); - - let registry = GlobalToolRegistry::with_plugin_tools(vec![PluginTool::new( - "demo@external", - "demo", - PluginToolDefinition { - name: "plugin_echo".to_string(), - description: Some("Echo plugin input".to_string()), - input_schema: json!({ - "type": "object", - "properties": { "message": { "type": "string" } }, - "required": ["message"], - "additionalProperties": false - }), - }, - script.display().to_string(), - Vec::new(), - PluginToolPermission::WorkspaceWrite, - script.parent().map(PathBuf::from), - )]) - .expect("registry should build"); - - let mut executor = - SubagentToolExecutor::new(BTreeSet::from([String::from("plugin_echo")]), registry); - let output = executor - .execute("plugin-echo", r#"{"message":"hello"}"#) - .expect("plugin tool should execute for subagent"); - let payload: serde_json::Value = serde_json::from_str(&output).expect("valid json"); - assert_eq!(payload["tool"], "plugin_echo"); - assert_eq!(payload["input"]["message"], "hello"); - - let _ = std::fs::remove_file(script); - } - - #[test] - fn global_registry_rejects_conflicting_plugin_tool_names() { - let error = GlobalToolRegistry::with_plugin_tools(vec![PluginTool::new( - "demo@external", - "demo", - PluginToolDefinition { - name: "read-file".to_string(), - description: Some("Conflicts with builtin".to_string()), - input_schema: json!({ "type": "object" }), - }, - "echo".to_string(), - Vec::new(), - PluginToolPermission::ReadOnly, - None, - )]) - .expect_err("conflicting plugin tool should fail"); - - assert!(error.contains("conflicts with already-registered tool `read_file`")); - } - #[test] fn web_fetch_returns_prompt_aware_summary() { let server = TestServer::spawn(Arc::new(|request_line: &str| { @@ -3595,6 +3125,63 @@ mod tests { assert!(error.contains("relative URL without a base") || error.contains("empty host")); } + #[test] + fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() { + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + 1, + &mut events, + &mut pending_tools, + true, + ); + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "grep_search".to_string(), + input: json!({}), + }, + 2, + &mut events, + &mut pending_tools, + true, + ); + + pending_tools + .get_mut(&1) + .expect("first tool pending") + .2 + .push_str("{\"path\":\"src/main.rs\"}"); + pending_tools + .get_mut(&2) + .expect("second tool pending") + .2 + .push_str("{\"pattern\":\"TODO\"}"); + + assert_eq!( + pending_tools.remove(&1), + Some(( + "tool-1".to_string(), + "read_file".to_string(), + "{\"path\":\"src/main.rs\"}".to_string(), + )) + ); + assert_eq!( + pending_tools.remove(&2), + Some(( + "tool-2".to_string(), + "grep_search".to_string(), + "{\"pattern\":\"TODO\"}".to_string(), + )) + ); + } + #[test] fn todo_write_persists_and_returns_previous_state() { let _guard = env_lock() @@ -4005,11 +3592,8 @@ mod tests { calls: 0, input_path: path.display().to_string(), }, - SubagentToolExecutor::new( - BTreeSet::from([String::from("read_file")]), - GlobalToolRegistry::builtin(), - ), - agent_permission_policy(&GlobalToolRegistry::builtin()), + SubagentToolExecutor::new(BTreeSet::from([String::from("read_file")])), + agent_permission_policy(), vec![String::from("system prompt")], ); @@ -4035,42 +3619,6 @@ mod tests { let _ = std::fs::remove_file(path); } - #[test] - fn response_to_events_ignores_thinking_blocks() { - let events = response_to_events(MessageResponse { - id: "msg-1".to_string(), - kind: "message".to_string(), - model: "claude-opus-4-6".to_string(), - role: "assistant".to_string(), - content: vec![ - OutputContentBlock::Thinking { - thinking: "step 1".to_string(), - signature: Some("sig_123".to_string()), - }, - OutputContentBlock::Text { - text: "Final answer".to_string(), - }, - ], - stop_reason: Some("end_turn".to_string()), - stop_sequence: None, - usage: Usage { - input_tokens: 1, - output_tokens: 1, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }, - request_id: None, - }); - - assert!(matches!( - &events[0], - AssistantEvent::TextDelta(text) if text == "Final answer" - )); - assert!(!events - .iter() - .any(|event| matches!(event, AssistantEvent::ToolUse { .. }))); - } - #[test] fn agent_rejects_blank_required_fields() { let missing_description = execute_tool( @@ -4495,19 +4043,19 @@ mod tests { )); let home = root.join("home"); let cwd = root.join("cwd"); - std::fs::create_dir_all(home.join(".claw")).expect("home dir"); - std::fs::create_dir_all(cwd.join(".claw")).expect("cwd dir"); + std::fs::create_dir_all(home.join(".claude")).expect("home dir"); + std::fs::create_dir_all(cwd.join(".claude")).expect("cwd dir"); std::fs::write( - home.join(".claw").join("settings.json"), + home.join(".claude").join("settings.json"), r#"{"verbose":false}"#, ) .expect("write global settings"); let original_home = std::env::var("HOME").ok(); - let original_claw_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_claude_home = std::env::var("CLAUDE_CONFIG_HOME").ok(); let original_dir = std::env::current_dir().expect("cwd"); std::env::set_var("HOME", &home); - std::env::remove_var("CLAW_CONFIG_HOME"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); std::env::set_current_dir(&cwd).expect("set cwd"); let get = execute_tool("Config", &json!({"setting": "verbose"})).expect("get config"); @@ -4540,9 +4088,9 @@ mod tests { Some(value) => std::env::set_var("HOME", value), None => std::env::remove_var("HOME"), } - match original_claw_home { - Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), - None => std::env::remove_var("CLAW_CONFIG_HOME"), + match original_claude_home { + Some(value) => std::env::set_var("CLAUDE_CONFIG_HOME", value), + None => std::env::remove_var("CLAUDE_CONFIG_HOME"), } let _ = std::fs::remove_dir_all(root); }