diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index d1db46f..48eec30 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -67,6 +67,7 @@ impl OpenAiCompatConfig { pub struct OpenAiCompatClient { http: reqwest::Client, api_key: String, + config: OpenAiCompatConfig, base_url: String, max_retries: u32, initial_backoff: Duration, @@ -74,11 +75,15 @@ pub struct OpenAiCompatClient { } impl OpenAiCompatClient { + const fn config(&self) -> OpenAiCompatConfig { + self.config + } #[must_use] pub fn new(api_key: impl Into, config: OpenAiCompatConfig) -> Self { Self { http: reqwest::Client::new(), api_key: api_key.into(), + config, base_url: read_base_url(config), max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, @@ -190,7 +195,7 @@ impl OpenAiCompatClient { .post(&request_url) .header("content-type", "application/json") .bearer_auth(&self.api_key) - .json(&build_chat_completion_request(request)) + .json(&build_chat_completion_request(request, self.config())) .send() .await .map_err(ApiError::from) @@ -633,7 +638,7 @@ struct ErrorBody { message: Option, } -fn build_chat_completion_request(request: &MessageRequest) -> Value { +fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value { let mut messages = Vec::new(); if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { messages.push(json!({ @@ -652,6 +657,10 @@ fn build_chat_completion_request(request: &MessageRequest) -> Value { "stream": request.stream, }); + if request.stream && should_request_stream_usage(config) { + payload["stream_options"] = json!({ "include_usage": true }); + } + if let Some(tools) = &request.tools { payload["tools"] = Value::Array(tools.iter().map(openai_tool_definition).collect::>()); @@ -749,6 +758,10 @@ fn openai_tool_choice(tool_choice: &ToolChoice) -> Value { } } +fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool { + matches!(config.provider_name, "OpenAI") +} + fn normalize_response( model: &str, response: ChatCompletionResponse, @@ -951,33 +964,36 @@ mod tests { #[test] fn request_translation_uses_openai_compatible_shape() { - let payload = build_chat_completion_request(&MessageRequest { - model: "grok-3".to_string(), - max_tokens: 64, - messages: vec![InputMessage { - role: "user".to_string(), - content: vec![ - InputContentBlock::Text { - text: "hello".to_string(), - }, - InputContentBlock::ToolResult { - tool_use_id: "tool_1".to_string(), - content: vec![ToolResultContentBlock::Json { - value: json!({"ok": true}), - }], - is_error: false, - }, - ], - }], - system: Some("be helpful".to_string()), - tools: Some(vec![ToolDefinition { - name: "weather".to_string(), - description: Some("Get weather".to_string()), - input_schema: json!({"type": "object"}), - }]), - tool_choice: Some(ToolChoice::Auto), - stream: false, - }); + let payload = build_chat_completion_request( + &MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![ + InputContentBlock::Text { + text: "hello".to_string(), + }, + InputContentBlock::ToolResult { + tool_use_id: "tool_1".to_string(), + content: vec![ToolResultContentBlock::Json { + value: json!({"ok": true}), + }], + is_error: false, + }, + ], + }], + system: Some("be helpful".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Get weather".to_string()), + input_schema: json!({"type": "object"}), + }]), + tool_choice: Some(ToolChoice::Auto), + stream: false, + }, + OpenAiCompatConfig::xai(), + ); assert_eq!(payload["messages"][0]["role"], json!("system")); assert_eq!(payload["messages"][1]["role"], json!("user")); @@ -986,6 +1002,42 @@ mod tests { assert_eq!(payload["tool_choice"], json!("auto")); } + #[test] + fn openai_streaming_requests_include_usage_opt_in() { + let payload = build_chat_completion_request( + &MessageRequest { + model: "gpt-5".to_string(), + max_tokens: 64, + messages: vec![InputMessage::user_text("hello")], + system: None, + tools: None, + tool_choice: None, + stream: true, + }, + OpenAiCompatConfig::openai(), + ); + + assert_eq!(payload["stream_options"], json!({"include_usage": true})); + } + + #[test] + fn xai_streaming_requests_skip_openai_specific_usage_opt_in() { + let payload = build_chat_completion_request( + &MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage::user_text("hello")], + system: None, + tools: None, + tool_choice: None, + stream: true, + }, + OpenAiCompatConfig::xai(), + ); + + assert!(payload.get("stream_options").is_none()); + } + #[test] fn tool_choice_translation_supports_required_function() { assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required")); diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs index 116451e..e12e673 100644 --- a/rust/crates/api/tests/openai_compat_integration.rs +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -5,8 +5,9 @@ use std::sync::{Mutex as StdMutex, OnceLock}; use api::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, - InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, - OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, + InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OpenAiCompatClient, + OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, + ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -195,6 +196,82 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() { assert!(request.body.contains("\"stream\":true")); } +#[allow(clippy::await_holding_lock)] +#[tokio::test] +async fn openai_streaming_requests_opt_into_usage_chunks() { + let state = Arc::new(Mutex::new(Vec::::new())); + let sse = concat!( + "data: {\"id\":\"chatcmpl_openai_stream\",\"model\":\"gpt-5\",\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n", + "data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n", + "data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}\n\n", + "data: [DONE]\n\n" + ); + let server = spawn_server( + state.clone(), + vec![http_response_with_headers( + "200 OK", + "text/event-stream", + sse, + &[("x-request-id", "req_openai_stream")], + )], + ) + .await; + + let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai()) + .with_base_url(server.base_url()); + let mut stream = client + .stream_message(&sample_request(false)) + .await + .expect("stream should start"); + + assert_eq!(stream.request_id(), Some("req_openai_stream")); + + let mut events = Vec::new(); + while let Some(event) = stream.next_event().await.expect("event should parse") { + events.push(event); + } + + assert!(matches!(events[0], StreamEvent::MessageStart(_))); + assert!(matches!( + events[1], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + content_block: OutputContentBlock::Text { .. }, + .. + }) + )); + assert!(matches!( + events[2], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + delta: ContentBlockDelta::TextDelta { .. }, + .. + }) + )); + assert!(matches!( + events[3], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 }) + )); + assert!(matches!( + events[4], + StreamEvent::MessageDelta(MessageDeltaEvent { .. }) + )); + assert!(matches!(events[5], StreamEvent::MessageStop(_))); + + match &events[4] { + StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => { + assert_eq!(usage.input_tokens, 9); + assert_eq!(usage.output_tokens, 4); + } + other => panic!("expected message delta, got {other:?}"), + } + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["stream"], json!(true)); + assert_eq!(body["stream_options"], json!({"include_usage": true})); +} + #[allow(clippy::await_holding_lock)] #[tokio::test] async fn provider_client_dispatches_xai_requests_from_env() {