mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-08 17:14:49 +08:00
feat(api): add tuning params (temperature, top_p, penalties, stop) to MessageRequest
MessageRequest was missing standard OpenAI-compatible generation tuning parameters. Callers had no way to control temperature, top_p, frequency_penalty, presence_penalty, or stop sequences. Changes: - Added 5 optional fields to MessageRequest (all Option, None by default) - Wired into build_chat_completion_request: only included in payload when set - All existing construction sites updated with ..Default::default() - MessageRequest now derives Default for ergonomic partial construction Tests added: - tuning_params_included_in_payload_when_set: all 5 params flow into JSON - tuning_params_omitted_from_payload_when_none: absent params stay absent 83 api lib tests passing, 0 failing. cargo check --workspace: 0 warnings.
This commit is contained in:
@@ -704,6 +704,7 @@ mod tests {
|
|||||||
tools: None,
|
tools: None,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1259,6 +1259,7 @@ mod tests {
|
|||||||
tools: None,
|
tools: None,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(request.with_streaming().stream);
|
assert!(request.with_streaming().stream);
|
||||||
@@ -1449,6 +1450,7 @@ mod tests {
|
|||||||
tools: None,
|
tools: None,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut rendered = client
|
let mut rendered = client
|
||||||
|
|||||||
@@ -480,6 +480,7 @@ mod tests {
|
|||||||
}]),
|
}]),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
stream: true,
|
stream: true,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let error = preflight_message_request(&request)
|
let error = preflight_message_request(&request)
|
||||||
@@ -518,6 +519,7 @@ mod tests {
|
|||||||
tools: None,
|
tools: None,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
preflight_message_request(&request)
|
preflight_message_request(&request)
|
||||||
|
|||||||
@@ -721,6 +721,25 @@ fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatC
|
|||||||
payload["tool_choice"] = openai_tool_choice(tool_choice);
|
payload["tool_choice"] = openai_tool_choice(tool_choice);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI-compatible tuning parameters — only included when explicitly set.
|
||||||
|
if let Some(temperature) = request.temperature {
|
||||||
|
payload["temperature"] = json!(temperature);
|
||||||
|
}
|
||||||
|
if let Some(top_p) = request.top_p {
|
||||||
|
payload["top_p"] = json!(top_p);
|
||||||
|
}
|
||||||
|
if let Some(frequency_penalty) = request.frequency_penalty {
|
||||||
|
payload["frequency_penalty"] = json!(frequency_penalty);
|
||||||
|
}
|
||||||
|
if let Some(presence_penalty) = request.presence_penalty {
|
||||||
|
payload["presence_penalty"] = json!(presence_penalty);
|
||||||
|
}
|
||||||
|
if let Some(stop) = &request.stop {
|
||||||
|
if !stop.is_empty() {
|
||||||
|
payload["stop"] = json!(stop);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
payload
|
payload
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1049,6 +1068,7 @@ mod tests {
|
|||||||
}]),
|
}]),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
stream: false,
|
stream: false,
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
OpenAiCompatConfig::xai(),
|
OpenAiCompatConfig::xai(),
|
||||||
);
|
);
|
||||||
@@ -1071,6 +1091,7 @@ mod tests {
|
|||||||
tools: None,
|
tools: None,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
stream: true,
|
stream: true,
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
OpenAiCompatConfig::openai(),
|
OpenAiCompatConfig::openai(),
|
||||||
);
|
);
|
||||||
@@ -1089,6 +1110,7 @@ mod tests {
|
|||||||
tools: None,
|
tools: None,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
stream: true,
|
stream: true,
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
OpenAiCompatConfig::xai(),
|
OpenAiCompatConfig::xai(),
|
||||||
);
|
);
|
||||||
@@ -1159,4 +1181,45 @@ mod tests {
|
|||||||
assert_eq!(normalize_finish_reason("stop"), "end_turn");
|
assert_eq!(normalize_finish_reason("stop"), "end_turn");
|
||||||
assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
|
assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tuning_params_included_in_payload_when_set() {
|
||||||
|
let request = MessageRequest {
|
||||||
|
model: "gpt-4o".to_string(),
|
||||||
|
max_tokens: 1024,
|
||||||
|
messages: vec![],
|
||||||
|
system: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
stream: false,
|
||||||
|
temperature: Some(0.7),
|
||||||
|
top_p: Some(0.9),
|
||||||
|
frequency_penalty: Some(0.5),
|
||||||
|
presence_penalty: Some(0.3),
|
||||||
|
stop: Some(vec!["\n".to_string()]),
|
||||||
|
};
|
||||||
|
let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai());
|
||||||
|
assert_eq!(payload["temperature"], 0.7);
|
||||||
|
assert_eq!(payload["top_p"], 0.9);
|
||||||
|
assert_eq!(payload["frequency_penalty"], 0.5);
|
||||||
|
assert_eq!(payload["presence_penalty"], 0.3);
|
||||||
|
assert_eq!(payload["stop"], json!(["\n"]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tuning_params_omitted_from_payload_when_none() {
|
||||||
|
let request = MessageRequest {
|
||||||
|
model: "gpt-4o".to_string(),
|
||||||
|
max_tokens: 1024,
|
||||||
|
messages: vec![],
|
||||||
|
stream: false,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai());
|
||||||
|
assert!(payload.get("temperature").is_none(), "temperature should be absent");
|
||||||
|
assert!(payload.get("top_p").is_none(), "top_p should be absent");
|
||||||
|
assert!(payload.get("frequency_penalty").is_none());
|
||||||
|
assert!(payload.get("presence_penalty").is_none());
|
||||||
|
assert!(payload.get("stop").is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use runtime::{pricing_for_model, TokenUsage, UsageCostEstimate};
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||||
pub struct MessageRequest {
|
pub struct MessageRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub max_tokens: u32,
|
pub max_tokens: u32,
|
||||||
@@ -15,6 +15,17 @@ pub struct MessageRequest {
|
|||||||
pub tool_choice: Option<ToolChoice>,
|
pub tool_choice: Option<ToolChoice>,
|
||||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
|
/// OpenAI-compatible tuning parameters. Optional — omitted from payload when None.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MessageRequest {
|
impl MessageRequest {
|
||||||
|
|||||||
@@ -6375,6 +6375,7 @@ impl ApiClient for AnthropicRuntimeClient {
|
|||||||
.then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())),
|
.then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())),
|
||||||
tool_choice: self.enable_tools.then_some(ToolChoice::Auto),
|
tool_choice: self.enable_tools.then_some(ToolChoice::Auto),
|
||||||
stream: true,
|
stream: true,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
self.runtime.block_on(async {
|
self.runtime.block_on(async {
|
||||||
|
|||||||
@@ -3841,6 +3841,7 @@ impl ApiClient for ProviderRuntimeClient {
|
|||||||
tools: (!tools.is_empty()).then(|| tools.clone()),
|
tools: (!tools.is_empty()).then(|| tools.clone()),
|
||||||
tool_choice: tool_choice.clone(),
|
tool_choice: tool_choice.clone(),
|
||||||
stream: true,
|
stream: true,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let attempt = runtime.block_on(stream_with_provider(&entry.client, &message_request));
|
let attempt = runtime.block_on(stream_with_provider(&entry.client, &message_request));
|
||||||
|
|||||||
Reference in New Issue
Block a user