#![allow(clippy::cast_possible_truncation)] use std::future::Future; use std::pin::Pin; use serde::Serialize; use crate::error::ApiError; use crate::types::{MessageRequest, MessageResponse}; pub mod anthropic; pub mod openai_compat; #[allow(dead_code)] pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; #[allow(dead_code)] pub trait Provider { type Stream; fn send_message<'a>( &'a self, request: &'a MessageRequest, ) -> ProviderFuture<'a, MessageResponse>; fn stream_message<'a>( &'a self, request: &'a MessageRequest, ) -> ProviderFuture<'a, Self::Stream>; } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ProviderKind { Anthropic, Xai, OpenAi, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ProviderMetadata { pub provider: ProviderKind, pub auth_env: &'static str, pub base_url_env: &'static str, pub default_base_url: &'static str, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ModelTokenLimit { pub max_output_tokens: u32, pub context_window_tokens: u32, } const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ ( "opus", ProviderMetadata { provider: ProviderKind::Anthropic, auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", default_base_url: anthropic::DEFAULT_BASE_URL, }, ), ( "sonnet", ProviderMetadata { provider: ProviderKind::Anthropic, auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", default_base_url: anthropic::DEFAULT_BASE_URL, }, ), ( "haiku", ProviderMetadata { provider: ProviderKind::Anthropic, auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", default_base_url: anthropic::DEFAULT_BASE_URL, }, ), ( "grok", ProviderMetadata { provider: ProviderKind::Xai, auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, }, ), ( "grok-3", ProviderMetadata { provider: ProviderKind::Xai, auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, }, ), ( "grok-mini", ProviderMetadata { provider: ProviderKind::Xai, auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, }, ), ( "grok-3-mini", ProviderMetadata { provider: ProviderKind::Xai, auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, }, ), ( "grok-2", ProviderMetadata { provider: ProviderKind::Xai, auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, }, ), ]; #[must_use] pub fn resolve_model_alias(model: &str) -> String { let trimmed = model.trim(); let lower = trimmed.to_ascii_lowercase(); MODEL_REGISTRY .iter() .find_map(|(alias, metadata)| { (*alias == lower).then_some(match metadata.provider { ProviderKind::Anthropic => match *alias { "opus" => "claude-opus-4-6", "sonnet" => "claude-sonnet-4-6", "haiku" => "claude-haiku-4-5-20251213", _ => trimmed, }, ProviderKind::Xai => match *alias { "grok" | "grok-3" => "grok-3", "grok-mini" | "grok-3-mini" => "grok-3-mini", "grok-2" => "grok-2", _ => trimmed, }, ProviderKind::OpenAi => trimmed, }) }) .map_or_else(|| trimmed.to_string(), ToOwned::to_owned) } #[must_use] pub fn metadata_for_model(model: &str) -> Option { let canonical = resolve_model_alias(model); if canonical.starts_with("claude") { return Some(ProviderMetadata { provider: ProviderKind::Anthropic, auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", default_base_url: anthropic::DEFAULT_BASE_URL, }); } if canonical.starts_with("grok") { return Some(ProviderMetadata { provider: ProviderKind::Xai, auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, }); } None } #[must_use] pub fn detect_provider_kind(model: &str) -> ProviderKind { if let Some(metadata) = metadata_for_model(model) { return metadata.provider; } if anthropic::has_auth_from_env_or_saved().unwrap_or(false) { return ProviderKind::Anthropic; } if openai_compat::has_api_key("OPENAI_API_KEY") { return ProviderKind::OpenAi; } if openai_compat::has_api_key("XAI_API_KEY") { return ProviderKind::Xai; } ProviderKind::Anthropic } #[must_use] pub fn max_tokens_for_model(model: &str) -> u32 { model_token_limit(model).map_or_else( || { let canonical = resolve_model_alias(model); if canonical.contains("opus") { 32_000 } else { 64_000 } }, |limit| limit.max_output_tokens, ) } #[must_use] pub fn model_token_limit(model: &str) -> Option { let canonical = resolve_model_alias(model); match canonical.as_str() { "claude-opus-4-6" => Some(ModelTokenLimit { max_output_tokens: 32_000, context_window_tokens: 200_000, }), "claude-sonnet-4-6" | "claude-haiku-4-5-20251213" => Some(ModelTokenLimit { max_output_tokens: 64_000, context_window_tokens: 200_000, }), "grok-3" | "grok-3-mini" => Some(ModelTokenLimit { max_output_tokens: 64_000, context_window_tokens: 131_072, }), _ => None, } } pub fn preflight_message_request(request: &MessageRequest) -> Result<(), ApiError> { let Some(limit) = model_token_limit(&request.model) else { return Ok(()); }; let estimated_input_tokens = estimate_message_request_input_tokens(request); let estimated_total_tokens = estimated_input_tokens.saturating_add(request.max_tokens); if estimated_total_tokens > limit.context_window_tokens { return Err(ApiError::ContextWindowExceeded { model: resolve_model_alias(&request.model), estimated_input_tokens, requested_output_tokens: request.max_tokens, estimated_total_tokens, context_window_tokens: limit.context_window_tokens, }); } Ok(()) } fn estimate_message_request_input_tokens(request: &MessageRequest) -> u32 { let mut estimate = estimate_serialized_tokens(&request.messages); estimate = estimate.saturating_add(estimate_serialized_tokens(&request.system)); estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tools)); estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tool_choice)); estimate } fn estimate_serialized_tokens(value: &T) -> u32 { serde_json::to_vec(value) .ok() .map_or(0, |bytes| (bytes.len() / 4 + 1) as u32) } /// Parse a `.env` file body into key/value pairs using a minimal `KEY=VALUE` /// grammar. Lines that are blank, start with `#`, or do not contain `=` are /// ignored. Surrounding double or single quotes are stripped from the value. /// An optional leading `export ` prefix on the key is also stripped so files /// shared with shell `source` workflows still parse cleanly. pub(crate) fn parse_dotenv(content: &str) -> std::collections::HashMap { let mut values = std::collections::HashMap::new(); for raw_line in content.lines() { let line = raw_line.trim(); if line.is_empty() || line.starts_with('#') { continue; } let Some((raw_key, raw_value)) = line.split_once('=') else { continue; }; let trimmed_key = raw_key.trim(); let key = trimmed_key .strip_prefix("export ") .map_or(trimmed_key, str::trim) .to_string(); if key.is_empty() { continue; } let trimmed_value = raw_value.trim(); let unquoted = if (trimmed_value.starts_with('"') && trimmed_value.ends_with('"') || trimmed_value.starts_with('\'') && trimmed_value.ends_with('\'')) && trimmed_value.len() >= 2 { &trimmed_value[1..trimmed_value.len() - 1] } else { trimmed_value }; values.insert(key, unquoted.to_string()); } values } /// Load and parse a `.env` file from the given path. Missing files yield /// `None` instead of an error so callers can use this as a soft fallback. pub(crate) fn load_dotenv_file( path: &std::path::Path, ) -> Option> { let content = std::fs::read_to_string(path).ok()?; Some(parse_dotenv(&content)) } /// Look up `key` in a `.env` file located in the current working directory. /// Returns `None` when the file is missing, the key is absent, or the value /// is empty. pub(crate) fn dotenv_value(key: &str) -> Option { let cwd = std::env::current_dir().ok()?; let values = load_dotenv_file(&cwd.join(".env"))?; values.get(key).filter(|value| !value.is_empty()).cloned() } #[cfg(test)] mod tests { use serde_json::json; use crate::error::ApiError; use crate::types::{ InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, }; use super::{ detect_provider_kind, load_dotenv_file, max_tokens_for_model, model_token_limit, parse_dotenv, preflight_message_request, resolve_model_alias, ProviderKind, }; #[test] fn resolves_grok_aliases() { assert_eq!(resolve_model_alias("grok"), "grok-3"); assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); assert_eq!(resolve_model_alias("grok-2"), "grok-2"); } #[test] fn detects_provider_from_model_name_first() { assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai); assert_eq!( detect_provider_kind("claude-sonnet-4-6"), ProviderKind::Anthropic ); } #[test] fn keeps_existing_max_token_heuristic() { assert_eq!(max_tokens_for_model("opus"), 32_000); assert_eq!(max_tokens_for_model("grok-3"), 64_000); } #[test] fn returns_context_window_metadata_for_supported_models() { assert_eq!( model_token_limit("claude-sonnet-4-6") .expect("claude-sonnet-4-6 should be registered") .context_window_tokens, 200_000 ); assert_eq!( model_token_limit("grok-mini") .expect("grok-mini should resolve to a registered model") .context_window_tokens, 131_072 ); } #[test] fn preflight_blocks_requests_that_exceed_the_model_context_window() { let request = MessageRequest { model: "claude-sonnet-4-6".to_string(), max_tokens: 64_000, messages: vec![InputMessage { role: "user".to_string(), content: vec![InputContentBlock::Text { text: "x".repeat(600_000), }], }], system: Some("Keep the answer short.".to_string()), tools: Some(vec![ToolDefinition { name: "weather".to_string(), description: Some("Fetches weather".to_string()), input_schema: json!({ "type": "object", "properties": { "city": { "type": "string" } }, }), }]), tool_choice: Some(ToolChoice::Auto), stream: true, }; let error = preflight_message_request(&request) .expect_err("oversized request should be rejected before the provider call"); match error { ApiError::ContextWindowExceeded { model, estimated_input_tokens, requested_output_tokens, estimated_total_tokens, context_window_tokens, } => { assert_eq!(model, "claude-sonnet-4-6"); assert!(estimated_input_tokens > 136_000); assert_eq!(requested_output_tokens, 64_000); assert!(estimated_total_tokens > context_window_tokens); assert_eq!(context_window_tokens, 200_000); } other => panic!("expected context-window preflight failure, got {other:?}"), } } #[test] fn preflight_skips_unknown_models() { let request = MessageRequest { model: "unknown-model".to_string(), max_tokens: 64_000, messages: vec![InputMessage { role: "user".to_string(), content: vec![InputContentBlock::Text { text: "x".repeat(600_000), }], }], system: None, tools: None, tool_choice: None, stream: false, }; preflight_message_request(&request) .expect("models without context metadata should skip the guarded preflight"); } #[test] fn parse_dotenv_extracts_keys_handles_comments_quotes_and_export_prefix() { // given let body = "\ # this is a comment ANTHROPIC_API_KEY=plain-value XAI_API_KEY=\"quoted-value\" OPENAI_API_KEY='single-quoted' export GROK_API_KEY=exported-value PADDED_KEY = padded-value EMPTY_VALUE= NO_EQUALS_LINE "; // when let values = parse_dotenv(body); // then assert_eq!( values.get("ANTHROPIC_API_KEY").map(String::as_str), Some("plain-value") ); assert_eq!( values.get("XAI_API_KEY").map(String::as_str), Some("quoted-value") ); assert_eq!( values.get("OPENAI_API_KEY").map(String::as_str), Some("single-quoted") ); assert_eq!( values.get("GROK_API_KEY").map(String::as_str), Some("exported-value") ); assert_eq!( values.get("PADDED_KEY").map(String::as_str), Some("padded-value") ); assert_eq!(values.get("EMPTY_VALUE").map(String::as_str), Some("")); assert!(!values.contains_key("NO_EQUALS_LINE")); assert!(!values.contains_key("# this is a comment")); } #[test] fn load_dotenv_file_reads_keys_from_disk_and_returns_none_when_missing() { // given let temp_root = std::env::temp_dir().join(format!( "api-dotenv-test-{}-{}", std::process::id(), std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .map_or(0, |duration| duration.as_nanos()) )); std::fs::create_dir_all(&temp_root).expect("create temp dir"); let env_path = temp_root.join(".env"); std::fs::write( &env_path, "ANTHROPIC_API_KEY=secret-from-file\n# comment\nXAI_API_KEY=\"xai-secret\"\n", ) .expect("write .env"); let missing_path = temp_root.join("does-not-exist.env"); // when let loaded = load_dotenv_file(&env_path).expect("file should load"); let missing = load_dotenv_file(&missing_path); // then assert_eq!( loaded.get("ANTHROPIC_API_KEY").map(String::as_str), Some("secret-from-file") ); assert_eq!( loaded.get("XAI_API_KEY").map(String::as_str), Some("xai-secret") ); assert!(missing.is_none()); let _ = std::fs::remove_dir_all(&temp_root); } }