use crate::error::ApiError; use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; use crate::providers::anthropic::{self, AnthropicClient, AuthSource}; use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig}; use crate::providers::{self, ProviderKind}; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone)] pub enum ProviderClient { Anthropic(AnthropicClient), Xai(OpenAiCompatClient), OpenAi(OpenAiCompatClient), } impl ProviderClient { pub fn from_model(model: &str) -> Result { Self::from_model_with_anthropic_auth(model, None) } pub fn from_model_with_anthropic_auth( model: &str, anthropic_auth: Option, ) -> Result { let resolved_model = providers::resolve_model_alias(model); match providers::detect_provider_kind(&resolved_model) { ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth { Some(auth) => AnthropicClient::from_auth(auth), None => AnthropicClient::from_env()?, })), ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env( OpenAiCompatConfig::xai(), )?)), ProviderKind::OpenAi => { // DashScope models (qwen-*) also return ProviderKind::OpenAi because they // speak the OpenAI wire format, but they need the DashScope config which // reads DASHSCOPE_API_KEY and points at dashscope.aliyuncs.com. let config = match providers::metadata_for_model(&resolved_model) { Some(meta) if meta.auth_env == "DASHSCOPE_API_KEY" => { OpenAiCompatConfig::dashscope() } _ => OpenAiCompatConfig::openai(), }; Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) } } } #[must_use] pub const fn provider_kind(&self) -> ProviderKind { match self { Self::Anthropic(_) => ProviderKind::Anthropic, Self::Xai(_) => ProviderKind::Xai, Self::OpenAi(_) => ProviderKind::OpenAi, } } #[must_use] pub fn with_prompt_cache(self, prompt_cache: PromptCache) -> Self { match self { Self::Anthropic(client) => Self::Anthropic(client.with_prompt_cache(prompt_cache)), other => other, } } #[must_use] pub fn prompt_cache_stats(&self) -> Option { match self { Self::Anthropic(client) => client.prompt_cache_stats(), Self::Xai(_) | Self::OpenAi(_) => None, } } #[must_use] pub fn take_last_prompt_cache_record(&self) -> Option { match self { Self::Anthropic(client) => client.take_last_prompt_cache_record(), Self::Xai(_) | Self::OpenAi(_) => None, } } pub async fn send_message( &self, request: &MessageRequest, ) -> Result { match self { Self::Anthropic(client) => client.send_message(request).await, Self::Xai(client) | Self::OpenAi(client) => client.send_message(request).await, } } pub async fn stream_message( &self, request: &MessageRequest, ) -> Result { match self { Self::Anthropic(client) => client .stream_message(request) .await .map(MessageStream::Anthropic), Self::Xai(client) | Self::OpenAi(client) => client .stream_message(request) .await .map(MessageStream::OpenAiCompat), } } } #[derive(Debug)] pub enum MessageStream { Anthropic(anthropic::MessageStream), OpenAiCompat(openai_compat::MessageStream), } impl MessageStream { #[must_use] pub fn request_id(&self) -> Option<&str> { match self { Self::Anthropic(stream) => stream.request_id(), Self::OpenAiCompat(stream) => stream.request_id(), } } pub async fn next_event(&mut self) -> Result, ApiError> { match self { Self::Anthropic(stream) => stream.next_event().await, Self::OpenAiCompat(stream) => stream.next_event().await, } } } pub use anthropic::{ oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet, }; #[must_use] pub fn read_base_url() -> String { anthropic::read_base_url() } #[must_use] pub fn read_xai_base_url() -> String { openai_compat::read_base_url(OpenAiCompatConfig::xai()) } #[cfg(test)] mod tests { use std::sync::{Mutex, OnceLock}; use super::ProviderClient; use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind}; /// Serializes every test in this module that mutates process-wide /// environment variables so concurrent test threads cannot observe /// each other's partially-applied state. fn env_lock() -> std::sync::MutexGuard<'static, ()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) .lock() .unwrap_or_else(std::sync::PoisonError::into_inner) } #[test] fn resolves_existing_and_grok_aliases() { assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); assert_eq!(resolve_model_alias("grok"), "grok-3"); assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); } #[test] fn provider_detection_prefers_model_family() { assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai); assert_eq!( detect_provider_kind("claude-sonnet-4-6"), ProviderKind::Anthropic ); } /// Snapshot-restore guard for a single environment variable. Mirrors /// the pattern used in `providers/mod.rs` tests: captures the original /// value on construction, applies the override, and restores on drop so /// tests leave the process env untouched even when they panic. struct EnvVarGuard { key: &'static str, original: Option, } impl EnvVarGuard { fn set(key: &'static str, value: Option<&str>) -> Self { let original = std::env::var_os(key); match value { Some(value) => std::env::set_var(key, value), None => std::env::remove_var(key), } Self { key, original } } } impl Drop for EnvVarGuard { fn drop(&mut self) { match self.original.take() { Some(value) => std::env::set_var(self.key, value), None => std::env::remove_var(self.key), } } } #[test] fn dashscope_model_uses_dashscope_config_not_openai() { // Regression: qwen-plus was being routed to OpenAiCompatConfig::openai() // which reads OPENAI_API_KEY and points at api.openai.com, when it should // use OpenAiCompatConfig::dashscope() which reads DASHSCOPE_API_KEY and // points at dashscope.aliyuncs.com. let _lock = env_lock(); let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", Some("test-dashscope-key")); let _openai = EnvVarGuard::set("OPENAI_API_KEY", None); let client = ProviderClient::from_model("qwen-plus"); // Must succeed (not fail with "missing OPENAI_API_KEY") assert!( client.is_ok(), "qwen-plus with DASHSCOPE_API_KEY set should build successfully, got: {:?}", client.err() ); // Verify it's the OpenAi variant pointed at the DashScope base URL. match client.unwrap() { ProviderClient::OpenAi(openai_client) => { assert!( openai_client.base_url().contains("dashscope.aliyuncs.com"), "qwen-plus should route to DashScope base URL (contains 'dashscope.aliyuncs.com'), got: {}", openai_client.base_url() ); } other => panic!( "Expected ProviderClient::OpenAi for qwen-plus, got: {:?}", other ), } } }