diff --git a/rust/crates/api/src/prompt_cache.rs b/rust/crates/api/src/prompt_cache.rs index ec9b0b4..0ee8663 100644 --- a/rust/crates/api/src/prompt_cache.rs +++ b/rust/crates/api/src/prompt_cache.rs @@ -704,6 +704,7 @@ mod tests { tools: None, tool_choice: None, stream: false, + ..Default::default() } } diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs index 3180169..301b6fc 100644 --- a/rust/crates/api/src/providers/anthropic.rs +++ b/rust/crates/api/src/providers/anthropic.rs @@ -1259,6 +1259,7 @@ mod tests { tools: None, tool_choice: None, stream: false, + ..Default::default() }; assert!(request.with_streaming().stream); @@ -1449,6 +1450,7 @@ mod tests { tools: None, tool_choice: None, stream: false, + ..Default::default() }; let mut rendered = client diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index b494303..9219963 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -480,6 +480,7 @@ mod tests { }]), tool_choice: Some(ToolChoice::Auto), stream: true, + ..Default::default() }; let error = preflight_message_request(&request) @@ -518,6 +519,7 @@ mod tests { tools: None, tool_choice: None, stream: false, + ..Default::default() }; preflight_message_request(&request) diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 57d9ef2..1ba4f03 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -721,6 +721,25 @@ fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatC payload["tool_choice"] = openai_tool_choice(tool_choice); } + // OpenAI-compatible tuning parameters — only included when explicitly set. + if let Some(temperature) = request.temperature { + payload["temperature"] = json!(temperature); + } + if let Some(top_p) = request.top_p { + payload["top_p"] = json!(top_p); + } + if let Some(frequency_penalty) = request.frequency_penalty { + payload["frequency_penalty"] = json!(frequency_penalty); + } + if let Some(presence_penalty) = request.presence_penalty { + payload["presence_penalty"] = json!(presence_penalty); + } + if let Some(stop) = &request.stop { + if !stop.is_empty() { + payload["stop"] = json!(stop); + } + } + payload } @@ -1049,6 +1068,7 @@ mod tests { }]), tool_choice: Some(ToolChoice::Auto), stream: false, + ..Default::default() }, OpenAiCompatConfig::xai(), ); @@ -1071,6 +1091,7 @@ mod tests { tools: None, tool_choice: None, stream: true, + ..Default::default() }, OpenAiCompatConfig::openai(), ); @@ -1089,6 +1110,7 @@ mod tests { tools: None, tool_choice: None, stream: true, + ..Default::default() }, OpenAiCompatConfig::xai(), ); @@ -1159,4 +1181,45 @@ mod tests { assert_eq!(normalize_finish_reason("stop"), "end_turn"); assert_eq!(normalize_finish_reason("tool_calls"), "tool_use"); } + + #[test] + fn tuning_params_included_in_payload_when_set() { + let request = MessageRequest { + model: "gpt-4o".to_string(), + max_tokens: 1024, + messages: vec![], + system: None, + tools: None, + tool_choice: None, + stream: false, + temperature: Some(0.7), + top_p: Some(0.9), + frequency_penalty: Some(0.5), + presence_penalty: Some(0.3), + stop: Some(vec!["\n".to_string()]), + }; + let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai()); + assert_eq!(payload["temperature"], 0.7); + assert_eq!(payload["top_p"], 0.9); + assert_eq!(payload["frequency_penalty"], 0.5); + assert_eq!(payload["presence_penalty"], 0.3); + assert_eq!(payload["stop"], json!(["\n"])); + } + + #[test] + fn tuning_params_omitted_from_payload_when_none() { + let request = MessageRequest { + model: "gpt-4o".to_string(), + max_tokens: 1024, + messages: vec![], + stream: false, + ..Default::default() + }; + let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai()); + assert!(payload.get("temperature").is_none(), "temperature should be absent"); + assert!(payload.get("top_p").is_none(), "top_p should be absent"); + assert!(payload.get("frequency_penalty").is_none()); + assert!(payload.get("presence_penalty").is_none()); + assert!(payload.get("stop").is_none()); + } } diff --git a/rust/crates/api/src/types.rs b/rust/crates/api/src/types.rs index f2e33f6..830b3de 100644 --- a/rust/crates/api/src/types.rs +++ b/rust/crates/api/src/types.rs @@ -2,7 +2,7 @@ use runtime::{pricing_for_model, TokenUsage, UsageCostEstimate}; use serde::{Deserialize, Serialize}; use serde_json::Value; -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] pub struct MessageRequest { pub model: String, pub max_tokens: u32, @@ -15,6 +15,17 @@ pub struct MessageRequest { pub tool_choice: Option, #[serde(default, skip_serializing_if = "std::ops::Not::not")] pub stream: bool, + /// OpenAI-compatible tuning parameters. Optional — omitted from payload when None. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, } impl MessageRequest { diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 330afb6..3c09cc5 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -6375,6 +6375,7 @@ impl ApiClient for AnthropicRuntimeClient { .then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())), tool_choice: self.enable_tools.then_some(ToolChoice::Auto), stream: true, + ..Default::default() }; self.runtime.block_on(async { diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index a6aec0c..0ccd700 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -3841,6 +3841,7 @@ impl ApiClient for ProviderRuntimeClient { tools: (!tools.is_empty()).then(|| tools.clone()), tool_choice: tool_choice.clone(), stream: true, + ..Default::default() }; let attempt = runtime.block_on(stream_with_provider(&entry.client, &message_request));