From fa72cd665e71f7740b0cf418899a5ebeb29fc265 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Sun, 5 Apr 2026 16:39:58 +0000 Subject: [PATCH] Block oversized requests before providers hard-fail The runtime already tracked rough token estimates for compaction, but provider-bound requests still relied on naive model output limits and could be sent upstream even when the selected model could not fit the estimated prompt plus requested output. This adds a small model token/context registry in the API layer, estimates request size from the serialized prompt payload, and fails locally with a dedicated context-window error before Anthropic or xAI calls are made. Focused integration coverage asserts the preflight fires before any HTTP request leaves the process. Constraint: Keep the first pass minimal and reusable across both Anthropic and OpenAI-compatible providers Rejected: Auto-compact-and-retry in the same patch | broader control-flow change than the requested minimal preflight Confidence: medium Scope-risk: narrow Reversibility: clean Directive: Expand the model registry before enabling preflight for additional providers or aliases Tested: cargo build -p api -p tools -p rusty-claude-cli; cargo test -p api Not-tested: End-to-end CLI auto-compaction or retry behavior after a local context_window_blocked failure --- rust/crates/api/src/error.rs | 18 ++ rust/crates/api/src/providers/anthropic.rs | 5 +- rust/crates/api/src/providers/mod.rs | 169 +++++++++++++++++- .../crates/api/src/providers/openai_compat.rs | 4 +- rust/crates/api/tests/client_integration.rs | 35 ++++ .../api/tests/openai_compat_integration.rs | 44 ++++- 6 files changed, 264 insertions(+), 11 deletions(-) diff --git a/rust/crates/api/src/error.rs b/rust/crates/api/src/error.rs index 7649889..35c8da2 100644 --- a/rust/crates/api/src/error.rs +++ b/rust/crates/api/src/error.rs @@ -8,6 +8,13 @@ pub enum ApiError { provider: &'static str, env_vars: &'static [&'static str], }, + ContextWindowExceeded { + model: String, + estimated_input_tokens: u32, + requested_output_tokens: u32, + estimated_total_tokens: u32, + context_window_tokens: u32, + }, ExpiredOAuthToken, Auth(String), InvalidApiKeyEnv(VarError), @@ -48,6 +55,7 @@ impl ApiError { Self::Api { retryable, .. } => *retryable, Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(), Self::MissingCredentials { .. } + | Self::ContextWindowExceeded { .. } | Self::ExpiredOAuthToken | Self::Auth(_) | Self::InvalidApiKeyEnv(_) @@ -67,6 +75,16 @@ impl Display for ApiError { "missing {provider} credentials; export {} before calling the {provider} API", env_vars.join(" or ") ), + Self::ContextWindowExceeded { + model, + estimated_input_tokens, + requested_output_tokens, + estimated_total_tokens, + context_window_tokens, + } => write!( + f, + "context_window_blocked for {model}: estimated input {estimated_input_tokens} + requested output {requested_output_tokens} = {estimated_total_tokens} tokens exceeds the {context_window_tokens}-token context window; compact the session or reduce request size before retrying" + ), Self::ExpiredOAuthToken => { write!( f, diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs index f8b41ac..a398241 100644 --- a/rust/crates/api/src/providers/anthropic.rs +++ b/rust/crates/api/src/providers/anthropic.rs @@ -14,7 +14,7 @@ use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, Session use crate::error::ApiError; use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; -use super::{Provider, ProviderFuture}; +use super::{preflight_message_request, Provider, ProviderFuture}; use crate::sse::SseParser; use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage}; @@ -294,6 +294,8 @@ impl AnthropicClient { } } + preflight_message_request(&request)?; + let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); let mut response = response @@ -337,6 +339,7 @@ impl AnthropicClient { &self, request: &MessageRequest, ) -> Result { + preflight_message_request(request)?; let response = self .send_with_retry(&request.clone().with_streaming()) .await?; diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index 6fa3d4d..31a58ed 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -1,6 +1,8 @@ use std::future::Future; use std::pin::Pin; +use serde::Serialize; + use crate::error::ApiError; use crate::types::{MessageRequest, MessageResponse}; @@ -40,6 +42,12 @@ pub struct ProviderMetadata { pub default_base_url: &'static str, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ModelTokenLimit { + pub max_output_tokens: u32, + pub context_window_tokens: u32, +} + const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ ( "opus", @@ -182,17 +190,86 @@ pub fn detect_provider_kind(model: &str) -> ProviderKind { #[must_use] pub fn max_tokens_for_model(model: &str) -> u32 { + model_token_limit(model).map_or_else( + || { + let canonical = resolve_model_alias(model); + if canonical.contains("opus") { + 32_000 + } else { + 64_000 + } + }, + |limit| limit.max_output_tokens, + ) +} + +#[must_use] +pub fn model_token_limit(model: &str) -> Option { let canonical = resolve_model_alias(model); - if canonical.contains("opus") { - 32_000 - } else { - 64_000 + match canonical.as_str() { + "claude-opus-4-6" => Some(ModelTokenLimit { + max_output_tokens: 32_000, + context_window_tokens: 200_000, + }), + "claude-sonnet-4-6" | "claude-haiku-4-5-20251213" => Some(ModelTokenLimit { + max_output_tokens: 64_000, + context_window_tokens: 200_000, + }), + "grok-3" | "grok-3-mini" => Some(ModelTokenLimit { + max_output_tokens: 64_000, + context_window_tokens: 131_072, + }), + _ => None, } } +pub fn preflight_message_request(request: &MessageRequest) -> Result<(), ApiError> { + let Some(limit) = model_token_limit(&request.model) else { + return Ok(()); + }; + + let estimated_input_tokens = estimate_message_request_input_tokens(request); + let estimated_total_tokens = estimated_input_tokens.saturating_add(request.max_tokens); + if estimated_total_tokens > limit.context_window_tokens { + return Err(ApiError::ContextWindowExceeded { + model: resolve_model_alias(&request.model), + estimated_input_tokens, + requested_output_tokens: request.max_tokens, + estimated_total_tokens, + context_window_tokens: limit.context_window_tokens, + }); + } + + Ok(()) +} + +fn estimate_message_request_input_tokens(request: &MessageRequest) -> u32 { + let mut estimate = estimate_serialized_tokens(&request.messages); + estimate = estimate.saturating_add(estimate_serialized_tokens(&request.system)); + estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tools)); + estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tool_choice)); + estimate +} + +fn estimate_serialized_tokens(value: &T) -> u32 { + serde_json::to_vec(value) + .ok() + .map_or(0, |bytes| (bytes.len() / 4 + 1) as u32) +} + #[cfg(test)] mod tests { - use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind}; + use serde_json::json; + + use crate::error::ApiError; + use crate::types::{ + InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, + }; + + use super::{ + detect_provider_kind, max_tokens_for_model, model_token_limit, preflight_message_request, + resolve_model_alias, ProviderKind, + }; #[test] fn resolves_grok_aliases() { @@ -215,4 +292,86 @@ mod tests { assert_eq!(max_tokens_for_model("opus"), 32_000); assert_eq!(max_tokens_for_model("grok-3"), 64_000); } + + #[test] + fn returns_context_window_metadata_for_supported_models() { + assert_eq!( + model_token_limit("claude-sonnet-4-6") + .expect("claude-sonnet-4-6 should be registered") + .context_window_tokens, + 200_000 + ); + assert_eq!( + model_token_limit("grok-mini") + .expect("grok-mini should resolve to a registered model") + .context_window_tokens, + 131_072 + ); + } + + #[test] + fn preflight_blocks_requests_that_exceed_the_model_context_window() { + let request = MessageRequest { + model: "claude-sonnet-4-6".to_string(), + max_tokens: 64_000, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "x".repeat(600_000), + }], + }], + system: Some("Keep the answer short.".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Fetches weather".to_string()), + input_schema: json!({ + "type": "object", + "properties": { "city": { "type": "string" } }, + }), + }]), + tool_choice: Some(ToolChoice::Auto), + stream: true, + }; + + let error = preflight_message_request(&request) + .expect_err("oversized request should be rejected before the provider call"); + + match error { + ApiError::ContextWindowExceeded { + model, + estimated_input_tokens, + requested_output_tokens, + estimated_total_tokens, + context_window_tokens, + } => { + assert_eq!(model, "claude-sonnet-4-6"); + assert!(estimated_input_tokens > 136_000); + assert_eq!(requested_output_tokens, 64_000); + assert!(estimated_total_tokens > context_window_tokens); + assert_eq!(context_window_tokens, 200_000); + } + other => panic!("expected context-window preflight failure, got {other:?}"), + } + } + + #[test] + fn preflight_skips_unknown_models() { + let request = MessageRequest { + model: "unknown-model".to_string(), + max_tokens: 64_000, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "x".repeat(600_000), + }], + }], + system: None, + tools: None, + tool_choice: None, + stream: false, + }; + + preflight_message_request(&request) + .expect("models without context metadata should skip the guarded preflight"); + } } diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 48eec30..bff4e44 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -12,7 +12,7 @@ use crate::types::{ ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, }; -use super::{Provider, ProviderFuture}; +use super::{preflight_message_request, Provider, ProviderFuture}; pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; @@ -128,6 +128,7 @@ impl OpenAiCompatClient { stream: false, ..request.clone() }; + preflight_message_request(&request)?; let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); let payload = response.json::().await?; @@ -142,6 +143,7 @@ impl OpenAiCompatClient { &self, request: &MessageRequest, ) -> Result { + preflight_message_request(request)?; let response = self .send_with_retry(&request.clone().with_streaming()) .await?; diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index 01addef..03b2f6d 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -103,6 +103,41 @@ async fn send_message_posts_json_and_parses_response() { ); } +#[tokio::test] +async fn send_message_blocks_oversized_requests_before_the_http_call() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", "{}")], + ) + .await; + + let client = AnthropicClient::new("test-key").with_base_url(server.base_url()); + let error = client + .send_message(&MessageRequest { + model: "claude-sonnet-4-6".to_string(), + max_tokens: 64_000, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "x".repeat(600_000), + }], + }], + system: Some("Keep the answer short.".to_string()), + tools: None, + tool_choice: None, + stream: false, + }) + .await + .expect_err("oversized request should fail local context-window preflight"); + + assert!(matches!(error, ApiError::ContextWindowExceeded { .. })); + assert!( + state.lock().await.is_empty(), + "preflight failure should avoid any upstream HTTP request" + ); +} + #[tokio::test] async fn send_message_applies_request_profile_and_records_telemetry() { let state = Arc::new(Mutex::new(Vec::::new())); diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs index e12e673..2dd90a2 100644 --- a/rust/crates/api/tests/openai_compat_integration.rs +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -4,10 +4,10 @@ use std::sync::Arc; use std::sync::{Mutex as StdMutex, OnceLock}; use api::{ - ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, - InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OpenAiCompatClient, - OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, - ToolDefinition, + ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, + ContentBlockStopEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, + OpenAiCompatClient, OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent, + ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -63,6 +63,42 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() { assert_eq!(body["tools"][0]["type"], json!("function")); } +#[tokio::test] +async fn send_message_blocks_oversized_xai_requests_before_the_http_call() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", "{}")], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let error = client + .send_message(&MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64_000, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "x".repeat(300_000), + }], + }], + system: Some("Keep the answer short.".to_string()), + tools: None, + tool_choice: None, + stream: false, + }) + .await + .expect_err("oversized request should fail local context-window preflight"); + + assert!(matches!(error, ApiError::ContextWindowExceeded { .. })); + assert!( + state.lock().await.is_empty(), + "preflight failure should avoid any upstream HTTP request" + ); +} + #[tokio::test] async fn send_message_accepts_full_chat_completions_endpoint_override() { let state = Arc::new(Mutex::new(Vec::::new()));