Use Anthropic count tokens for preflight

This commit is contained in:
Yeachan-Heo
2026-04-06 09:38:21 +00:00
parent c1883d0f66
commit be561bfdeb

View File

@@ -14,7 +14,7 @@ use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, Session
use crate::error::ApiError;
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
use super::{preflight_message_request, Provider, ProviderFuture};
use super::{model_token_limit, resolve_model_alias, Provider, ProviderFuture};
use crate::sse::SseParser;
use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
@@ -294,7 +294,7 @@ impl AnthropicClient {
}
}
preflight_message_request(&request)?;
self.preflight_message_request(&request).await?;
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
@@ -339,7 +339,7 @@ impl AnthropicClient {
&self,
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
preflight_message_request(request)?;
self.preflight_message_request(request).await?;
let response = self
.send_with_retry(&request.clone().with_streaming())
.await?;
@@ -466,18 +466,67 @@ impl AnthropicClient {
request: &MessageRequest,
) -> Result<reqwest::Response, ApiError> {
let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
let request_body = self.request_profile.render_json_body(request)?;
let request_builder = self.build_request(&request_url).json(&request_body);
request_builder.send().await.map_err(ApiError::from)
}
fn build_request(&self, request_url: &str) -> reqwest::RequestBuilder {
let request_builder = self
.http
.post(&request_url)
.post(request_url)
.header("content-type", "application/json");
let mut request_builder = self.auth.apply(request_builder);
for (header_name, header_value) in self.request_profile.header_pairs() {
request_builder = request_builder.header(header_name, header_value);
}
request_builder
}
async fn preflight_message_request(&self, request: &MessageRequest) -> Result<(), ApiError> {
let Some(limit) = model_token_limit(&request.model) else {
return Ok(());
};
let counted_input_tokens = match self.count_tokens(request).await {
Ok(count) => count,
Err(_) => return Ok(()),
};
let estimated_total_tokens = counted_input_tokens.saturating_add(request.max_tokens);
if estimated_total_tokens > limit.context_window_tokens {
return Err(ApiError::ContextWindowExceeded {
model: resolve_model_alias(&request.model),
estimated_input_tokens: counted_input_tokens,
requested_output_tokens: request.max_tokens,
estimated_total_tokens,
context_window_tokens: limit.context_window_tokens,
});
}
Ok(())
}
async fn count_tokens(&self, request: &MessageRequest) -> Result<u32, ApiError> {
#[derive(serde::Deserialize)]
struct CountTokensResponse {
input_tokens: u32,
}
let request_url = format!("{}/v1/messages/count_tokens", self.base_url.trim_end_matches('/'));
let request_body = self.request_profile.render_json_body(request)?;
request_builder = request_builder.json(&request_body);
request_builder.send().await.map_err(ApiError::from)
let response = self
.build_request(&request_url)
.json(&request_body)
.send()
.await
.map_err(ApiError::from)?;
let parsed = expect_success(response)
.await?
.json::<CountTokensResponse>()
.await
.map_err(ApiError::from)?;
Ok(parsed.input_tokens)
}
fn record_request_failure(&self, attempt: u32, error: &ApiError) {