use std::collections::{BTreeMap, VecDeque}; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use serde::Deserialize; use serde_json::{json, Value}; use crate::error::ApiError; use crate::http_client::build_http_client_or_default; use crate::types::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, }; use super::{preflight_message_request, Provider, ProviderFuture}; pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; pub const DEFAULT_DASHSCOPE_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1"; const REQUEST_ID_HEADER: &str = "request-id"; const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_secs(1); const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(128); const DEFAULT_MAX_RETRIES: u32 = 8; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct OpenAiCompatConfig { pub provider_name: &'static str, pub api_key_env: &'static str, pub base_url_env: &'static str, pub default_base_url: &'static str, } const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; const DASHSCOPE_ENV_VARS: &[&str] = &["DASHSCOPE_API_KEY"]; impl OpenAiCompatConfig { #[must_use] pub const fn xai() -> Self { Self { provider_name: "xAI", api_key_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: DEFAULT_XAI_BASE_URL, } } #[must_use] pub const fn openai() -> Self { Self { provider_name: "OpenAI", api_key_env: "OPENAI_API_KEY", base_url_env: "OPENAI_BASE_URL", default_base_url: DEFAULT_OPENAI_BASE_URL, } } /// Alibaba DashScope compatible-mode endpoint (Qwen family models). /// Uses the OpenAI-compatible REST shape at /compatible-mode/v1. /// Requested via Discord #clawcode-get-help: native Alibaba API for /// higher rate limits than going through OpenRouter. #[must_use] pub const fn dashscope() -> Self { Self { provider_name: "DashScope", api_key_env: "DASHSCOPE_API_KEY", base_url_env: "DASHSCOPE_BASE_URL", default_base_url: DEFAULT_DASHSCOPE_BASE_URL, } } #[must_use] pub fn credential_env_vars(self) -> &'static [&'static str] { match self.provider_name { "xAI" => XAI_ENV_VARS, "OpenAI" => OPENAI_ENV_VARS, "DashScope" => DASHSCOPE_ENV_VARS, _ => &[], } } } #[derive(Debug, Clone)] pub struct OpenAiCompatClient { http: reqwest::Client, api_key: String, config: OpenAiCompatConfig, base_url: String, max_retries: u32, initial_backoff: Duration, max_backoff: Duration, } impl OpenAiCompatClient { const fn config(&self) -> OpenAiCompatConfig { self.config } #[must_use] pub fn base_url(&self) -> &str { &self.base_url } #[must_use] pub fn new(api_key: impl Into, config: OpenAiCompatConfig) -> Self { Self { http: build_http_client_or_default(), api_key: api_key.into(), config, base_url: read_base_url(config), max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, } } pub fn from_env(config: OpenAiCompatConfig) -> Result { let Some(api_key) = read_env_non_empty(config.api_key_env)? else { return Err(ApiError::missing_credentials( config.provider_name, config.credential_env_vars(), )); }; Ok(Self::new(api_key, config)) } #[must_use] pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = base_url.into(); self } #[must_use] pub fn with_retry_policy( mut self, max_retries: u32, initial_backoff: Duration, max_backoff: Duration, ) -> Self { self.max_retries = max_retries; self.initial_backoff = initial_backoff; self.max_backoff = max_backoff; self } pub async fn send_message( &self, request: &MessageRequest, ) -> Result { let request = MessageRequest { stream: false, ..request.clone() }; preflight_message_request(&request)?; let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); let body = response.text().await.map_err(ApiError::from)?; let payload = serde_json::from_str::(&body).map_err(|error| { ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error) })?; let mut normalized = normalize_response(&request.model, payload)?; if normalized.request_id.is_none() { normalized.request_id = request_id; } Ok(normalized) } pub async fn stream_message( &self, request: &MessageRequest, ) -> Result { preflight_message_request(request)?; let response = self .send_with_retry(&request.clone().with_streaming()) .await?; Ok(MessageStream { request_id: request_id_from_headers(response.headers()), response, parser: OpenAiSseParser::with_context(self.config.provider_name, request.model.clone()), pending: VecDeque::new(), done: false, state: StreamState::new(request.model.clone()), }) } async fn send_with_retry( &self, request: &MessageRequest, ) -> Result { let mut attempts = 0; let last_error = loop { attempts += 1; let retryable_error = match self.send_raw_request(request).await { Ok(response) => match expect_success(response).await { Ok(response) => return Ok(response), Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, Err(error) => return Err(error), }, Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, Err(error) => return Err(error), }; if attempts > self.max_retries { break retryable_error; } tokio::time::sleep(self.jittered_backoff_for_attempt(attempts)?).await; }; Err(ApiError::RetriesExhausted { attempts, last_error: Box::new(last_error), }) } async fn send_raw_request( &self, request: &MessageRequest, ) -> Result { let request_url = chat_completions_endpoint(&self.base_url); self.http .post(&request_url) .header("content-type", "application/json") .bearer_auth(&self.api_key) .json(&build_chat_completion_request(request, self.config())) .send() .await .map_err(ApiError::from) } fn backoff_for_attempt(&self, attempt: u32) -> Result { let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { return Err(ApiError::BackoffOverflow { attempt, base_delay: self.initial_backoff, }); }; Ok(self .initial_backoff .checked_mul(multiplier) .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) } fn jittered_backoff_for_attempt(&self, attempt: u32) -> Result { let base = self.backoff_for_attempt(attempt)?; Ok(base + jitter_for_base(base)) } } /// Process-wide counter that guarantees distinct jitter samples even when /// the system clock resolution is coarser than consecutive retry sleeps. static JITTER_COUNTER: AtomicU64 = AtomicU64::new(0); /// Returns a random additive jitter in `[0, base]` to decorrelate retries /// from multiple concurrent clients. Entropy is drawn from the nanosecond /// wall clock mixed with a monotonic counter and run through a splitmix64 /// finalizer; adequate for retry jitter (no cryptographic requirement). fn jitter_for_base(base: Duration) -> Duration { let base_nanos = u64::try_from(base.as_nanos()).unwrap_or(u64::MAX); if base_nanos == 0 { return Duration::ZERO; } let raw_nanos = SystemTime::now() .duration_since(UNIX_EPOCH) .map(|elapsed| u64::try_from(elapsed.as_nanos()).unwrap_or(u64::MAX)) .unwrap_or(0); let tick = JITTER_COUNTER.fetch_add(1, Ordering::Relaxed); let mut mixed = raw_nanos .wrapping_add(tick) .wrapping_add(0x9E37_79B9_7F4A_7C15); mixed = (mixed ^ (mixed >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); mixed = (mixed ^ (mixed >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); mixed ^= mixed >> 31; let jitter_nanos = mixed % base_nanos.saturating_add(1); Duration::from_nanos(jitter_nanos) } impl Provider for OpenAiCompatClient { type Stream = MessageStream; fn send_message<'a>( &'a self, request: &'a MessageRequest, ) -> ProviderFuture<'a, MessageResponse> { Box::pin(async move { self.send_message(request).await }) } fn stream_message<'a>( &'a self, request: &'a MessageRequest, ) -> ProviderFuture<'a, Self::Stream> { Box::pin(async move { self.stream_message(request).await }) } } #[derive(Debug)] pub struct MessageStream { request_id: Option, response: reqwest::Response, parser: OpenAiSseParser, pending: VecDeque, done: bool, state: StreamState, } impl MessageStream { #[must_use] pub fn request_id(&self) -> Option<&str> { self.request_id.as_deref() } pub async fn next_event(&mut self) -> Result, ApiError> { loop { if let Some(event) = self.pending.pop_front() { return Ok(Some(event)); } if self.done { self.pending.extend(self.state.finish()?); if let Some(event) = self.pending.pop_front() { return Ok(Some(event)); } return Ok(None); } match self.response.chunk().await? { Some(chunk) => { for parsed in self.parser.push(&chunk)? { self.pending.extend(self.state.ingest_chunk(parsed)?); } } None => { self.done = true; } } } } } #[derive(Debug, Default)] struct OpenAiSseParser { buffer: Vec, provider: String, model: String, } impl OpenAiSseParser { fn with_context(provider: impl Into, model: impl Into) -> Self { Self { buffer: Vec::new(), provider: provider.into(), model: model.into(), } } fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { self.buffer.extend_from_slice(chunk); let mut events = Vec::new(); while let Some(frame) = next_sse_frame(&mut self.buffer) { if let Some(event) = parse_sse_frame(&frame, &self.provider, &self.model)? { events.push(event); } } Ok(events) } } #[allow(clippy::struct_excessive_bools)] #[derive(Debug)] struct StreamState { model: String, message_started: bool, text_started: bool, text_finished: bool, finished: bool, stop_reason: Option, usage: Option, tool_calls: BTreeMap, } impl StreamState { fn new(model: String) -> Self { Self { model, message_started: false, text_started: false, text_finished: false, finished: false, stop_reason: None, usage: None, tool_calls: BTreeMap::new(), } } fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result, ApiError> { let mut events = Vec::new(); if !self.message_started { self.message_started = true; events.push(StreamEvent::MessageStart(MessageStartEvent { message: MessageResponse { id: chunk.id.clone(), kind: "message".to_string(), role: "assistant".to_string(), content: Vec::new(), model: chunk.model.clone().unwrap_or_else(|| self.model.clone()), stop_reason: None, stop_sequence: None, usage: Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: 0, }, request_id: None, }, })); } if let Some(usage) = chunk.usage { self.usage = Some(Usage { input_tokens: usage.prompt_tokens, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: usage.completion_tokens, }); } for choice in chunk.choices { if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) { if !self.text_started { self.text_started = true; events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { index: 0, content_block: OutputContentBlock::Text { text: String::new(), }, })); } events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { index: 0, delta: ContentBlockDelta::TextDelta { text: content }, })); } for tool_call in choice.delta.tool_calls { let state = self.tool_calls.entry(tool_call.index).or_default(); state.apply(tool_call); let block_index = state.block_index(); if !state.started { if let Some(start_event) = state.start_event()? { state.started = true; events.push(StreamEvent::ContentBlockStart(start_event)); } else { continue; } } if let Some(delta_event) = state.delta_event() { events.push(StreamEvent::ContentBlockDelta(delta_event)); } if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped { state.stopped = true; events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: block_index, })); } } if let Some(finish_reason) = choice.finish_reason { self.stop_reason = Some(normalize_finish_reason(&finish_reason)); if finish_reason == "tool_calls" { for state in self.tool_calls.values_mut() { if state.started && !state.stopped { state.stopped = true; events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: state.block_index(), })); } } } } } Ok(events) } fn finish(&mut self) -> Result, ApiError> { if self.finished { return Ok(Vec::new()); } self.finished = true; let mut events = Vec::new(); if self.text_started && !self.text_finished { self.text_finished = true; events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0, })); } for state in self.tool_calls.values_mut() { if !state.started { if let Some(start_event) = state.start_event()? { state.started = true; events.push(StreamEvent::ContentBlockStart(start_event)); if let Some(delta_event) = state.delta_event() { events.push(StreamEvent::ContentBlockDelta(delta_event)); } } } if state.started && !state.stopped { state.stopped = true; events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: state.block_index(), })); } } if self.message_started { events.push(StreamEvent::MessageDelta(MessageDeltaEvent { delta: MessageDelta { stop_reason: Some( self.stop_reason .clone() .unwrap_or_else(|| "end_turn".to_string()), ), stop_sequence: None, }, usage: self.usage.clone().unwrap_or(Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: 0, }), })); events.push(StreamEvent::MessageStop(MessageStopEvent {})); } Ok(events) } } #[derive(Debug, Default)] struct ToolCallState { openai_index: u32, id: Option, name: Option, arguments: String, emitted_len: usize, started: bool, stopped: bool, } impl ToolCallState { fn apply(&mut self, tool_call: DeltaToolCall) { self.openai_index = tool_call.index; if let Some(id) = tool_call.id { self.id = Some(id); } if let Some(name) = tool_call.function.name { self.name = Some(name); } if let Some(arguments) = tool_call.function.arguments { self.arguments.push_str(&arguments); } } const fn block_index(&self) -> u32 { self.openai_index + 1 } #[allow(clippy::unnecessary_wraps)] fn start_event(&self) -> Result, ApiError> { let Some(name) = self.name.clone() else { return Ok(None); }; let id = self .id .clone() .unwrap_or_else(|| format!("tool_call_{}", self.openai_index)); Ok(Some(ContentBlockStartEvent { index: self.block_index(), content_block: OutputContentBlock::ToolUse { id, name, input: json!({}), }, })) } fn delta_event(&mut self) -> Option { if self.emitted_len >= self.arguments.len() { return None; } let delta = self.arguments[self.emitted_len..].to_string(); self.emitted_len = self.arguments.len(); Some(ContentBlockDeltaEvent { index: self.block_index(), delta: ContentBlockDelta::InputJsonDelta { partial_json: delta, }, }) } } #[derive(Debug, Deserialize)] struct ChatCompletionResponse { id: String, model: String, choices: Vec, #[serde(default)] usage: Option, } #[derive(Debug, Deserialize)] struct ChatChoice { message: ChatMessage, #[serde(default)] finish_reason: Option, } #[derive(Debug, Deserialize)] struct ChatMessage { role: String, #[serde(default)] content: Option, #[serde(default)] tool_calls: Vec, } #[derive(Debug, Deserialize)] struct ResponseToolCall { id: String, function: ResponseToolFunction, } #[derive(Debug, Deserialize)] struct ResponseToolFunction { name: String, arguments: String, } #[derive(Debug, Deserialize)] struct OpenAiUsage { #[serde(default)] prompt_tokens: u32, #[serde(default)] completion_tokens: u32, } #[derive(Debug, Deserialize)] struct ChatCompletionChunk { id: String, #[serde(default)] model: Option, #[serde(default)] choices: Vec, #[serde(default)] usage: Option, } #[derive(Debug, Deserialize)] struct ChunkChoice { delta: ChunkDelta, #[serde(default)] finish_reason: Option, } #[derive(Debug, Default, Deserialize)] struct ChunkDelta { #[serde(default)] content: Option, #[serde(default)] tool_calls: Vec, } #[derive(Debug, Deserialize)] struct DeltaToolCall { #[serde(default)] index: u32, #[serde(default)] id: Option, #[serde(default)] function: DeltaFunction, } #[derive(Debug, Default, Deserialize)] struct DeltaFunction { #[serde(default)] name: Option, #[serde(default)] arguments: Option, } #[derive(Debug, Deserialize)] struct ErrorEnvelope { error: ErrorBody, } #[derive(Debug, Deserialize)] struct ErrorBody { #[serde(rename = "type")] error_type: Option, message: Option, } /// Returns true for models known to reject tuning parameters like temperature, /// top_p, frequency_penalty, and presence_penalty. These are typically /// reasoning/chain-of-thought models with fixed sampling. fn is_reasoning_model(model: &str) -> bool { let lowered = model.to_ascii_lowercase(); // Strip any provider/ prefix for the check (e.g. qwen/qwen-qwq -> qwen-qwq) let canonical = lowered.rsplit('/').next().unwrap_or(lowered.as_str()); // OpenAI reasoning models canonical.starts_with("o1") || canonical.starts_with("o3") || canonical.starts_with("o4") // xAI reasoning: grok-3-mini always uses reasoning mode || canonical == "grok-3-mini" // Alibaba DashScope reasoning variants (QwQ + Qwen3-Thinking family) || canonical.starts_with("qwen-qwq") || canonical.starts_with("qwq") || canonical.contains("thinking") } /// Strip routing prefix (e.g., "openai/gpt-4" → "gpt-4") for the wire. /// The prefix is used only to select transport; the backend expects the /// bare model id. fn strip_routing_prefix(model: &str) -> &str { if let Some(pos) = model.find('/') { let prefix = &model[..pos]; // Only strip if the prefix before "/" is a known routing prefix, // not if "/" appears in the middle of the model name for other reasons. if matches!(prefix, "openai" | "xai" | "grok" | "qwen") { &model[pos + 1..] } else { model } } else { model } } fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value { let mut messages = Vec::new(); if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { messages.push(json!({ "role": "system", "content": system, })); } for message in &request.messages { messages.extend(translate_message(message)); } // Strip routing prefix (e.g., "openai/gpt-4" → "gpt-4") for the wire. let wire_model = strip_routing_prefix(&request.model); let mut payload = json!({ "model": wire_model, "max_tokens": request.max_tokens, "messages": messages, "stream": request.stream, }); if request.stream && should_request_stream_usage(config) { payload["stream_options"] = json!({ "include_usage": true }); } if let Some(tools) = &request.tools { payload["tools"] = Value::Array(tools.iter().map(openai_tool_definition).collect::>()); } if let Some(tool_choice) = &request.tool_choice { payload["tool_choice"] = openai_tool_choice(tool_choice); } // OpenAI-compatible tuning parameters — only included when explicitly set. // Reasoning models (o1/o3/o4/grok-3-mini) reject these params with 400; // silently strip them to avoid cryptic provider errors. if !is_reasoning_model(&request.model) { if let Some(temperature) = request.temperature { payload["temperature"] = json!(temperature); } if let Some(top_p) = request.top_p { payload["top_p"] = json!(top_p); } if let Some(frequency_penalty) = request.frequency_penalty { payload["frequency_penalty"] = json!(frequency_penalty); } if let Some(presence_penalty) = request.presence_penalty { payload["presence_penalty"] = json!(presence_penalty); } } // stop is generally safe for all providers if let Some(stop) = &request.stop { if !stop.is_empty() { payload["stop"] = json!(stop); } } // reasoning_effort for OpenAI-compatible reasoning models (o4-mini, o3, etc.) if let Some(effort) = &request.reasoning_effort { payload["reasoning_effort"] = json!(effort); } payload } fn translate_message(message: &InputMessage) -> Vec { match message.role.as_str() { "assistant" => { let mut text = String::new(); let mut tool_calls = Vec::new(); for block in &message.content { match block { InputContentBlock::Text { text: value } => text.push_str(value), InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({ "id": id, "type": "function", "function": { "name": name, "arguments": input.to_string(), } })), InputContentBlock::ToolResult { .. } => {} } } if text.is_empty() && tool_calls.is_empty() { Vec::new() } else { vec![json!({ "role": "assistant", "content": (!text.is_empty()).then_some(text), "tool_calls": tool_calls, })] } } _ => message .content .iter() .filter_map(|block| match block { InputContentBlock::Text { text } => Some(json!({ "role": "user", "content": text, })), InputContentBlock::ToolResult { tool_use_id, content, is_error, } => Some(json!({ "role": "tool", "tool_call_id": tool_use_id, "content": flatten_tool_result_content(content), "is_error": is_error, })), InputContentBlock::ToolUse { .. } => None, }) .collect(), } } fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String { content .iter() .map(|block| match block { ToolResultContentBlock::Text { text } => text.clone(), ToolResultContentBlock::Json { value } => value.to_string(), }) .collect::>() .join("\n") } /// Recursively ensure every object-type node in a JSON Schema has /// `"properties"` (at least `{}`) and `"additionalProperties": false`. /// The OpenAI `/responses` endpoint validates schemas strictly and rejects /// objects that omit these fields; `/chat/completions` is lenient but also /// accepts them, so we normalise unconditionally. fn normalize_object_schema(schema: &mut Value) { if let Some(obj) = schema.as_object_mut() { if obj.get("type").and_then(Value::as_str) == Some("object") { obj.entry("properties").or_insert_with(|| json!({})); obj.entry("additionalProperties") .or_insert(Value::Bool(false)); } // Recurse into properties values if let Some(props) = obj.get_mut("properties") { if let Some(props_obj) = props.as_object_mut() { let keys: Vec = props_obj.keys().cloned().collect(); for k in keys { if let Some(v) = props_obj.get_mut(&k) { normalize_object_schema(v); } } } } // Recurse into items (arrays) if let Some(items) = obj.get_mut("items") { normalize_object_schema(items); } } } fn openai_tool_definition(tool: &ToolDefinition) -> Value { let mut parameters = tool.input_schema.clone(); normalize_object_schema(&mut parameters); json!({ "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": parameters, } }) } fn openai_tool_choice(tool_choice: &ToolChoice) -> Value { match tool_choice { ToolChoice::Auto => Value::String("auto".to_string()), ToolChoice::Any => Value::String("required".to_string()), ToolChoice::Tool { name } => json!({ "type": "function", "function": { "name": name }, }), } } fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool { matches!(config.provider_name, "OpenAI") } fn normalize_response( model: &str, response: ChatCompletionResponse, ) -> Result { let choice = response .choices .into_iter() .next() .ok_or(ApiError::InvalidSseFrame( "chat completion response missing choices", ))?; let mut content = Vec::new(); if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) { content.push(OutputContentBlock::Text { text }); } for tool_call in choice.message.tool_calls { content.push(OutputContentBlock::ToolUse { id: tool_call.id, name: tool_call.function.name, input: parse_tool_arguments(&tool_call.function.arguments), }); } Ok(MessageResponse { id: response.id, kind: "message".to_string(), role: choice.message.role, content, model: response.model.if_empty_then(model.to_string()), stop_reason: choice .finish_reason .map(|value| normalize_finish_reason(&value)), stop_sequence: None, usage: Usage { input_tokens: response .usage .as_ref() .map_or(0, |usage| usage.prompt_tokens), cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: response .usage .as_ref() .map_or(0, |usage| usage.completion_tokens), }, request_id: None, }) } fn parse_tool_arguments(arguments: &str) -> Value { serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments })) } fn next_sse_frame(buffer: &mut Vec) -> Option { let separator = buffer .windows(2) .position(|window| window == b"\n\n") .map(|position| (position, 2)) .or_else(|| { buffer .windows(4) .position(|window| window == b"\r\n\r\n") .map(|position| (position, 4)) })?; let (position, separator_len) = separator; let frame = buffer.drain(..position + separator_len).collect::>(); let frame_len = frame.len().saturating_sub(separator_len); Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) } fn parse_sse_frame( frame: &str, provider: &str, model: &str, ) -> Result, ApiError> { let trimmed = frame.trim(); if trimmed.is_empty() { return Ok(None); } let mut data_lines = Vec::new(); for line in trimmed.lines() { if line.starts_with(':') { continue; } if let Some(data) = line.strip_prefix("data:") { data_lines.push(data.trim_start()); } } if data_lines.is_empty() { return Ok(None); } let payload = data_lines.join("\n"); if payload == "[DONE]" { return Ok(None); } serde_json::from_str::(&payload) .map(Some) .map_err(|error| ApiError::json_deserialize(provider, model, &payload, error)) } fn read_env_non_empty(key: &str) -> Result, ApiError> { match std::env::var(key) { Ok(value) if !value.is_empty() => Ok(Some(value)), Ok(_) | Err(std::env::VarError::NotPresent) => Ok(super::dotenv_value(key)), Err(error) => Err(ApiError::from(error)), } } #[must_use] pub fn has_api_key(key: &str) -> bool { read_env_non_empty(key) .ok() .and_then(std::convert::identity) .is_some() } #[must_use] pub fn read_base_url(config: OpenAiCompatConfig) -> String { std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) } fn chat_completions_endpoint(base_url: &str) -> String { let trimmed = base_url.trim_end_matches('/'); if trimmed.ends_with("/chat/completions") { trimmed.to_string() } else { format!("{trimmed}/chat/completions") } } fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { headers .get(REQUEST_ID_HEADER) .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) .and_then(|value| value.to_str().ok()) .map(ToOwned::to_owned) } async fn expect_success(response: reqwest::Response) -> Result { let status = response.status(); if status.is_success() { return Ok(response); } let request_id = request_id_from_headers(response.headers()); let body = response.text().await.unwrap_or_default(); let parsed_error = serde_json::from_str::(&body).ok(); let retryable = is_retryable_status(status); Err(ApiError::Api { status, error_type: parsed_error .as_ref() .and_then(|error| error.error.error_type.clone()), message: parsed_error .as_ref() .and_then(|error| error.error.message.clone()), request_id, body, retryable, }) } const fn is_retryable_status(status: reqwest::StatusCode) -> bool { matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) } fn normalize_finish_reason(value: &str) -> String { match value { "stop" => "end_turn", "tool_calls" => "tool_use", other => other, } .to_string() } trait StringExt { fn if_empty_then(self, fallback: String) -> String; } impl StringExt for String { fn if_empty_then(self, fallback: String) -> String { if self.is_empty() { fallback } else { self } } } #[cfg(test)] mod tests { use super::{ build_chat_completion_request, chat_completions_endpoint, is_reasoning_model, normalize_finish_reason, openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, }; use crate::error::ApiError; use crate::types::{ InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use serde_json::json; use std::sync::{Mutex, OnceLock}; #[test] fn request_translation_uses_openai_compatible_shape() { let payload = build_chat_completion_request( &MessageRequest { model: "grok-3".to_string(), max_tokens: 64, messages: vec![InputMessage { role: "user".to_string(), content: vec![ InputContentBlock::Text { text: "hello".to_string(), }, InputContentBlock::ToolResult { tool_use_id: "tool_1".to_string(), content: vec![ToolResultContentBlock::Json { value: json!({"ok": true}), }], is_error: false, }, ], }], system: Some("be helpful".to_string()), tools: Some(vec![ToolDefinition { name: "weather".to_string(), description: Some("Get weather".to_string()), input_schema: json!({"type": "object"}), }]), tool_choice: Some(ToolChoice::Auto), stream: false, ..Default::default() }, OpenAiCompatConfig::xai(), ); assert_eq!(payload["messages"][0]["role"], json!("system")); assert_eq!(payload["messages"][1]["role"], json!("user")); assert_eq!(payload["messages"][2]["role"], json!("tool")); assert_eq!(payload["tools"][0]["type"], json!("function")); assert_eq!(payload["tool_choice"], json!("auto")); } #[test] fn tool_schema_object_gets_strict_fields_for_responses_endpoint() { // OpenAI /responses endpoint rejects object schemas missing // "properties" and "additionalProperties". Verify normalize_object_schema // fills them in so the request shape is strict-validator-safe. use super::normalize_object_schema; // Bare object — no properties at all let mut schema = json!({"type": "object"}); normalize_object_schema(&mut schema); assert_eq!(schema["properties"], json!({})); assert_eq!(schema["additionalProperties"], json!(false)); // Nested object inside properties let mut schema2 = json!({ "type": "object", "properties": { "location": {"type": "object", "properties": {"lat": {"type": "number"}}} } }); normalize_object_schema(&mut schema2); assert_eq!(schema2["additionalProperties"], json!(false)); assert_eq!( schema2["properties"]["location"]["additionalProperties"], json!(false) ); // Existing properties/additionalProperties should not be overwritten let mut schema3 = json!({ "type": "object", "properties": {"x": {"type": "string"}}, "additionalProperties": true }); normalize_object_schema(&mut schema3); assert_eq!( schema3["additionalProperties"], json!(true), "must not overwrite existing" ); } #[test] fn reasoning_effort_is_included_when_set() { let payload = build_chat_completion_request( &MessageRequest { model: "o4-mini".to_string(), max_tokens: 1024, messages: vec![InputMessage::user_text("think hard")], reasoning_effort: Some("high".to_string()), ..Default::default() }, OpenAiCompatConfig::openai(), ); assert_eq!(payload["reasoning_effort"], json!("high")); } #[test] fn reasoning_effort_omitted_when_not_set() { let payload = build_chat_completion_request( &MessageRequest { model: "gpt-4o".to_string(), max_tokens: 64, messages: vec![InputMessage::user_text("hello")], ..Default::default() }, OpenAiCompatConfig::openai(), ); assert!(payload.get("reasoning_effort").is_none()); } #[test] fn openai_streaming_requests_include_usage_opt_in() { let payload = build_chat_completion_request( &MessageRequest { model: "gpt-5".to_string(), max_tokens: 64, messages: vec![InputMessage::user_text("hello")], system: None, tools: None, tool_choice: None, stream: true, ..Default::default() }, OpenAiCompatConfig::openai(), ); assert_eq!(payload["stream_options"], json!({"include_usage": true})); } #[test] fn xai_streaming_requests_skip_openai_specific_usage_opt_in() { let payload = build_chat_completion_request( &MessageRequest { model: "grok-3".to_string(), max_tokens: 64, messages: vec![InputMessage::user_text("hello")], system: None, tools: None, tool_choice: None, stream: true, ..Default::default() }, OpenAiCompatConfig::xai(), ); assert!(payload.get("stream_options").is_none()); } #[test] fn tool_choice_translation_supports_required_function() { assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required")); assert_eq!( openai_tool_choice(&ToolChoice::Tool { name: "weather".to_string(), }), json!({"type": "function", "function": {"name": "weather"}}) ); } #[test] fn parses_tool_arguments_fallback() { assert_eq!( parse_tool_arguments("{\"city\":\"Paris\"}"), json!({"city": "Paris"}) ); assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"})); } #[test] fn missing_xai_api_key_is_provider_specific() { let _lock = env_lock(); std::env::remove_var("XAI_API_KEY"); let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai()) .expect_err("missing key should error"); assert!(matches!( error, ApiError::MissingCredentials { provider: "xAI", .. } )); } #[test] fn endpoint_builder_accepts_base_urls_and_full_endpoints() { assert_eq!( chat_completions_endpoint("https://api.x.ai/v1"), "https://api.x.ai/v1/chat/completions" ); assert_eq!( chat_completions_endpoint("https://api.x.ai/v1/"), "https://api.x.ai/v1/chat/completions" ); assert_eq!( chat_completions_endpoint("https://api.x.ai/v1/chat/completions"), "https://api.x.ai/v1/chat/completions" ); } fn env_lock() -> std::sync::MutexGuard<'static, ()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) .lock() .expect("env lock") } #[test] fn normalizes_stop_reasons() { assert_eq!(normalize_finish_reason("stop"), "end_turn"); assert_eq!(normalize_finish_reason("tool_calls"), "tool_use"); } #[test] fn tuning_params_included_in_payload_when_set() { let request = MessageRequest { model: "gpt-4o".to_string(), max_tokens: 1024, messages: vec![], system: None, tools: None, tool_choice: None, stream: false, temperature: Some(0.7), top_p: Some(0.9), frequency_penalty: Some(0.5), presence_penalty: Some(0.3), stop: Some(vec!["\n".to_string()]), reasoning_effort: None, }; let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai()); assert_eq!(payload["temperature"], 0.7); assert_eq!(payload["top_p"], 0.9); assert_eq!(payload["frequency_penalty"], 0.5); assert_eq!(payload["presence_penalty"], 0.3); assert_eq!(payload["stop"], json!(["\n"])); } #[test] fn reasoning_model_strips_tuning_params() { let request = MessageRequest { model: "o1-mini".to_string(), max_tokens: 1024, messages: vec![], stream: false, temperature: Some(0.7), top_p: Some(0.9), frequency_penalty: Some(0.5), presence_penalty: Some(0.3), stop: Some(vec!["\n".to_string()]), ..Default::default() }; let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai()); assert!( payload.get("temperature").is_none(), "reasoning model should strip temperature" ); assert!( payload.get("top_p").is_none(), "reasoning model should strip top_p" ); assert!(payload.get("frequency_penalty").is_none()); assert!(payload.get("presence_penalty").is_none()); // stop is safe for all providers assert_eq!(payload["stop"], json!(["\n"])); } #[test] fn grok_3_mini_is_reasoning_model() { assert!(is_reasoning_model("grok-3-mini")); assert!(is_reasoning_model("o1")); assert!(is_reasoning_model("o1-mini")); assert!(is_reasoning_model("o3-mini")); assert!(!is_reasoning_model("gpt-4o")); assert!(!is_reasoning_model("grok-3")); assert!(!is_reasoning_model("claude-sonnet-4-6")); } #[test] fn qwen_reasoning_variants_are_detected() { // QwQ reasoning model assert!(is_reasoning_model("qwen-qwq-32b")); assert!(is_reasoning_model("qwen/qwen-qwq-32b")); // Qwen3 thinking family assert!(is_reasoning_model("qwen3-30b-a3b-thinking")); assert!(is_reasoning_model("qwen/qwen3-30b-a3b-thinking")); // Bare qwq assert!(is_reasoning_model("qwq-plus")); // Regular Qwen models must NOT be classified as reasoning assert!(!is_reasoning_model("qwen-max")); assert!(!is_reasoning_model("qwen/qwen-plus")); assert!(!is_reasoning_model("qwen-turbo")); } #[test] fn tuning_params_omitted_from_payload_when_none() { let request = MessageRequest { model: "gpt-4o".to_string(), max_tokens: 1024, messages: vec![], stream: false, ..Default::default() }; let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai()); assert!( payload.get("temperature").is_none(), "temperature should be absent" ); assert!(payload.get("top_p").is_none(), "top_p should be absent"); assert!(payload.get("frequency_penalty").is_none()); assert!(payload.get("presence_penalty").is_none()); assert!(payload.get("stop").is_none()); } }