commit 55a106196864ee0613064c1c8054a2fd74e8c189 Author: Yeachan-Heo Date: Wed Apr 1 20:36:06 2026 +0900 initial commit scaffold diff --git a/rust/crates/api/Cargo.toml b/rust/crates/api/Cargo.toml new file mode 100644 index 0000000..b9923a8 --- /dev/null +++ b/rust/crates/api/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "api" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +runtime = { path = "../runtime" } +serde = { version = "1", features = ["derive"] } +serde_json.workspace = true +tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] } + +[lints] +workspace = true diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs new file mode 100644 index 0000000..8a9c286 --- /dev/null +++ b/rust/crates/api/src/client.rs @@ -0,0 +1,141 @@ +use crate::error::ApiError; +use crate::providers::claw_provider::{self, ClawApiClient, AuthSource}; +use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig}; +use crate::providers::{self, Provider, ProviderKind}; +use crate::types::{MessageRequest, MessageResponse, StreamEvent}; + +async fn send_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.send_message(request).await +} + +async fn stream_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.stream_message(request).await +} + +#[derive(Debug, Clone)] +pub enum ProviderClient { + ClawApi(ClawApiClient), + Xai(OpenAiCompatClient), + OpenAi(OpenAiCompatClient), +} + +impl ProviderClient { + pub fn from_model(model: &str) -> Result { + Self::from_model_with_default_auth(model, None) + } + + pub fn from_model_with_default_auth( + model: &str, + default_auth: Option, + ) -> Result { + let resolved_model = providers::resolve_model_alias(model); + match providers::detect_provider_kind(&resolved_model) { + ProviderKind::ClawApi => Ok(Self::ClawApi(match default_auth { + Some(auth) => ClawApiClient::from_auth(auth), + None => ClawApiClient::from_env()?, + })), + ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env( + OpenAiCompatConfig::xai(), + )?)), + ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env( + OpenAiCompatConfig::openai(), + )?)), + } + } + + #[must_use] + pub const fn provider_kind(&self) -> ProviderKind { + match self { + Self::ClawApi(_) => ProviderKind::ClawApi, + Self::Xai(_) => ProviderKind::Xai, + Self::OpenAi(_) => ProviderKind::OpenAi, + } + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + match self { + Self::ClawApi(client) => send_via_provider(client, request).await, + Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await, + } + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + match self { + Self::ClawApi(client) => stream_via_provider(client, request) + .await + .map(MessageStream::ClawApi), + Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request) + .await + .map(MessageStream::OpenAiCompat), + } + } +} + +#[derive(Debug)] +pub enum MessageStream { + ClawApi(claw_provider::MessageStream), + OpenAiCompat(openai_compat::MessageStream), +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + match self { + Self::ClawApi(stream) => stream.request_id(), + Self::OpenAiCompat(stream) => stream.request_id(), + } + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + match self { + Self::ClawApi(stream) => stream.next_event().await, + Self::OpenAiCompat(stream) => stream.next_event().await, + } + } +} + +pub use claw_provider::{ + oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet, +}; +#[must_use] +pub fn read_base_url() -> String { + claw_provider::read_base_url() +} + +#[must_use] +pub fn read_xai_base_url() -> String { + openai_compat::read_base_url(OpenAiCompatConfig::xai()) +} + +#[cfg(test)] +mod tests { + use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind}; + + #[test] + fn resolves_existing_and_grok_aliases() { + assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); + } + + #[test] + fn provider_detection_prefers_model_family() { + assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai); + assert_eq!( + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::ClawApi + ); + } +} diff --git a/rust/crates/api/src/error.rs b/rust/crates/api/src/error.rs new file mode 100644 index 0000000..7649889 --- /dev/null +++ b/rust/crates/api/src/error.rs @@ -0,0 +1,135 @@ +use std::env::VarError; +use std::fmt::{Display, Formatter}; +use std::time::Duration; + +#[derive(Debug)] +pub enum ApiError { + MissingCredentials { + provider: &'static str, + env_vars: &'static [&'static str], + }, + ExpiredOAuthToken, + Auth(String), + InvalidApiKeyEnv(VarError), + Http(reqwest::Error), + Io(std::io::Error), + Json(serde_json::Error), + Api { + status: reqwest::StatusCode, + error_type: Option, + message: Option, + body: String, + retryable: bool, + }, + RetriesExhausted { + attempts: u32, + last_error: Box, + }, + InvalidSseFrame(&'static str), + BackoffOverflow { + attempt: u32, + base_delay: Duration, + }, +} + +impl ApiError { + #[must_use] + pub const fn missing_credentials( + provider: &'static str, + env_vars: &'static [&'static str], + ) -> Self { + Self::MissingCredentials { provider, env_vars } + } + + #[must_use] + pub fn is_retryable(&self) -> bool { + match self { + Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(), + Self::Api { retryable, .. } => *retryable, + Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(), + Self::MissingCredentials { .. } + | Self::ExpiredOAuthToken + | Self::Auth(_) + | Self::InvalidApiKeyEnv(_) + | Self::Io(_) + | Self::Json(_) + | Self::InvalidSseFrame(_) + | Self::BackoffOverflow { .. } => false, + } + } +} + +impl Display for ApiError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::MissingCredentials { provider, env_vars } => write!( + f, + "missing {provider} credentials; export {} before calling the {provider} API", + env_vars.join(" or ") + ), + Self::ExpiredOAuthToken => { + write!( + f, + "saved OAuth token is expired and no refresh token is available" + ) + } + Self::Auth(message) => write!(f, "auth error: {message}"), + Self::InvalidApiKeyEnv(error) => { + write!(f, "failed to read credential environment variable: {error}") + } + Self::Http(error) => write!(f, "http error: {error}"), + Self::Io(error) => write!(f, "io error: {error}"), + Self::Json(error) => write!(f, "json error: {error}"), + Self::Api { + status, + error_type, + message, + body, + .. + } => match (error_type, message) { + (Some(error_type), Some(message)) => { + write!(f, "api returned {status} ({error_type}): {message}") + } + _ => write!(f, "api returned {status}: {body}"), + }, + Self::RetriesExhausted { + attempts, + last_error, + } => write!(f, "api failed after {attempts} attempts: {last_error}"), + Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"), + Self::BackoffOverflow { + attempt, + base_delay, + } => write!( + f, + "retry backoff overflowed on attempt {attempt} with base delay {base_delay:?}" + ), + } + } +} + +impl std::error::Error for ApiError {} + +impl From for ApiError { + fn from(value: reqwest::Error) -> Self { + Self::Http(value) + } +} + +impl From for ApiError { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} + +impl From for ApiError { + fn from(value: serde_json::Error) -> Self { + Self::Json(value) + } +} + +impl From for ApiError { + fn from(value: VarError) -> Self { + Self::InvalidApiKeyEnv(value) + } +} diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs new file mode 100644 index 0000000..2b2584a --- /dev/null +++ b/rust/crates/api/src/lib.rs @@ -0,0 +1,23 @@ +mod client; +mod error; +mod providers; +mod sse; +mod types; + +pub use client::{ + oauth_token_is_expired, read_base_url, read_xai_base_url, resolve_saved_oauth_token, + resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient, +}; +pub use error::ApiError; +pub use providers::claw_provider::{ClawApiClient, ClawApiClient as ApiClient, AuthSource}; +pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig}; +pub use providers::{ + detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind, +}; +pub use sse::{parse_frame, SseParser}; +pub use types::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, + MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, + ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, +}; diff --git a/rust/crates/api/src/providers/claw_provider.rs b/rust/crates/api/src/providers/claw_provider.rs new file mode 100644 index 0000000..55d7f37 --- /dev/null +++ b/rust/crates/api/src/providers/claw_provider.rs @@ -0,0 +1,1046 @@ +use std::collections::VecDeque; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use runtime::{ + load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, + OAuthTokenExchangeRequest, +}; +use serde::Deserialize; + +use crate::error::ApiError; + +use super::{Provider, ProviderFuture}; +use crate::sse::SseParser; +use crate::types::{MessageRequest, MessageResponse, StreamEvent}; + +pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthSource { + None, + ApiKey(String), + BearerToken(String), + ApiKeyAndBearer { + api_key: String, + bearer_token: String, + }, +} + +impl AuthSource { + pub fn from_env() -> Result { + let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; + let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; + match (api_key, auth_token) { + (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + (Some(api_key), None) => Ok(Self::ApiKey(api_key)), + (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), + (None, None) => Err(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), + } + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + match self { + Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), + Self::None | Self::BearerToken(_) => None, + } + } + + #[must_use] + pub fn bearer_token(&self) -> Option<&str> { + match self { + Self::BearerToken(token) + | Self::ApiKeyAndBearer { + bearer_token: token, + .. + } => Some(token), + Self::None | Self::ApiKey(_) => None, + } + } + + #[must_use] + pub fn masked_authorization_header(&self) -> &'static str { + if self.bearer_token().is_some() { + "Bearer [REDACTED]" + } else { + "" + } + } + + pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(api_key) = self.api_key() { + request_builder = request_builder.header("x-api-key", api_key); + } + if let Some(token) = self.bearer_token() { + request_builder = request_builder.bearer_auth(token); + } + request_builder + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + #[serde(default)] + pub scopes: Vec, +} + +impl From for AuthSource { + fn from(value: OAuthTokenSet) -> Self { + Self::BearerToken(value.access_token) + } +} + +#[derive(Debug, Clone)] +pub struct ClawApiClient { + http: reqwest::Client, + auth: AuthSource, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl ClawApiClient { + #[must_use] + pub fn new(api_key: impl Into) -> Self { + Self { + http: reqwest::Client::new(), + auth: AuthSource::ApiKey(api_key.into()), + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + #[must_use] + pub fn from_auth(auth: AuthSource) -> Self { + Self { + http: reqwest::Client::new(), + auth, + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env() -> Result { + Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) + } + + #[must_use] + pub fn with_auth_source(mut self, auth: AuthSource) -> Self { + self.auth = auth; + self + } + + #[must_use] + pub fn with_auth_token(mut self, auth_token: Option) -> Self { + match ( + self.auth.api_key().map(ToOwned::to_owned), + auth_token.filter(|token| !token.is_empty()), + ) { + (Some(api_key), Some(bearer_token)) => { + self.auth = AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }; + } + (Some(api_key), None) => { + self.auth = AuthSource::ApiKey(api_key); + } + (None, Some(bearer_token)) => { + self.auth = AuthSource::BearerToken(bearer_token); + } + (None, None) => { + self.auth = AuthSource::None; + } + } + self + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + #[must_use] + pub fn auth_source(&self) -> &AuthSource { + &self.auth + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let mut response = response + .json::() + .await + .map_err(ApiError::from)?; + if response.request_id.is_none() { + response.request_id = request_id; + } + Ok(response) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: SseParser::new(), + pending: VecDeque::new(), + done: false, + }) + } + + pub async fn exchange_oauth_code( + &self, + config: &OAuthConfig, + request: &OAuthTokenExchangeRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + pub async fn refresh_oauth_token( + &self, + config: &OAuthConfig, + request: &OAuthRefreshRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + let mut last_error: Option; + + loop { + attempts += 1; + match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + } + + if attempts > self.max_retries { + break; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + } + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error.expect("retry loop must capture an error")), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); + let request_builder = self + .http + .post(&request_url) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json"); + let mut request_builder = self.auth.apply(request_builder); + + request_builder = request_builder.json(request); + request_builder.send().await.map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl AuthSource { + pub fn from_env_or_saved() -> Result { + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(Self::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(Self::BearerToken(bearer_token)); + } + match load_saved_oauth_token() { + Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { + if token_set.refresh_token.is_some() { + Err(ApiError::Auth( + "saved OAuth token is expired; load runtime OAuth config to refresh it" + .to_string(), + )) + } else { + Err(ApiError::ExpiredOAuthToken) + } + } + Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), + Ok(None) => Err(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), + Err(error) => Err(error), + } + } +} + +#[must_use] +pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { + token_set + .expires_at + .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) +} + +pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { + let Some(token_set) = load_saved_oauth_token()? else { + return Ok(None); + }; + resolve_saved_oauth_token_set(config, token_set).map(Some) +} + +pub fn has_auth_from_env_or_saved() -> Result { + Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some() + || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some() + || load_saved_oauth_token()?.is_some()) +} + +pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result +where + F: FnOnce() -> Result, ApiError>, +{ + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(AuthSource::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(AuthSource::BearerToken(bearer_token)); + } + + let Some(token_set) = load_saved_oauth_token()? else { + return Err(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )); + }; + if !oauth_token_is_expired(&token_set) { + return Ok(AuthSource::BearerToken(token_set.access_token)); + } + if token_set.refresh_token.is_none() { + return Err(ApiError::ExpiredOAuthToken); + } + + let Some(config) = load_oauth_config()? else { + return Err(ApiError::Auth( + "saved OAuth token is expired; runtime OAuth config is missing".to_string(), + )); + }; + Ok(AuthSource::from(resolve_saved_oauth_token_set( + &config, token_set, + )?)) +} + +fn resolve_saved_oauth_token_set( + config: &OAuthConfig, + token_set: OAuthTokenSet, +) -> Result { + if !oauth_token_is_expired(&token_set) { + return Ok(token_set); + } + let Some(refresh_token) = token_set.refresh_token.clone() else { + return Err(ApiError::ExpiredOAuthToken); + }; + let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(read_base_url()); + let refreshed = client_runtime_block_on(async { + client + .refresh_oauth_token( + config, + &OAuthRefreshRequest::from_config( + config, + refresh_token, + Some(token_set.scopes.clone()), + ), + ) + .await + })?; + let resolved = OAuthTokenSet { + access_token: refreshed.access_token, + refresh_token: refreshed.refresh_token.or(token_set.refresh_token), + expires_at: refreshed.expires_at, + scopes: refreshed.scopes, + }; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: resolved.access_token.clone(), + refresh_token: resolved.refresh_token.clone(), + expires_at: resolved.expires_at, + scopes: resolved.scopes.clone(), + }) + .map_err(ApiError::from)?; + Ok(resolved) +} + +fn client_runtime_block_on(future: F) -> Result +where + F: std::future::Future>, +{ + tokio::runtime::Runtime::new() + .map_err(ApiError::from)? + .block_on(future) +} + +fn load_saved_oauth_token() -> Result, ApiError> { + let token_set = load_oauth_credentials().map_err(ApiError::from)?; + Ok(token_set.map(|token_set| OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })) +} + +fn now_unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[cfg(test)] +fn read_api_key() -> Result { + let auth = AuthSource::from_env_or_saved()?; + auth.api_key() + .or_else(|| auth.bearer_token()) + .map(ToOwned::to_owned) + .ok_or(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )) +} + +#[cfg(test)] +fn read_auth_token() -> Option { + read_env_non_empty("ANTHROPIC_AUTH_TOKEN") + .ok() + .and_then(std::convert::identity) +} + +#[must_use] +pub fn read_base_url() -> String { + std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +impl Provider for ClawApiClient { + type Stream = MessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: SseParser, + pending: VecDeque, + done: bool, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + let remaining = self.parser.finish()?; + self.pending.extend(remaining); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + self.pending.extend(self.parser.push(&chunk)?); + } + None => { + self.done = true; + } + } + } + } +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_else(|_| String::new()); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .map(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .map(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +#[derive(Debug, Deserialize)] +struct ApiErrorEnvelope { + error: ApiErrorBody, +} + +#[derive(Debug, Deserialize)] +struct ApiErrorBody { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +#[cfg(test)] +mod tests { + use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::sync::{Mutex, OnceLock}; + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; + + use super::{ + now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, + resolve_startup_auth_source, ClawApiClient, AuthSource, OAuthTokenSet, + }; + use crate::types::{ContentBlockDelta, MessageRequest}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + fn temp_config_home() -> std::path::PathBuf { + std::env::temp_dir().join(format!( + "api-oauth-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )) + } + + fn cleanup_temp_config_home(config_home: &std::path::Path) { + match std::fs::remove_dir_all(config_home) { + Ok(()) => {} + Err(error) if error.kind() == std::io::ErrorKind::NotFound => {} + Err(error) => panic!("cleanup temp dir: {error}"), + } + } + + fn sample_oauth_config(token_url: String) -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url, + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:read".to_string(), "user:write".to_string()], + } + } + + fn spawn_token_server(response_body: &'static str) -> String { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let address = listener.local_addr().expect("local addr"); + thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept connection"); + let mut buffer = [0_u8; 4096]; + let _ = stream.read(&mut buffer).expect("read request"); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + response_body.len(), + response_body + ); + stream + .write_all(response.as_bytes()) + .expect("write response"); + }); + format!("http://{address}/oauth/token") + } + + #[test] + fn read_api_key_requires_presence() { + let _guard = env_lock(); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + std::env::remove_var("CLAW_CONFIG_HOME"); + let error = super::read_api_key().expect_err("missing key should error"); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); + } + + #[test] + fn read_api_key_requires_non_empty_value() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); + std::env::remove_var("ANTHROPIC_API_KEY"); + let error = super::read_api_key().expect_err("empty key should error"); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn read_api_key_prefers_api_key_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + assert_eq!( + super::read_api_key().expect("api key should load"), + "legacy-key" + ); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn read_auth_token_reads_auth_token_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn oauth_token_maps_to_bearer_auth_source() { + let auth = AuthSource::from(OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(123), + scopes: vec!["scope:a".to_string()], + }); + assert_eq!(auth.bearer_token(), Some("access-token")); + assert_eq!(auth.api_key(), None); + } + + #[test] + fn auth_source_from_env_combines_api_key_and_bearer_token() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + let auth = AuthSource::from_env().expect("env auth"); + assert_eq!(auth.api_key(), Some("legacy-key")); + assert_eq!(auth.bearer_token(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn auth_source_from_saved_oauth_when_env_absent() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = AuthSource::from_env_or_saved().expect("saved auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn oauth_token_expiry_uses_expires_at_timestamp() { + assert!(oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(1), + scopes: Vec::new(), + })); + assert!(!oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(now_unix_timestamp() + 60), + scopes: Vec::new(), + })); + } + + #[test] + fn resolve_saved_oauth_token_refreshes_expired_credentials() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "refreshed-token"); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) + .expect("startup auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let error = + resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); + assert!( + matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) + ); + + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "expired-access-token"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn message_request_stream_helper_sets_stream_true() { + let request = MessageRequest { + model: "claude-opus-4-6".to_string(), + max_tokens: 64, + messages: vec![], + system: None, + tools: None, + tool_choice: None, + stream: false, + }; + + assert!(request.with_streaming().stream); + } + + #[test] + fn backoff_doubles_until_maximum() { + let client = ClawApiClient::new("test-key").with_retry_policy( + 3, + Duration::from_millis(10), + Duration::from_millis(25), + ); + assert_eq!( + client.backoff_for_attempt(1).expect("attempt 1"), + Duration::from_millis(10) + ); + assert_eq!( + client.backoff_for_attempt(2).expect("attempt 2"), + Duration::from_millis(20) + ); + assert_eq!( + client.backoff_for_attempt(3).expect("attempt 3"), + Duration::from_millis(25) + ); + } + + #[test] + fn retryable_statuses_are_detected() { + assert!(super::is_retryable_status( + reqwest::StatusCode::TOO_MANY_REQUESTS + )); + assert!(super::is_retryable_status( + reqwest::StatusCode::INTERNAL_SERVER_ERROR + )); + assert!(!super::is_retryable_status( + reqwest::StatusCode::UNAUTHORIZED + )); + } + + #[test] + fn tool_delta_variant_round_trips() { + let delta = ContentBlockDelta::InputJsonDelta { + partial_json: "{\"city\":\"Paris\"}".to_string(), + }; + let encoded = serde_json::to_string(&delta).expect("delta should serialize"); + let decoded: ContentBlockDelta = + serde_json::from_str(&encoded).expect("delta should deserialize"); + assert_eq!(decoded, delta); + } + + #[test] + fn request_id_uses_primary_or_fallback_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_primary") + ); + + headers.clear(); + headers.insert( + ALT_REQUEST_ID_HEADER, + "req_fallback".parse().expect("header"), + ); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_fallback") + ); + } + + #[test] + fn auth_source_applies_headers() { + let auth = AuthSource::ApiKeyAndBearer { + api_key: "test-key".to_string(), + bearer_token: "proxy-token".to_string(), + }; + let request = auth + .apply(reqwest::Client::new().post("https://example.test")) + .build() + .expect("request build"); + let headers = request.headers(); + assert_eq!( + headers.get("x-api-key").and_then(|v| v.to_str().ok()), + Some("test-key") + ); + assert_eq!( + headers.get("authorization").and_then(|v| v.to_str().ok()), + Some("Bearer proxy-token") + ); + } +} diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs new file mode 100644 index 0000000..192afd6 --- /dev/null +++ b/rust/crates/api/src/providers/mod.rs @@ -0,0 +1,239 @@ +use std::future::Future; +use std::pin::Pin; + +use crate::error::ApiError; +use crate::types::{MessageRequest, MessageResponse}; + +pub mod claw_provider; +pub mod openai_compat; + +pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; + +pub trait Provider { + type Stream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse>; + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream>; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProviderKind { + ClawApi, + Xai, + OpenAi, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ProviderMetadata { + pub provider: ProviderKind, + pub auth_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ + ( + "opus", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "sonnet", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "haiku", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "claude-opus-4-6", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "claude-sonnet-4-6", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "claude-haiku-4-5-20251213", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "grok", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-2", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), +]; + +#[must_use] +pub fn resolve_model_alias(model: &str) -> String { + let trimmed = model.trim(); + let lower = trimmed.to_ascii_lowercase(); + MODEL_REGISTRY + .iter() + .find_map(|(alias, metadata)| { + (*alias == lower).then_some(match metadata.provider { + ProviderKind::ClawApi => match *alias { + "opus" => "claude-opus-4-6", + "sonnet" => "claude-sonnet-4-6", + "haiku" => "claude-haiku-4-5-20251213", + _ => trimmed, + }, + ProviderKind::Xai => match *alias { + "grok" | "grok-3" => "grok-3", + "grok-mini" | "grok-3-mini" => "grok-3-mini", + "grok-2" => "grok-2", + _ => trimmed, + }, + ProviderKind::OpenAi => trimmed, + }) + }) + .map_or_else(|| trimmed.to_string(), ToOwned::to_owned) +} + +#[must_use] +pub fn metadata_for_model(model: &str) -> Option { + let canonical = resolve_model_alias(model); + let lower = canonical.to_ascii_lowercase(); + if let Some((_, metadata)) = MODEL_REGISTRY.iter().find(|(alias, _)| *alias == lower) { + return Some(*metadata); + } + if lower.starts_with("grok") { + return Some(ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }); + } + None +} + +#[must_use] +pub fn detect_provider_kind(model: &str) -> ProviderKind { + if let Some(metadata) = metadata_for_model(model) { + return metadata.provider; + } + if claw_provider::has_auth_from_env_or_saved().unwrap_or(false) { + return ProviderKind::ClawApi; + } + if openai_compat::has_api_key("OPENAI_API_KEY") { + return ProviderKind::OpenAi; + } + if openai_compat::has_api_key("XAI_API_KEY") { + return ProviderKind::Xai; + } + ProviderKind::ClawApi +} + +#[must_use] +pub fn max_tokens_for_model(model: &str) -> u32 { + let canonical = resolve_model_alias(model); + if canonical.contains("opus") { + 32_000 + } else { + 64_000 + } +} + +#[cfg(test)] +mod tests { + use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind}; + + #[test] + fn resolves_grok_aliases() { + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); + assert_eq!(resolve_model_alias("grok-2"), "grok-2"); + } + + #[test] + fn detects_provider_from_model_name_first() { + assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai); + assert_eq!( + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::ClawApi + ); + } + + #[test] + fn keeps_existing_max_token_heuristic() { + assert_eq!(max_tokens_for_model("opus"), 32_000); + assert_eq!(max_tokens_for_model("grok-3"), 64_000); + } +} diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs new file mode 100644 index 0000000..e8210ae --- /dev/null +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -0,0 +1,1050 @@ +use std::collections::{BTreeMap, VecDeque}; +use std::time::Duration; + +use serde::Deserialize; +use serde_json::{json, Value}; + +use crate::error::ApiError; +use crate::types::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, + MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, + ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, +}; + +use super::{Provider, ProviderFuture}; + +pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; +pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct OpenAiCompatConfig { + pub provider_name: &'static str, + pub api_key_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; +const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; + +impl OpenAiCompatConfig { + #[must_use] + pub const fn xai() -> Self { + Self { + provider_name: "xAI", + api_key_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: DEFAULT_XAI_BASE_URL, + } + } + + #[must_use] + pub const fn openai() -> Self { + Self { + provider_name: "OpenAI", + api_key_env: "OPENAI_API_KEY", + base_url_env: "OPENAI_BASE_URL", + default_base_url: DEFAULT_OPENAI_BASE_URL, + } + } + #[must_use] + pub fn credential_env_vars(self) -> &'static [&'static str] { + match self.provider_name { + "xAI" => XAI_ENV_VARS, + "OpenAI" => OPENAI_ENV_VARS, + _ => &[], + } + } +} + +#[derive(Debug, Clone)] +pub struct OpenAiCompatClient { + http: reqwest::Client, + api_key: String, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl OpenAiCompatClient { + #[must_use] + pub fn new(api_key: impl Into, config: OpenAiCompatConfig) -> Self { + Self { + http: reqwest::Client::new(), + api_key: api_key.into(), + base_url: read_base_url(config), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env(config: OpenAiCompatConfig) -> Result { + let Some(api_key) = read_env_non_empty(config.api_key_env)? else { + return Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )); + }; + Ok(Self::new(api_key, config)) + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let payload = response.json::().await?; + let mut normalized = normalize_response(&request.model, payload)?; + if normalized.request_id.is_none() { + normalized.request_id = request_id; + } + Ok(normalized) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: OpenAiSseParser::new(), + pending: VecDeque::new(), + done: false, + state: StreamState::new(request.model.clone()), + }) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + + let last_error = loop { + attempts += 1; + let retryable_error = match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }; + + if attempts > self.max_retries { + break retryable_error; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + }; + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = chat_completions_endpoint(&self.base_url); + self.http + .post(&request_url) + .header("content-type", "application/json") + .bearer_auth(&self.api_key) + .json(&build_chat_completion_request(request)) + .send() + .await + .map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl Provider for OpenAiCompatClient { + type Stream = MessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: OpenAiSseParser, + pending: VecDeque, + done: bool, + state: StreamState, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + self.pending.extend(self.state.finish()?); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + for parsed in self.parser.push(&chunk)? { + self.pending.extend(self.state.ingest_chunk(parsed)?); + } + } + None => { + self.done = true; + } + } + } + } +} + +#[derive(Debug, Default)] +struct OpenAiSseParser { + buffer: Vec, +} + +impl OpenAiSseParser { + fn new() -> Self { + Self::default() + } + + fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { + self.buffer.extend_from_slice(chunk); + let mut events = Vec::new(); + + while let Some(frame) = next_sse_frame(&mut self.buffer) { + if let Some(event) = parse_sse_frame(&frame)? { + events.push(event); + } + } + + Ok(events) + } +} + +#[derive(Debug)] +struct StreamState { + model: String, + message_started: bool, + text_started: bool, + text_finished: bool, + finished: bool, + stop_reason: Option, + usage: Option, + tool_calls: BTreeMap, +} + +impl StreamState { + fn new(model: String) -> Self { + Self { + model, + message_started: false, + text_started: false, + text_finished: false, + finished: false, + stop_reason: None, + usage: None, + tool_calls: BTreeMap::new(), + } + } + + fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result, ApiError> { + let mut events = Vec::new(); + if !self.message_started { + self.message_started = true; + events.push(StreamEvent::MessageStart(MessageStartEvent { + message: MessageResponse { + id: chunk.id.clone(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: Vec::new(), + model: chunk.model.clone().unwrap_or_else(|| self.model.clone()), + stop_reason: None, + stop_sequence: None, + usage: Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }, + request_id: None, + }, + })); + } + + if let Some(usage) = chunk.usage { + self.usage = Some(Usage { + input_tokens: usage.prompt_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: usage.completion_tokens, + }); + } + + for choice in chunk.choices { + if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) { + if !self.text_started { + self.text_started = true; + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 0, + content_block: OutputContentBlock::Text { + text: String::new(), + }, + })); + } + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::TextDelta { text: content }, + })); + } + + for tool_call in choice.delta.tool_calls { + let state = self.tool_calls.entry(tool_call.index).or_default(); + state.apply(tool_call); + let block_index = state.block_index(); + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + } else { + continue; + } + } + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + } + + if let Some(finish_reason) = choice.finish_reason { + self.stop_reason = Some(normalize_finish_reason(&finish_reason)); + if finish_reason == "tool_calls" { + for state in self.tool_calls.values_mut() { + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + } + } + } + + Ok(events) + } + + fn finish(&mut self) -> Result, ApiError> { + if self.finished { + return Ok(Vec::new()); + } + self.finished = true; + + let mut events = Vec::new(); + if self.text_started && !self.text_finished { + self.text_finished = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: 0, + })); + } + + for state in self.tool_calls.values_mut() { + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + } + } + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + + if self.message_started { + events.push(StreamEvent::MessageDelta(MessageDeltaEvent { + delta: MessageDelta { + stop_reason: Some( + self.stop_reason + .clone() + .unwrap_or_else(|| "end_turn".to_string()), + ), + stop_sequence: None, + }, + usage: self.usage.clone().unwrap_or(Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }), + })); + events.push(StreamEvent::MessageStop(MessageStopEvent {})); + } + Ok(events) + } +} + +#[derive(Debug, Default)] +struct ToolCallState { + openai_index: u32, + id: Option, + name: Option, + arguments: String, + emitted_len: usize, + started: bool, + stopped: bool, +} + +impl ToolCallState { + fn apply(&mut self, tool_call: DeltaToolCall) { + self.openai_index = tool_call.index; + if let Some(id) = tool_call.id { + self.id = Some(id); + } + if let Some(name) = tool_call.function.name { + self.name = Some(name); + } + if let Some(arguments) = tool_call.function.arguments { + self.arguments.push_str(&arguments); + } + } + + const fn block_index(&self) -> u32 { + self.openai_index + 1 + } + + fn start_event(&self) -> Result, ApiError> { + let Some(name) = self.name.clone() else { + return Ok(None); + }; + let id = self + .id + .clone() + .unwrap_or_else(|| format!("tool_call_{}", self.openai_index)); + Ok(Some(ContentBlockStartEvent { + index: self.block_index(), + content_block: OutputContentBlock::ToolUse { + id, + name, + input: json!({}), + }, + })) + } + + fn delta_event(&mut self) -> Option { + if self.emitted_len >= self.arguments.len() { + return None; + } + let delta = self.arguments[self.emitted_len..].to_string(); + self.emitted_len = self.arguments.len(); + Some(ContentBlockDeltaEvent { + index: self.block_index(), + delta: ContentBlockDelta::InputJsonDelta { + partial_json: delta, + }, + }) + } +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionResponse { + id: String, + model: String, + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatChoice { + message: ChatMessage, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatMessage { + role: String, + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolCall { + id: String, + function: ResponseToolFunction, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolFunction { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct OpenAiUsage { + #[serde(default)] + prompt_tokens: u32, + #[serde(default)] + completion_tokens: u32, +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionChunk { + id: String, + #[serde(default)] + model: Option, + #[serde(default)] + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChunkChoice { + delta: ChunkDelta, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Default, Deserialize)] +struct ChunkDelta { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct DeltaToolCall { + #[serde(default)] + index: u32, + #[serde(default)] + id: Option, + #[serde(default)] + function: DeltaFunction, +} + +#[derive(Debug, Default, Deserialize)] +struct DeltaFunction { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct ErrorEnvelope { + error: ErrorBody, +} + +#[derive(Debug, Deserialize)] +struct ErrorBody { + #[serde(rename = "type")] + error_type: Option, + message: Option, +} + +fn build_chat_completion_request(request: &MessageRequest) -> Value { + let mut messages = Vec::new(); + if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { + messages.push(json!({ + "role": "system", + "content": system, + })); + } + for message in &request.messages { + messages.extend(translate_message(message)); + } + + let mut payload = json!({ + "model": request.model, + "max_tokens": request.max_tokens, + "messages": messages, + "stream": request.stream, + }); + + if let Some(tools) = &request.tools { + payload["tools"] = + Value::Array(tools.iter().map(openai_tool_definition).collect::>()); + } + if let Some(tool_choice) = &request.tool_choice { + payload["tool_choice"] = openai_tool_choice(tool_choice); + } + + payload +} + +fn translate_message(message: &InputMessage) -> Vec { + match message.role.as_str() { + "assistant" => { + let mut text = String::new(); + let mut tool_calls = Vec::new(); + for block in &message.content { + match block { + InputContentBlock::Text { text: value } => text.push_str(value), + InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({ + "id": id, + "type": "function", + "function": { + "name": name, + "arguments": input.to_string(), + } + })), + InputContentBlock::ToolResult { .. } => {} + } + } + if text.is_empty() && tool_calls.is_empty() { + Vec::new() + } else { + vec![json!({ + "role": "assistant", + "content": (!text.is_empty()).then_some(text), + "tool_calls": tool_calls, + })] + } + } + _ => message + .content + .iter() + .filter_map(|block| match block { + InputContentBlock::Text { text } => Some(json!({ + "role": "user", + "content": text, + })), + InputContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } => Some(json!({ + "role": "tool", + "tool_call_id": tool_use_id, + "content": flatten_tool_result_content(content), + "is_error": is_error, + })), + InputContentBlock::ToolUse { .. } => None, + }) + .collect(), + } +} + +fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String { + content + .iter() + .map(|block| match block { + ToolResultContentBlock::Text { text } => text.clone(), + ToolResultContentBlock::Json { value } => value.to_string(), + }) + .collect::>() + .join("\n") +} + +fn openai_tool_definition(tool: &ToolDefinition) -> Value { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + } + }) +} + +fn openai_tool_choice(tool_choice: &ToolChoice) -> Value { + match tool_choice { + ToolChoice::Auto => Value::String("auto".to_string()), + ToolChoice::Any => Value::String("required".to_string()), + ToolChoice::Tool { name } => json!({ + "type": "function", + "function": { "name": name }, + }), + } +} + +fn normalize_response( + model: &str, + response: ChatCompletionResponse, +) -> Result { + let choice = response + .choices + .into_iter() + .next() + .ok_or(ApiError::InvalidSseFrame( + "chat completion response missing choices", + ))?; + let mut content = Vec::new(); + if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) { + content.push(OutputContentBlock::Text { text }); + } + for tool_call in choice.message.tool_calls { + content.push(OutputContentBlock::ToolUse { + id: tool_call.id, + name: tool_call.function.name, + input: parse_tool_arguments(&tool_call.function.arguments), + }); + } + + Ok(MessageResponse { + id: response.id, + kind: "message".to_string(), + role: choice.message.role, + content, + model: response.model.if_empty_then(model.to_string()), + stop_reason: choice + .finish_reason + .map(|value| normalize_finish_reason(&value)), + stop_sequence: None, + usage: Usage { + input_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.prompt_tokens), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.completion_tokens), + }, + request_id: None, + }) +} + +fn parse_tool_arguments(arguments: &str) -> Value { + serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments })) +} + +fn next_sse_frame(buffer: &mut Vec) -> Option { + let separator = buffer + .windows(2) + .position(|window| window == b"\n\n") + .map(|position| (position, 2)) + .or_else(|| { + buffer + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|position| (position, 4)) + })?; + + let (position, separator_len) = separator; + let frame = buffer.drain(..position + separator_len).collect::>(); + let frame_len = frame.len().saturating_sub(separator_len); + Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) +} + +fn parse_sse_frame(frame: &str) -> Result, ApiError> { + let trimmed = frame.trim(); + if trimmed.is_empty() { + return Ok(None); + } + + let mut data_lines = Vec::new(); + for line in trimmed.lines() { + if line.starts_with(':') { + continue; + } + if let Some(data) = line.strip_prefix("data:") { + data_lines.push(data.trim_start()); + } + } + if data_lines.is_empty() { + return Ok(None); + } + let payload = data_lines.join("\n"); + if payload == "[DONE]" { + return Ok(None); + } + serde_json::from_str(&payload) + .map(Some) + .map_err(ApiError::from) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[must_use] +pub fn has_api_key(key: &str) -> bool { + read_env_non_empty(key) + .ok() + .and_then(std::convert::identity) + .is_some() +} + +#[must_use] +pub fn read_base_url(config: OpenAiCompatConfig) -> String { + std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) +} + +fn chat_completions_endpoint(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/chat/completions") { + trimmed.to_string() + } else { + format!("{trimmed}/chat/completions") + } +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_default(); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .and_then(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .and_then(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +fn normalize_finish_reason(value: &str) -> String { + match value { + "stop" => "end_turn", + "tool_calls" => "tool_use", + other => other, + } + .to_string() +} + +trait StringExt { + fn if_empty_then(self, fallback: String) -> String; +} + +impl StringExt for String { + fn if_empty_then(self, fallback: String) -> String { + if self.is_empty() { + fallback + } else { + self + } + } +} + +#[cfg(test)] +mod tests { + use super::{ + build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason, + openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, + }; + use crate::error::ApiError; + use crate::types::{ + InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, + ToolResultContentBlock, + }; + use serde_json::json; + use std::sync::{Mutex, OnceLock}; + + #[test] + fn request_translation_uses_openai_compatible_shape() { + let payload = build_chat_completion_request(&MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![ + InputContentBlock::Text { + text: "hello".to_string(), + }, + InputContentBlock::ToolResult { + tool_use_id: "tool_1".to_string(), + content: vec![ToolResultContentBlock::Json { + value: json!({"ok": true}), + }], + is_error: false, + }, + ], + }], + system: Some("be helpful".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Get weather".to_string()), + input_schema: json!({"type": "object"}), + }]), + tool_choice: Some(ToolChoice::Auto), + stream: false, + }); + + assert_eq!(payload["messages"][0]["role"], json!("system")); + assert_eq!(payload["messages"][1]["role"], json!("user")); + assert_eq!(payload["messages"][2]["role"], json!("tool")); + assert_eq!(payload["tools"][0]["type"], json!("function")); + assert_eq!(payload["tool_choice"], json!("auto")); + } + + #[test] + fn tool_choice_translation_supports_required_function() { + assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required")); + assert_eq!( + openai_tool_choice(&ToolChoice::Tool { + name: "weather".to_string(), + }), + json!({"type": "function", "function": {"name": "weather"}}) + ); + } + + #[test] + fn parses_tool_arguments_fallback() { + assert_eq!( + parse_tool_arguments("{\"city\":\"Paris\"}"), + json!({"city": "Paris"}) + ); + assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"})); + } + + #[test] + fn missing_xai_api_key_is_provider_specific() { + let _lock = env_lock(); + std::env::remove_var("XAI_API_KEY"); + let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai()) + .expect_err("missing key should error"); + assert!(matches!( + error, + ApiError::MissingCredentials { + provider: "xAI", + .. + } + )); + } + + #[test] + fn endpoint_builder_accepts_base_urls_and_full_endpoints() { + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/chat/completions"), + "https://api.x.ai/v1/chat/completions" + ); + } + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + #[test] + fn normalizes_stop_reasons() { + assert_eq!(normalize_finish_reason("stop"), "end_turn"); + assert_eq!(normalize_finish_reason("tool_calls"), "tool_use"); + } +} diff --git a/rust/crates/api/src/sse.rs b/rust/crates/api/src/sse.rs new file mode 100644 index 0000000..5f54e50 --- /dev/null +++ b/rust/crates/api/src/sse.rs @@ -0,0 +1,279 @@ +use crate::error::ApiError; +use crate::types::StreamEvent; + +#[derive(Debug, Default)] +pub struct SseParser { + buffer: Vec, +} + +impl SseParser { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { + self.buffer.extend_from_slice(chunk); + let mut events = Vec::new(); + + while let Some(frame) = self.next_frame() { + if let Some(event) = parse_frame(&frame)? { + events.push(event); + } + } + + Ok(events) + } + + pub fn finish(&mut self) -> Result, ApiError> { + if self.buffer.is_empty() { + return Ok(Vec::new()); + } + + let trailing = std::mem::take(&mut self.buffer); + match parse_frame(&String::from_utf8_lossy(&trailing))? { + Some(event) => Ok(vec![event]), + None => Ok(Vec::new()), + } + } + + fn next_frame(&mut self) -> Option { + let separator = self + .buffer + .windows(2) + .position(|window| window == b"\n\n") + .map(|position| (position, 2)) + .or_else(|| { + self.buffer + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|position| (position, 4)) + })?; + + let (position, separator_len) = separator; + let frame = self + .buffer + .drain(..position + separator_len) + .collect::>(); + let frame_len = frame.len().saturating_sub(separator_len); + Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) + } +} + +pub fn parse_frame(frame: &str) -> Result, ApiError> { + let trimmed = frame.trim(); + if trimmed.is_empty() { + return Ok(None); + } + + let mut data_lines = Vec::new(); + let mut event_name: Option<&str> = None; + + for line in trimmed.lines() { + if line.starts_with(':') { + continue; + } + if let Some(name) = line.strip_prefix("event:") { + event_name = Some(name.trim()); + continue; + } + if let Some(data) = line.strip_prefix("data:") { + data_lines.push(data.trim_start()); + } + } + + if matches!(event_name, Some("ping")) { + return Ok(None); + } + + if data_lines.is_empty() { + return Ok(None); + } + + let payload = data_lines.join("\n"); + if payload == "[DONE]" { + return Ok(None); + } + + serde_json::from_str::(&payload) + .map(Some) + .map_err(ApiError::from) +} + +#[cfg(test)] +mod tests { + use super::{parse_frame, SseParser}; + use crate::types::{ContentBlockDelta, MessageDelta, OutputContentBlock, StreamEvent, Usage}; + + #[test] + fn parses_single_frame() { + let frame = concat!( + "event: content_block_start\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"Hi\"}}\n\n" + ); + + let event = parse_frame(frame).expect("frame should parse"); + assert_eq!( + event, + Some(StreamEvent::ContentBlockStart( + crate::types::ContentBlockStartEvent { + index: 0, + content_block: OutputContentBlock::Text { + text: "Hi".to_string(), + }, + }, + )) + ); + } + + #[test] + fn parses_chunked_stream() { + let mut parser = SseParser::new(); + let first = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hel"; + let second = b"lo\"}}\n\n"; + + assert!(parser + .push(first) + .expect("first chunk should buffer") + .is_empty()); + let events = parser.push(second).expect("second chunk should parse"); + + assert_eq!( + events, + vec![StreamEvent::ContentBlockDelta( + crate::types::ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::TextDelta { + text: "Hello".to_string(), + }, + } + )] + ); + } + + #[test] + fn ignores_ping_and_done() { + let mut parser = SseParser::new(); + let payload = concat!( + ": keepalive\n", + "event: ping\n", + "data: {\"type\":\"ping\"}\n\n", + "event: message_delta\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}\n\n", + "event: message_stop\n", + "data: {\"type\":\"message_stop\"}\n\n", + "data: [DONE]\n\n" + ); + + let events = parser + .push(payload.as_bytes()) + .expect("parser should succeed"); + assert_eq!( + events, + vec![ + StreamEvent::MessageDelta(crate::types::MessageDeltaEvent { + delta: MessageDelta { + stop_reason: Some("tool_use".to_string()), + stop_sequence: None, + }, + usage: Usage { + input_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 2, + }, + }), + StreamEvent::MessageStop(crate::types::MessageStopEvent {}), + ] + ); + } + + #[test] + fn ignores_data_less_event_frames() { + let frame = "event: ping\n\n"; + let event = parse_frame(frame).expect("frame without data should be ignored"); + assert_eq!(event, None); + } + + #[test] + fn parses_split_json_across_data_lines() { + let frame = concat!( + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\n", + "data: \"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n" + ); + + let event = parse_frame(frame).expect("frame should parse"); + assert_eq!( + event, + Some(StreamEvent::ContentBlockDelta( + crate::types::ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::TextDelta { + text: "Hello".to_string(), + }, + } + )) + ); + } + + #[test] + fn parses_thinking_content_block_start() { + let frame = concat!( + "event: content_block_start\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\",\"signature\":null}}\n\n" + ); + + let event = parse_frame(frame).expect("frame should parse"); + assert_eq!( + event, + Some(StreamEvent::ContentBlockStart( + crate::types::ContentBlockStartEvent { + index: 0, + content_block: OutputContentBlock::Thinking { + thinking: String::new(), + signature: None, + }, + }, + )) + ); + } + + #[test] + fn parses_thinking_related_deltas() { + let thinking = concat!( + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"step 1\"}}\n\n" + ); + let signature = concat!( + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"signature_delta\",\"signature\":\"sig_123\"}}\n\n" + ); + + let thinking_event = parse_frame(thinking).expect("thinking delta should parse"); + let signature_event = parse_frame(signature).expect("signature delta should parse"); + + assert_eq!( + thinking_event, + Some(StreamEvent::ContentBlockDelta( + crate::types::ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::ThinkingDelta { + thinking: "step 1".to_string(), + }, + } + )) + ); + assert_eq!( + signature_event, + Some(StreamEvent::ContentBlockDelta( + crate::types::ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::SignatureDelta { + signature: "sig_123".to_string(), + }, + } + )) + ); + } +} diff --git a/rust/crates/api/src/types.rs b/rust/crates/api/src/types.rs new file mode 100644 index 0000000..c060be6 --- /dev/null +++ b/rust/crates/api/src/types.rs @@ -0,0 +1,223 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MessageRequest { + pub model: String, + pub max_tokens: u32, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub stream: bool, +} + +impl MessageRequest { + #[must_use] + pub fn with_streaming(mut self) -> Self { + self.stream = true; + self + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct InputMessage { + pub role: String, + pub content: Vec, +} + +impl InputMessage { + #[must_use] + pub fn user_text(text: impl Into) -> Self { + Self { + role: "user".to_string(), + content: vec![InputContentBlock::Text { text: text.into() }], + } + } + + #[must_use] + pub fn user_tool_result( + tool_use_id: impl Into, + content: impl Into, + is_error: bool, + ) -> Self { + Self { + role: "user".to_string(), + content: vec![InputContentBlock::ToolResult { + tool_use_id: tool_use_id.into(), + content: vec![ToolResultContentBlock::Text { + text: content.into(), + }], + is_error, + }], + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum InputContentBlock { + Text { + text: String, + }, + ToolUse { + id: String, + name: String, + input: Value, + }, + ToolResult { + tool_use_id: String, + content: Vec, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + is_error: bool, + }, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolResultContentBlock { + Text { text: String }, + Json { value: Value }, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ToolDefinition { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub input_schema: Value, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolChoice { + Auto, + Any, + Tool { name: String }, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MessageResponse { + pub id: String, + #[serde(rename = "type")] + pub kind: String, + pub role: String, + pub content: Vec, + pub model: String, + #[serde(default)] + pub stop_reason: Option, + #[serde(default)] + pub stop_sequence: Option, + pub usage: Usage, + #[serde(default)] + pub request_id: Option, +} + +impl MessageResponse { + #[must_use] + pub fn total_tokens(&self) -> u32 { + self.usage.total_tokens() + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum OutputContentBlock { + Text { + text: String, + }, + ToolUse { + id: String, + name: String, + input: Value, + }, + Thinking { + #[serde(default)] + thinking: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + signature: Option, + }, + RedactedThinking { + data: Value, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Usage { + pub input_tokens: u32, + #[serde(default)] + pub cache_creation_input_tokens: u32, + #[serde(default)] + pub cache_read_input_tokens: u32, + pub output_tokens: u32, +} + +impl Usage { + #[must_use] + pub const fn total_tokens(&self) -> u32 { + self.input_tokens + self.output_tokens + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MessageStartEvent { + pub message: MessageResponse, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MessageDeltaEvent { + pub delta: MessageDelta, + pub usage: Usage, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct MessageDelta { + #[serde(default)] + pub stop_reason: Option, + #[serde(default)] + pub stop_sequence: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ContentBlockStartEvent { + pub index: u32, + pub content_block: OutputContentBlock, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ContentBlockDeltaEvent { + pub index: u32, + pub delta: ContentBlockDelta, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentBlockDelta { + TextDelta { text: String }, + InputJsonDelta { partial_json: String }, + ThinkingDelta { thinking: String }, + SignatureDelta { signature: String }, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ContentBlockStopEvent { + pub index: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct MessageStopEvent {} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamEvent { + MessageStart(MessageStartEvent), + MessageDelta(MessageDeltaEvent), + ContentBlockStart(ContentBlockStartEvent), + ContentBlockDelta(ContentBlockDeltaEvent), + ContentBlockStop(ContentBlockStopEvent), + MessageStop(MessageStopEvent), +} diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs new file mode 100644 index 0000000..ae810b8 --- /dev/null +++ b/rust/crates/api/tests/client_integration.rs @@ -0,0 +1,484 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use api::{ + ApiClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent, + ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, +}; +use serde_json::json; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; + +#[tokio::test] +async fn send_message_posts_json_and_parses_response() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"msg_test\",", + "\"type\":\"message\",", + "\"role\":\"assistant\",", + "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claw\"}],", + "\"model\":\"claude-sonnet-4-6\",", + "\"stop_reason\":\"end_turn\",", + "\"stop_sequence\":null,", + "\"usage\":{\"input_tokens\":12,\"output_tokens\":4},", + "\"request_id\":\"req_body_123\"", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = ApiClient::new("test-key") + .with_auth_token(Some("proxy-token".to_string())) + .with_base_url(server.base_url()); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.id, "msg_test"); + assert_eq!(response.total_tokens(), 16); + assert_eq!(response.request_id.as_deref(), Some("req_body_123")); + assert_eq!( + response.content, + vec![OutputContentBlock::Text { + text: "Hello from Claw".to_string(), + }] + ); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.method, "POST"); + assert_eq!(request.path, "/v1/messages"); + assert_eq!( + request.headers.get("x-api-key").map(String::as_str), + Some("test-key") + ); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer proxy-token") + ); + let body: serde_json::Value = + serde_json::from_str(&request.body).expect("request body should be json"); + assert_eq!( + body.get("model").and_then(serde_json::Value::as_str), + Some("claude-sonnet-4-6") + ); + assert!(body.get("stream").is_none()); + assert_eq!(body["tools"][0]["name"], json!("get_weather")); + assert_eq!(body["tool_choice"]["type"], json!("auto")); +} + +#[tokio::test] +async fn stream_message_parses_sse_events_with_tool_use() { + let state = Arc::new(Mutex::new(Vec::::new())); + let sse = concat!( + "event: message_start\n", + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", + "event: content_block_start\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n", + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n", + "event: content_block_stop\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "event: message_delta\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n", + "event: message_stop\n", + "data: {\"type\":\"message_stop\"}\n\n", + "data: [DONE]\n\n" + ); + let server = spawn_server( + state.clone(), + vec![http_response_with_headers( + "200 OK", + "text/event-stream", + sse, + &[("request-id", "req_stream_456")], + )], + ) + .await; + + let client = ApiClient::new("test-key") + .with_auth_token(Some("proxy-token".to_string())) + .with_base_url(server.base_url()); + let mut stream = client + .stream_message(&sample_request(false)) + .await + .expect("stream should start"); + + assert_eq!(stream.request_id(), Some("req_stream_456")); + + let mut events = Vec::new(); + while let Some(event) = stream + .next_event() + .await + .expect("stream event should parse") + { + events.push(event); + } + + assert_eq!(events.len(), 6); + assert!(matches!(events[0], StreamEvent::MessageStart(_))); + assert!(matches!( + events[1], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + content_block: OutputContentBlock::ToolUse { .. }, + .. + }) + )); + assert!(matches!( + events[2], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + delta: ContentBlockDelta::InputJsonDelta { .. }, + .. + }) + )); + assert!(matches!(events[3], StreamEvent::ContentBlockStop(_))); + assert!(matches!( + events[4], + StreamEvent::MessageDelta(MessageDeltaEvent { .. }) + )); + assert!(matches!(events[5], StreamEvent::MessageStop(_))); + + match &events[1] { + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + content_block: OutputContentBlock::ToolUse { name, input, .. }, + .. + }) => { + assert_eq!(name, "get_weather"); + assert_eq!(input, &json!({})); + } + other => panic!("expected tool_use block, got {other:?}"), + } + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert!(request.body.contains("\"stream\":true")); +} + +#[tokio::test] +async fn retries_retryable_failures_before_succeeding() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![ + http_response( + "429 Too Many Requests", + "application/json", + "{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}", + ), + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + ), + ], + ) + .await; + + let client = ApiClient::new("test-key") + .with_base_url(server.base_url()) + .with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2)); + + let response = client + .send_message(&sample_request(false)) + .await + .expect("retry should eventually succeed"); + + assert_eq!(response.total_tokens(), 5); + assert_eq!(state.lock().await.len(), 2); +} + +#[tokio::test] +async fn provider_client_dispatches_api_requests() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + )], + ) + .await; + + let client = ProviderClient::from_model_with_default_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("test-key".to_string())), + ) + .expect("api provider client should be constructed"); + let client = match client { + ProviderClient::ClawApi(client) => { + ProviderClient::ClawApi(client.with_base_url(server.base_url())) + } + other => panic!("expected default provider, got {other:?}"), + }; + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 5); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/v1/messages"); + assert_eq!( + request.headers.get("x-api-key").map(String::as_str), + Some("test-key") + ); +} + +#[tokio::test] +async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![ + http_response( + "503 Service Unavailable", + "application/json", + "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}", + ), + http_response( + "503 Service Unavailable", + "application/json", + "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}", + ), + ], + ) + .await; + + let client = ApiClient::new("test-key") + .with_base_url(server.base_url()) + .with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2)); + + let error = client + .send_message(&sample_request(false)) + .await + .expect_err("persistent 503 should fail"); + + match error { + ApiError::RetriesExhausted { + attempts, + last_error, + } => { + assert_eq!(attempts, 2); + assert!(matches!( + *last_error, + ApiError::Api { + status: reqwest::StatusCode::SERVICE_UNAVAILABLE, + retryable: true, + .. + } + )); + } + other => panic!("expected retries exhausted, got {other:?}"), + } +} + +#[tokio::test] +#[ignore = "requires ANTHROPIC_API_KEY and network access"] +async fn live_stream_smoke_test() { + let client = ApiClient::from_env().expect("ANTHROPIC_API_KEY must be set"); + let mut stream = client + .stream_message(&MessageRequest { + model: std::env::var("CLAW_MODEL") + .unwrap_or_else(|_| "claude-sonnet-4-6".to_string()), + max_tokens: 32, + messages: vec![InputMessage::user_text( + "Reply with exactly: hello from rust", + )], + system: None, + tools: None, + tool_choice: None, + stream: false, + }) + .await + .expect("live stream should start"); + + while let Some(_event) = stream + .next_event() + .await + .expect("live stream should yield events") + {} +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CapturedRequest { + method: String, + path: String, + headers: HashMap, + body: String, +} + +struct TestServer { + base_url: String, + join_handle: tokio::task::JoinHandle<()>, +} + +impl TestServer { + fn base_url(&self) -> String { + self.base_url.clone() + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.join_handle.abort(); + } +} + +async fn spawn_server( + state: Arc>>, + responses: Vec, +) -> TestServer { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let address = listener + .local_addr() + .expect("listener should have local addr"); + let join_handle = tokio::spawn(async move { + for response in responses { + let (mut socket, _) = listener.accept().await.expect("server should accept"); + let mut buffer = Vec::new(); + let mut header_end = None; + + loop { + let mut chunk = [0_u8; 1024]; + let read = socket + .read(&mut chunk) + .await + .expect("request read should succeed"); + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if let Some(position) = find_header_end(&buffer) { + header_end = Some(position); + break; + } + } + + let header_end = header_end.expect("request should include headers"); + let (header_bytes, remaining) = buffer.split_at(header_end); + let header_text = + String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8"); + let mut lines = header_text.split("\r\n"); + let request_line = lines.next().expect("request line should exist"); + let mut parts = request_line.split_whitespace(); + let method = parts.next().expect("method should exist").to_string(); + let path = parts.next().expect("path should exist").to_string(); + let mut headers = HashMap::new(); + let mut content_length = 0_usize; + for line in lines { + if line.is_empty() { + continue; + } + let (name, value) = line.split_once(':').expect("header should have colon"); + let value = value.trim().to_string(); + if name.eq_ignore_ascii_case("content-length") { + content_length = value.parse().expect("content length should parse"); + } + headers.insert(name.to_ascii_lowercase(), value); + } + + let mut body = remaining[4..].to_vec(); + while body.len() < content_length { + let mut chunk = vec![0_u8; content_length - body.len()]; + let read = socket + .read(&mut chunk) + .await + .expect("body read should succeed"); + if read == 0 { + break; + } + body.extend_from_slice(&chunk[..read]); + } + + state.lock().await.push(CapturedRequest { + method, + path, + headers, + body: String::from_utf8(body).expect("body should be utf8"), + }); + + socket + .write_all(response.as_bytes()) + .await + .expect("response write should succeed"); + } + }); + + TestServer { + base_url: format!("http://{address}"), + join_handle, + } +} + +fn find_header_end(bytes: &[u8]) -> Option { + bytes.windows(4).position(|window| window == b"\r\n\r\n") +} + +fn http_response(status: &str, content_type: &str, body: &str) -> String { + http_response_with_headers(status, content_type, body, &[]) +} + +fn http_response_with_headers( + status: &str, + content_type: &str, + body: &str, + headers: &[(&str, &str)], +) -> String { + let mut extra_headers = String::new(); + for (name, value) in headers { + use std::fmt::Write as _; + write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed"); + } + format!( + "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ) +} + +fn sample_request(stream: bool) -> MessageRequest { + MessageRequest { + model: "claude-sonnet-4-6".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![ + InputContentBlock::Text { + text: "Say hello".to_string(), + }, + InputContentBlock::ToolResult { + tool_use_id: "toolu_prev".to_string(), + content: vec![api::ToolResultContentBlock::Json { + value: json!({"forecast": "sunny"}), + }], + is_error: false, + }, + ], + }], + system: Some("Use tools when needed".to_string()), + tools: Some(vec![ToolDefinition { + name: "get_weather".to_string(), + description: Some("Fetches the weather".to_string()), + input_schema: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + }), + }]), + tool_choice: Some(ToolChoice::Auto), + stream, + } +} diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs new file mode 100644 index 0000000..b345b1f --- /dev/null +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -0,0 +1,415 @@ +use std::collections::HashMap; +use std::ffi::OsString; +use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; + +use api::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, +}; +use serde_json::json; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; + +#[tokio::test] +async fn send_message_uses_openai_compatible_endpoint_and_auth() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_test\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.model, "grok-3"); + assert_eq!(response.total_tokens(), 16); + assert_eq!( + response.content, + vec![OutputContentBlock::Text { + text: "Hello from Grok".to_string(), + }] + ); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["model"], json!("grok-3")); + assert_eq!(body["messages"][0]["role"], json!("system")); + assert_eq!(body["tools"][0]["type"], json!("function")); +} + +#[tokio::test] +async fn send_message_accepts_full_chat_completions_endpoint_override() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_full_endpoint\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let endpoint_url = format!("{}/chat/completions", server.base_url()); + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(endpoint_url); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.total_tokens(), 10); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); +} + +#[tokio::test] +async fn stream_message_normalizes_text_and_multiple_tool_calls() { + let state = Arc::new(Mutex::new(Vec::::new())); + let sse = concat!( + "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n", + "data: [DONE]\n\n" + ); + let server = spawn_server( + state.clone(), + vec![http_response_with_headers( + "200 OK", + "text/event-stream", + sse, + &[("x-request-id", "req_grok_stream")], + )], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let mut stream = client + .stream_message(&sample_request(false)) + .await + .expect("stream should start"); + + assert_eq!(stream.request_id(), Some("req_grok_stream")); + + let mut events = Vec::new(); + while let Some(event) = stream.next_event().await.expect("event should parse") { + events.push(event); + } + + assert!(matches!(events[0], StreamEvent::MessageStart(_))); + assert!(matches!( + events[1], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + content_block: OutputContentBlock::Text { .. }, + .. + }) + )); + assert!(matches!( + events[2], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + delta: ContentBlockDelta::TextDelta { .. }, + .. + }) + )); + assert!(matches!( + events[3], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 1, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[4], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 1, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[5], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 2, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[6], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 2, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[7], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 }) + )); + assert!(matches!( + events[8], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 }) + )); + assert!(matches!( + events[9], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 }) + )); + assert!(matches!(events[10], StreamEvent::MessageDelta(_))); + assert!(matches!(events[11], StreamEvent::MessageStop(_))); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert!(request.body.contains("\"stream\":true")); +} + +#[tokio::test] +async fn provider_client_dispatches_xai_requests_from_env() { + let _lock = env_lock(); + let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key"); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}", + )], + ) + .await; + let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url()); + + let client = + ProviderClient::from_model("grok").expect("xAI provider client should be constructed"); + assert!(matches!(client, ProviderClient::Xai(_))); + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 13); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CapturedRequest { + path: String, + headers: HashMap, + body: String, +} + +struct TestServer { + base_url: String, + join_handle: tokio::task::JoinHandle<()>, +} + +impl TestServer { + fn base_url(&self) -> String { + self.base_url.clone() + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.join_handle.abort(); + } +} + +async fn spawn_server( + state: Arc>>, + responses: Vec, +) -> TestServer { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let address = listener.local_addr().expect("listener addr"); + let join_handle = tokio::spawn(async move { + for response in responses { + let (mut socket, _) = listener.accept().await.expect("accept"); + let mut buffer = Vec::new(); + let mut header_end = None; + loop { + let mut chunk = [0_u8; 1024]; + let read = socket.read(&mut chunk).await.expect("read request"); + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if let Some(position) = find_header_end(&buffer) { + header_end = Some(position); + break; + } + } + + let header_end = header_end.expect("headers should exist"); + let (header_bytes, remaining) = buffer.split_at(header_end); + let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers"); + let mut lines = header_text.split("\r\n"); + let request_line = lines.next().expect("request line"); + let path = request_line + .split_whitespace() + .nth(1) + .expect("path") + .to_string(); + let mut headers = HashMap::new(); + let mut content_length = 0_usize; + for line in lines { + if line.is_empty() { + continue; + } + let (name, value) = line.split_once(':').expect("header"); + let value = value.trim().to_string(); + if name.eq_ignore_ascii_case("content-length") { + content_length = value.parse().expect("content length"); + } + headers.insert(name.to_ascii_lowercase(), value); + } + + let mut body = remaining[4..].to_vec(); + while body.len() < content_length { + let mut chunk = vec![0_u8; content_length - body.len()]; + let read = socket.read(&mut chunk).await.expect("read body"); + if read == 0 { + break; + } + body.extend_from_slice(&chunk[..read]); + } + + state.lock().await.push(CapturedRequest { + path, + headers, + body: String::from_utf8(body).expect("utf8 body"), + }); + + socket + .write_all(response.as_bytes()) + .await + .expect("write response"); + } + }); + + TestServer { + base_url: format!("http://{address}"), + join_handle, + } +} + +fn find_header_end(bytes: &[u8]) -> Option { + bytes.windows(4).position(|window| window == b"\r\n\r\n") +} + +fn http_response(status: &str, content_type: &str, body: &str) -> String { + http_response_with_headers(status, content_type, body, &[]) +} + +fn http_response_with_headers( + status: &str, + content_type: &str, + body: &str, + headers: &[(&str, &str)], +) -> String { + let mut extra_headers = String::new(); + for (name, value) in headers { + use std::fmt::Write as _; + write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write"); + } + format!( + "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ) +} + +fn sample_request(stream: bool) -> MessageRequest { + MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "Say hello".to_string(), + }], + }], + system: Some("Use tools when needed".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Fetches weather".to_string()), + input_schema: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + }), + }]), + tool_choice: Some(ToolChoice::Auto), + stream, + } +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct ScopedEnvVar { + key: &'static str, + previous: Option, +} + +impl ScopedEnvVar { + fn set(key: &'static str, value: impl AsRef) -> Self { + let previous = std::env::var_os(key); + std::env::set_var(key, value); + Self { key, previous } + } +} + +impl Drop for ScopedEnvVar { + fn drop(&mut self) { + match &self.previous { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/rust/crates/api/tests/provider_client_integration.rs b/rust/crates/api/tests/provider_client_integration.rs new file mode 100644 index 0000000..abeebdd --- /dev/null +++ b/rust/crates/api/tests/provider_client_integration.rs @@ -0,0 +1,86 @@ +use std::ffi::OsString; +use std::sync::{Mutex, OnceLock}; + +use api::{read_xai_base_url, ApiError, AuthSource, ProviderClient, ProviderKind}; + +#[test] +fn provider_client_routes_grok_aliases_through_xai() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key")); + + let client = ProviderClient::from_model("grok-mini").expect("grok alias should resolve"); + + assert_eq!(client.provider_kind(), ProviderKind::Xai); +} + +#[test] +fn provider_client_reports_missing_xai_credentials_for_grok_models() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", None); + + let error = ProviderClient::from_model("grok-3") + .expect_err("grok requests without XAI_API_KEY should fail fast"); + + match error { + ApiError::MissingCredentials { provider, env_vars } => { + assert_eq!(provider, "xAI"); + assert_eq!(env_vars, &["XAI_API_KEY"]); + } + other => panic!("expected missing xAI credentials, got {other:?}"), + } +} + +#[test] +fn provider_client_uses_explicit_auth_without_env_lookup() { + let _lock = env_lock(); + let _api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None); + let _auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None); + + let client = ProviderClient::from_model_with_default_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("claw-test-key".to_string())), + ) + .expect("explicit auth should avoid env lookup"); + + assert_eq!(client.provider_kind(), ProviderKind::ClawApi); +} + +#[test] +fn read_xai_base_url_prefers_env_override() { + let _lock = env_lock(); + let _xai_base_url = EnvVarGuard::set("XAI_BASE_URL", Some("https://example.xai.test/v1")); + + assert_eq!(read_xai_base_url(), "https://example.xai.test/v1"); +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct EnvVarGuard { + key: &'static str, + original: Option, +} + +impl EnvVarGuard { + fn set(key: &'static str, value: Option<&str>) -> Self { + let original = std::env::var_os(key); + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + Self { key, original } + } +} + +impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.original { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +}