From 0cf2204d43700cc53edf4108b735ea83c8e76f8c Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:30:24 +0000 Subject: [PATCH 1/4] wip: cache-tracking progress --- rust/crates/api/src/client.rs | 91 ++- rust/crates/api/src/lib.rs | 5 + rust/crates/api/src/prompt_cache.rs | 679 ++++++++++++++++++++ rust/crates/api/tests/client_integration.rs | 146 ++++- 4 files changed, 916 insertions(+), 5 deletions(-) create mode 100644 rust/crates/api/src/prompt_cache.rs diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 7ef7e83..4d264b5 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -8,8 +8,9 @@ use runtime::{ use serde::Deserialize; use crate::error::ApiError; +use crate::prompt_cache::{PromptCache, PromptCacheStats}; use crate::sse::SseParser; -use crate::types::{MessageRequest, MessageResponse, StreamEvent}; +use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage}; const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; const ANTHROPIC_VERSION: &str = "2023-06-01"; @@ -108,6 +109,7 @@ pub struct AnthropicClient { max_retries: u32, initial_backoff: Duration, max_backoff: Duration, + prompt_cache: Option, } impl AnthropicClient { @@ -120,6 +122,7 @@ impl AnthropicClient { max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + prompt_cache: None, } } @@ -132,6 +135,7 @@ impl AnthropicClient { max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + prompt_cache: None, } } @@ -189,6 +193,22 @@ impl AnthropicClient { self } + #[must_use] + pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self { + self.prompt_cache = Some(prompt_cache); + self + } + + #[must_use] + pub fn prompt_cache(&self) -> Option<&PromptCache> { + self.prompt_cache.as_ref() + } + + #[must_use] + pub fn prompt_cache_stats(&self) -> Option { + self.prompt_cache.as_ref().map(PromptCache::stats) + } + #[must_use] pub fn auth_source(&self) -> &AuthSource { &self.auth @@ -202,6 +222,11 @@ impl AnthropicClient { stream: false, ..request.clone() }; + if let Some(prompt_cache) = &self.prompt_cache { + if let Some(response) = prompt_cache.lookup_completion(&request) { + return Ok(response); + } + } let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); let mut response = response @@ -211,6 +236,9 @@ impl AnthropicClient { if response.request_id.is_none() { response.request_id = request_id; } + if let Some(prompt_cache) = &self.prompt_cache { + let _ = prompt_cache.record_response(&request, &response); + } Ok(response) } @@ -227,6 +255,15 @@ impl AnthropicClient { parser: SseParser::new(), pending: VecDeque::new(), done: false, + cache_tracking: self + .prompt_cache + .as_ref() + .map(|prompt_cache| StreamCacheTracking { + prompt_cache: prompt_cache.clone(), + request: request.clone().with_streaming(), + last_usage: None, + finalized: false, + }), }) } @@ -527,6 +564,7 @@ pub struct MessageStream { parser: SseParser, pending: VecDeque, done: bool, + cache_tracking: Option, } impl MessageStream { @@ -538,6 +576,9 @@ impl MessageStream { pub async fn next_event(&mut self) -> Result, ApiError> { loop { if let Some(event) = self.pending.pop_front() { + if let Some(cache_tracking) = &mut self.cache_tracking { + cache_tracking.observe(&event); + } return Ok(Some(event)); } @@ -545,8 +586,14 @@ impl MessageStream { let remaining = self.parser.finish()?; self.pending.extend(remaining); if let Some(event) = self.pending.pop_front() { + if let Some(cache_tracking) = &mut self.cache_tracking { + cache_tracking.observe(&event); + } return Ok(Some(event)); } + if let Some(cache_tracking) = &mut self.cache_tracking { + cache_tracking.finalize(); + } return Ok(None); } @@ -562,6 +609,41 @@ impl MessageStream { } } +#[derive(Debug, Clone)] +struct StreamCacheTracking { + prompt_cache: PromptCache, + request: MessageRequest, + last_usage: Option, + finalized: bool, +} + +impl StreamCacheTracking { + fn observe(&mut self, event: &StreamEvent) { + match event { + StreamEvent::MessageStart(event) => { + self.last_usage = Some(event.message.usage.clone()); + } + StreamEvent::MessageDelta(event) => { + self.last_usage = Some(event.usage.clone()); + } + StreamEvent::ContentBlockStart(_) + | StreamEvent::ContentBlockDelta(_) + | StreamEvent::ContentBlockStop(_) + | StreamEvent::MessageStop(_) => {} + } + } + + fn finalize(&mut self) { + if self.finalized { + return; + } + if let Some(usage) = &self.last_usage { + let _ = self.prompt_cache.record_usage(&self.request, usage); + } + self.finalized = true; + } +} + async fn expect_success(response: reqwest::Response) -> Result { let status = response.status(); if status.is_success() { @@ -606,6 +688,7 @@ mod tests { use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; use std::io::{Read, Write}; use std::net::TcpListener; + use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Mutex, OnceLock}; use std::thread; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -622,13 +705,15 @@ mod tests { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) .lock() - .expect("env lock") + .unwrap_or_else(std::sync::PoisonError::into_inner) } fn temp_config_home() -> std::path::PathBuf { + static NEXT_ID: AtomicU64 = AtomicU64::new(0); std::env::temp_dir().join(format!( - "api-oauth-test-{}-{}", + "api-oauth-test-{}-{}-{}", std::process::id(), + NEXT_ID.fetch_add(1, Ordering::Relaxed), SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time") diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 4108187..43e2ffa 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -1,5 +1,6 @@ mod client; mod error; +mod prompt_cache; mod sse; mod types; @@ -8,6 +9,10 @@ pub use client::{ AnthropicClient, AuthSource, MessageStream, OAuthTokenSet, }; pub use error::ApiError; +pub use prompt_cache::{ + CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord, + PromptCacheStats, +}; pub use sse::{parse_frame, SseParser}; pub use types::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, diff --git a/rust/crates/api/src/prompt_cache.rs b/rust/crates/api/src/prompt_cache.rs new file mode 100644 index 0000000..5a6a7da --- /dev/null +++ b/rust/crates/api/src/prompt_cache.rs @@ -0,0 +1,679 @@ +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use crate::types::{MessageRequest, MessageResponse, Usage}; + +const DEFAULT_COMPLETION_TTL_SECS: u64 = 30; +const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60; +const DEFAULT_BREAK_MIN_DROP: u32 = 2_000; +const MAX_SANITIZED_LENGTH: usize = 80; +const REQUEST_FINGERPRINT_VERSION: u32 = 1; +const REQUEST_FINGERPRINT_PREFIX: &str = "v1"; +const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325; +const FNV_PRIME: u64 = 0x0000_0100_0000_01b3; + +#[derive(Debug, Clone)] +pub struct PromptCacheConfig { + pub session_id: String, + pub completion_ttl: Duration, + pub prompt_ttl: Duration, + pub cache_break_min_drop: u32, +} + +impl PromptCacheConfig { + #[must_use] + pub fn new(session_id: impl Into) -> Self { + Self { + session_id: session_id.into(), + completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS), + prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS), + cache_break_min_drop: DEFAULT_BREAK_MIN_DROP, + } + } +} + +impl Default for PromptCacheConfig { + fn default() -> Self { + Self::new("default") + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptCachePaths { + pub root: PathBuf, + pub session_dir: PathBuf, + pub completion_dir: PathBuf, + pub session_state_path: PathBuf, + pub stats_path: PathBuf, +} + +impl PromptCachePaths { + #[must_use] + pub fn for_session(session_id: &str) -> Self { + let root = base_cache_root(); + let session_dir = root.join(sanitize_path_segment(session_id)); + let completion_dir = session_dir.join("completions"); + Self { + root, + session_state_path: session_dir.join("session-state.json"), + stats_path: session_dir.join("stats.json"), + session_dir, + completion_dir, + } + } + + #[must_use] + pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf { + self.completion_dir.join(format!("{request_hash}.json")) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptCacheStats { + pub tracked_requests: u64, + pub completion_cache_hits: u64, + pub completion_cache_misses: u64, + pub completion_cache_writes: u64, + pub expected_invalidations: u64, + pub unexpected_cache_breaks: u64, + pub total_cache_creation_input_tokens: u64, + pub total_cache_read_input_tokens: u64, + pub last_cache_creation_input_tokens: Option, + pub last_cache_read_input_tokens: Option, + pub last_request_hash: Option, + pub last_completion_cache_key: Option, + pub last_break_reason: Option, + pub last_cache_source: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct CacheBreakEvent { + pub unexpected: bool, + pub reason: String, + pub previous_cache_read_input_tokens: u32, + pub current_cache_read_input_tokens: u32, + pub token_drop: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PromptCacheRecord { + pub cache_break: Option, + pub stats: PromptCacheStats, +} + +#[derive(Debug, Clone)] +pub struct PromptCache { + inner: Arc>, +} + +impl PromptCache { + #[must_use] + pub fn new(session_id: impl Into) -> Self { + Self::with_config(PromptCacheConfig::new(session_id)) + } + + #[must_use] + pub fn with_config(config: PromptCacheConfig) -> Self { + let paths = PromptCachePaths::for_session(&config.session_id); + let stats = read_json::(&paths.stats_path).unwrap_or_default(); + let previous = read_json::(&paths.session_state_path); + Self { + inner: Arc::new(Mutex::new(PromptCacheInner { + config, + paths, + stats, + previous, + })), + } + } + + #[must_use] + pub fn paths(&self) -> PromptCachePaths { + self.lock().paths.clone() + } + + #[must_use] + pub fn stats(&self) -> PromptCacheStats { + self.lock().stats.clone() + } + + pub fn lookup_completion(&self, request: &MessageRequest) -> Option { + let request_hash = request_hash_hex(request); + let (paths, ttl) = { + let inner = self.lock(); + (inner.paths.clone(), inner.config.completion_ttl) + }; + let entry_path = paths.completion_entry_path(&request_hash); + let entry = read_json::(&entry_path); + let Some(entry) = entry else { + let mut inner = self.lock(); + inner.stats.completion_cache_misses += 1; + inner.stats.last_completion_cache_key = Some(request_hash); + persist_state(&inner); + return None; + }; + + if entry.fingerprint_version != current_fingerprint_version() { + let mut inner = self.lock(); + inner.stats.completion_cache_misses += 1; + inner.stats.last_completion_cache_key = Some(request_hash.clone()); + let _ = fs::remove_file(entry_path); + persist_state(&inner); + return None; + } + + let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs(); + let mut inner = self.lock(); + inner.stats.last_completion_cache_key = Some(request_hash.clone()); + if expired { + inner.stats.completion_cache_misses += 1; + let _ = fs::remove_file(entry_path); + persist_state(&inner); + return None; + } + + inner.stats.completion_cache_hits += 1; + apply_usage_to_stats( + &mut inner.stats, + &entry.response.usage, + &request_hash, + "completion-cache", + ); + inner.previous = Some(TrackedPromptState::from_usage( + request, + &entry.response.usage, + )); + persist_state(&inner); + Some(entry.response) + } + + pub fn record_response( + &self, + request: &MessageRequest, + response: &MessageResponse, + ) -> PromptCacheRecord { + self.record_usage_internal(request, &response.usage, Some(response)) + } + + pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord { + self.record_usage_internal(request, usage, None) + } + + fn record_usage_internal( + &self, + request: &MessageRequest, + usage: &Usage, + response: Option<&MessageResponse>, + ) -> PromptCacheRecord { + let request_hash = request_hash_hex(request); + let mut inner = self.lock(); + let previous = inner.previous.clone(); + let current = TrackedPromptState::from_usage(request, usage); + let cache_break = detect_cache_break(&inner.config, previous.as_ref(), ¤t); + + inner.stats.tracked_requests += 1; + apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response"); + if let Some(event) = &cache_break { + if event.unexpected { + inner.stats.unexpected_cache_breaks += 1; + } else { + inner.stats.expected_invalidations += 1; + } + inner.stats.last_break_reason = Some(event.reason.clone()); + } + + inner.previous = Some(current); + if let Some(response) = response { + write_completion_entry(&inner.paths, &request_hash, response); + inner.stats.completion_cache_writes += 1; + } + persist_state(&inner); + + PromptCacheRecord { + cache_break, + stats: inner.stats.clone(), + } + } + + fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> { + self.inner + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } +} + +#[derive(Debug)] +struct PromptCacheInner { + config: PromptCacheConfig, + paths: PromptCachePaths, + stats: PromptCacheStats, + previous: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CompletionCacheEntry { + cached_at_unix_secs: u64, + #[serde(default = "current_fingerprint_version")] + fingerprint_version: u32, + response: MessageResponse, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct TrackedPromptState { + observed_at_unix_secs: u64, + #[serde(default = "current_fingerprint_version")] + fingerprint_version: u32, + request_hash: u64, + model_hash: u64, + system_hash: u64, + tools_hash: u64, + messages_hash: u64, + cache_read_input_tokens: u32, +} + +impl TrackedPromptState { + fn from_usage(request: &MessageRequest, usage: &Usage) -> Self { + let hashes = RequestHashes::from_request(request); + Self { + observed_at_unix_secs: now_unix_secs(), + fingerprint_version: current_fingerprint_version(), + request_hash: hashes.request_hash, + model_hash: hashes.model_hash, + system_hash: hashes.system_hash, + tools_hash: hashes.tools_hash, + messages_hash: hashes.messages_hash, + cache_read_input_tokens: usage.cache_read_input_tokens, + } + } +} + +#[derive(Debug, Clone, Copy)] +struct RequestHashes { + request_hash: u64, + model_hash: u64, + system_hash: u64, + tools_hash: u64, + messages_hash: u64, +} + +impl RequestHashes { + fn from_request(request: &MessageRequest) -> Self { + Self { + request_hash: hash_serializable(request), + model_hash: hash_serializable(&request.model), + system_hash: hash_serializable(&request.system), + tools_hash: hash_serializable(&request.tools), + messages_hash: hash_serializable(&request.messages), + } + } +} + +fn detect_cache_break( + config: &PromptCacheConfig, + previous: Option<&TrackedPromptState>, + current: &TrackedPromptState, +) -> Option { + let previous = previous?; + if previous.fingerprint_version != current.fingerprint_version { + return Some(CacheBreakEvent { + unexpected: false, + reason: format!( + "fingerprint version changed (v{} -> v{})", + previous.fingerprint_version, current.fingerprint_version + ), + previous_cache_read_input_tokens: previous.cache_read_input_tokens, + current_cache_read_input_tokens: current.cache_read_input_tokens, + token_drop: previous + .cache_read_input_tokens + .saturating_sub(current.cache_read_input_tokens), + }); + } + let token_drop = previous + .cache_read_input_tokens + .saturating_sub(current.cache_read_input_tokens); + if token_drop < config.cache_break_min_drop { + return None; + } + + let mut reasons = Vec::new(); + if previous.model_hash != current.model_hash { + reasons.push("model changed"); + } + if previous.system_hash != current.system_hash { + reasons.push("system prompt changed"); + } + if previous.tools_hash != current.tools_hash { + reasons.push("tool definitions changed"); + } + if previous.messages_hash != current.messages_hash { + reasons.push("message payload changed"); + } + + let elapsed = current + .observed_at_unix_secs + .saturating_sub(previous.observed_at_unix_secs); + + let (unexpected, reason) = if reasons.is_empty() { + if elapsed > config.prompt_ttl.as_secs() { + ( + false, + format!("possible prompt cache TTL expiry after {elapsed}s"), + ) + } else { + ( + true, + "cache read tokens dropped while prompt fingerprint remained stable".to_string(), + ) + } + } else { + (false, reasons.join(", ")) + }; + + Some(CacheBreakEvent { + unexpected, + reason, + previous_cache_read_input_tokens: previous.cache_read_input_tokens, + current_cache_read_input_tokens: current.cache_read_input_tokens, + token_drop, + }) +} + +fn apply_usage_to_stats( + stats: &mut PromptCacheStats, + usage: &Usage, + request_hash: &str, + source: &str, +) { + stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens); + stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens); + stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens); + stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens); + stats.last_request_hash = Some(request_hash.to_string()); + stats.last_cache_source = Some(source.to_string()); +} + +fn persist_state(inner: &PromptCacheInner) { + let _ = ensure_cache_dirs(&inner.paths); + let _ = write_json(&inner.paths.stats_path, &inner.stats); + if let Some(previous) = &inner.previous { + let _ = write_json(&inner.paths.session_state_path, previous); + } +} + +fn write_completion_entry( + paths: &PromptCachePaths, + request_hash: &str, + response: &MessageResponse, +) { + let _ = ensure_cache_dirs(paths); + let entry = CompletionCacheEntry { + cached_at_unix_secs: now_unix_secs(), + fingerprint_version: current_fingerprint_version(), + response: response.clone(), + }; + let _ = write_json(&paths.completion_entry_path(request_hash), &entry); +} + +fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> { + fs::create_dir_all(&paths.completion_dir) +} + +fn write_json(path: &Path, value: &T) -> std::io::Result<()> { + let json = serde_json::to_vec_pretty(value) + .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; + fs::write(path, json) +} + +fn read_json Deserialize<'de>>(path: &Path) -> Option { + let bytes = fs::read(path).ok()?; + serde_json::from_slice(&bytes).ok() +} + +fn request_hash_hex(request: &MessageRequest) -> String { + format!( + "{REQUEST_FINGERPRINT_PREFIX}-{:016x}", + hash_serializable(request) + ) +} + +fn hash_serializable(value: &T) -> u64 { + let json = serde_json::to_vec(value).unwrap_or_default(); + stable_hash_bytes(&json) +} + +fn sanitize_path_segment(value: &str) -> String { + let sanitized: String = value + .chars() + .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' }) + .collect(); + if sanitized.len() <= MAX_SANITIZED_LENGTH { + return sanitized; + } + let suffix = format!("-{:x}", hash_string(value)); + format!( + "{}{}", + &sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())], + suffix + ) +} + +fn hash_string(value: &str) -> u64 { + stable_hash_bytes(value.as_bytes()) +} + +fn base_cache_root() -> PathBuf { + if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") { + return PathBuf::from(config_home) + .join("cache") + .join("prompt-cache"); + } + if let Some(home) = std::env::var_os("HOME") { + return PathBuf::from(home) + .join(".claude") + .join("cache") + .join("prompt-cache"); + } + std::env::temp_dir().join("claude-prompt-cache") +} + +fn now_unix_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +const fn current_fingerprint_version() -> u32 { + REQUEST_FINGERPRINT_VERSION +} + +fn stable_hash_bytes(bytes: &[u8]) -> u64 { + let mut hash = FNV_OFFSET_BASIS; + for byte in bytes { + hash ^= u64::from(*byte); + hash = hash.wrapping_mul(FNV_PRIME); + } + hash +} + +#[cfg(test)] +mod tests { + use std::sync::{Mutex, OnceLock}; + use std::time::{SystemTime, UNIX_EPOCH}; + + use super::{ + detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache, + PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX, + }; + use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + #[test] + fn path_builder_sanitizes_session_identifier() { + let paths = PromptCachePaths::for_session("session:/with spaces"); + let session_dir = paths + .session_dir + .file_name() + .and_then(|value| value.to_str()) + .expect("session dir name"); + assert_eq!(session_dir, "session--with-spaces"); + assert!(paths.completion_dir.ends_with("completions")); + assert!(paths.stats_path.ends_with("stats.json")); + assert!(paths.session_state_path.ends_with("session-state.json")); + } + + #[test] + fn request_fingerprint_drives_unexpected_break_detection() { + let request = sample_request("same"); + let previous = TrackedPromptState::from_usage( + &request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 6_000, + output_tokens: 0, + }, + ); + let current = TrackedPromptState::from_usage( + &request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 1_000, + output_tokens: 0, + }, + ); + let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) + .expect("break should be detected"); + assert!(event.unexpected); + assert!(event.reason.contains("stable")); + } + + #[test] + fn changed_prompt_marks_break_as_expected() { + let previous_request = sample_request("first"); + let current_request = sample_request("second"); + let previous = TrackedPromptState::from_usage( + &previous_request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 6_000, + output_tokens: 0, + }, + ); + let current = TrackedPromptState::from_usage( + ¤t_request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 1_000, + output_tokens: 0, + }, + ); + let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) + .expect("break should be detected"); + assert!(!event.unexpected); + assert!(event.reason.contains("message payload changed")); + } + + #[test] + fn completion_cache_round_trip_persists_recent_response() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::new("unit-test-session"); + let request = sample_request("cache me"); + let response = sample_response(42, 12, "cached"); + + assert!(cache.lookup_completion(&request).is_none()); + let record = cache.record_response(&request, &response); + assert!(record.cache_break.is_none()); + + let cached = cache + .lookup_completion(&request) + .expect("cached response should load"); + assert_eq!(cached.content, response.content); + + let stats = cache.stats(); + assert_eq!(stats.completion_cache_hits, 1); + assert_eq!(stats.completion_cache_misses, 1); + assert_eq!(stats.completion_cache_writes, 1); + + let persisted = read_json::(&cache.paths().stats_path) + .expect("stats should persist"); + assert_eq!(persisted.completion_cache_hits, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn sanitize_path_caps_long_values() { + let long_value = "x".repeat(200); + let sanitized = sanitize_path_segment(&long_value); + assert!(sanitized.len() <= 80); + } + + #[test] + fn request_hashes_are_versioned_and_stable() { + let request = sample_request("stable"); + let first = request_hash_hex(&request); + let second = request_hash_hex(&request); + assert_eq!(first, second); + assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX)); + } + + fn sample_request(text: &str) -> MessageRequest { + MessageRequest { + model: "claude-3-7-sonnet-latest".to_string(), + max_tokens: 64, + messages: vec![InputMessage::user_text(text)], + system: Some("system".to_string()), + tools: None, + tool_choice: None, + stream: false, + } + } + + fn sample_response( + cache_read_input_tokens: u32, + output_tokens: u32, + text: &str, + ) -> MessageResponse { + MessageResponse { + id: "msg_test".to_string(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::Text { + text: text.to_string(), + }], + model: "claude-3-7-sonnet-latest".to_string(), + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 10, + cache_creation_input_tokens: 5, + cache_read_input_tokens, + output_tokens, + }, + request_id: Some("req_test".to_string()), + } + } +} diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index c37fa99..9f59710 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -1,17 +1,25 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; use std::time::Duration; use api::{ AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock, - StreamEvent, ToolChoice, ToolDefinition, + PromptCache, StreamEvent, ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::sync::Mutex; +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) +} + #[tokio::test] async fn send_message_posts_json_and_parses_response() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -77,6 +85,16 @@ async fn send_message_posts_json_and_parses_response() { #[tokio::test] async fn stream_message_parses_sse_events_with_tool_use() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-stream-cache-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); let state = Arc::new(Mutex::new(Vec::::new())); let sse = concat!( "event: message_start\n", @@ -106,7 +124,8 @@ async fn stream_message_parses_sse_events_with_tool_use() { let client = AnthropicClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) - .with_base_url(server.base_url()); + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::new("stream-session")); let mut stream = client .stream_message(&sample_request(false)) .await @@ -160,6 +179,16 @@ async fn stream_message_parses_sse_events_with_tool_use() { let captured = state.lock().await; let request = captured.first().expect("server should capture request"); assert!(request.body.contains("\"stream\":true")); + + let stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(stats.tracked_requests, 1); + assert_eq!(stats.last_cache_read_input_tokens, Some(0)); + assert_eq!(stats.last_cache_source.as_deref(), Some("api-response")); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); } #[tokio::test] @@ -243,6 +272,119 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { } } +#[tokio::test] +async fn send_message_reuses_recent_completion_cache_entries() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-prompt-cache-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}", + )], + ) + .await; + + let client = AnthropicClient::new("test-key") + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::new("integration-session")); + + let first = client + .send_message(&sample_request(false)) + .await + .expect("first request should succeed"); + let second = client + .send_message(&sample_request(false)) + .await + .expect("second request should reuse cache"); + + assert_eq!(first.content, second.content); + assert_eq!(state.lock().await.len(), 1); + + let stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(stats.completion_cache_hits, 1); + assert_eq!(stats.completion_cache_misses, 1); + assert_eq!(stats.completion_cache_writes, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); +} + +#[tokio::test] +async fn send_message_tracks_unexpected_prompt_cache_breaks() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-prompt-break-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state, + vec![ + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}", + ), + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}", + ), + ], + ) + .await; + + let request = sample_request(false); + let client = AnthropicClient::new("test-key") + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::with_config(api::PromptCacheConfig { + session_id: "break-session".to_string(), + completion_ttl: Duration::from_secs(0), + ..api::PromptCacheConfig::default() + })); + + client + .send_message(&request) + .await + .expect("first response should succeed"); + client + .send_message(&request) + .await + .expect("second response should succeed"); + + let stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(stats.unexpected_cache_breaks, 1); + assert_eq!( + stats.last_break_reason.as_deref(), + Some("cache read tokens dropped while prompt fingerprint remained stable") + ); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); +} + #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY and network access"] async fn live_stream_smoke_test() { From 26344c578b460e230da64bd7afd6399263cd9cf8 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:40:17 +0000 Subject: [PATCH 2/4] wip: cache-tracking progress --- rust/crates/api/src/client.rs | 29 ++---- rust/crates/api/src/lib.rs | 8 ++ rust/crates/api/src/prompt_cache.rs | 106 ++++++++++++++------ rust/crates/api/tests/client_integration.rs | 28 ++++-- 4 files changed, 113 insertions(+), 58 deletions(-) diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 4d264b5..f90eaf8 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -689,7 +689,6 @@ mod tests { use std::io::{Read, Write}; use std::net::TcpListener; use std::sync::atomic::{AtomicU64, Ordering}; - use std::sync::{Mutex, OnceLock}; use std::thread; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -699,15 +698,9 @@ mod tests { now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, }; + use crate::test_env_lock; use crate::types::{ContentBlockDelta, MessageRequest}; - fn env_lock() -> std::sync::MutexGuard<'static, ()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - } - fn temp_config_home() -> std::path::PathBuf { static NEXT_ID: AtomicU64 = AtomicU64::new(0); std::env::temp_dir().join(format!( @@ -753,7 +746,7 @@ mod tests { #[test] fn read_api_key_requires_presence() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_API_KEY"); std::env::remove_var("CLAUDE_CONFIG_HOME"); @@ -763,7 +756,7 @@ mod tests { #[test] fn read_api_key_requires_non_empty_value() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); std::env::remove_var("ANTHROPIC_API_KEY"); let error = super::read_api_key().expect_err("empty key should error"); @@ -773,7 +766,7 @@ mod tests { #[test] fn read_api_key_prefers_api_key_env() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); assert_eq!( @@ -786,7 +779,7 @@ mod tests { #[test] fn read_auth_token_reads_auth_token_env() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -806,7 +799,7 @@ mod tests { #[test] fn auth_source_from_env_combines_api_key_and_bearer_token() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); let auth = AuthSource::from_env().expect("env auth"); @@ -818,7 +811,7 @@ mod tests { #[test] fn auth_source_from_saved_oauth_when_env_absent() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -857,7 +850,7 @@ mod tests { #[test] fn resolve_saved_oauth_token_refreshes_expired_credentials() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -889,7 +882,7 @@ mod tests { #[test] fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -913,7 +906,7 @@ mod tests { #[test] fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -945,7 +938,7 @@ mod tests { #[test] fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 43e2ffa..fc6ab87 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -20,3 +20,11 @@ pub use types::{ MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, }; + +#[cfg(test)] +pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: std::sync::OnceLock> = std::sync::OnceLock::new(); + LOCK.get_or_init(|| std::sync::Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) +} diff --git a/rust/crates/api/src/prompt_cache.rs b/rust/crates/api/src/prompt_cache.rs index 5a6a7da..be7cb83 100644 --- a/rust/crates/api/src/prompt_cache.rs +++ b/rust/crates/api/src/prompt_cache.rs @@ -141,6 +141,7 @@ impl PromptCache { self.lock().stats.clone() } + #[must_use] pub fn lookup_completion(&self, request: &MessageRequest) -> Option { let request_hash = request_hash_hex(request); let (paths, ttl) = { @@ -191,6 +192,7 @@ impl PromptCache { Some(entry.response) } + #[must_use] pub fn record_response( &self, request: &MessageRequest, @@ -199,6 +201,7 @@ impl PromptCache { self.record_usage_internal(request, &response.usage, Some(response)) } + #[must_use] pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord { self.record_usage_internal(request, usage, None) } @@ -267,7 +270,6 @@ struct TrackedPromptState { observed_at_unix_secs: u64, #[serde(default = "current_fingerprint_version")] fingerprint_version: u32, - request_hash: u64, model_hash: u64, system_hash: u64, tools_hash: u64, @@ -277,37 +279,34 @@ struct TrackedPromptState { impl TrackedPromptState { fn from_usage(request: &MessageRequest, usage: &Usage) -> Self { - let hashes = RequestHashes::from_request(request); + let hashes = RequestFingerprints::from_request(request); Self { observed_at_unix_secs: now_unix_secs(), fingerprint_version: current_fingerprint_version(), - request_hash: hashes.request_hash, - model_hash: hashes.model_hash, - system_hash: hashes.system_hash, - tools_hash: hashes.tools_hash, - messages_hash: hashes.messages_hash, + model_hash: hashes.model, + system_hash: hashes.system, + tools_hash: hashes.tools, + messages_hash: hashes.messages, cache_read_input_tokens: usage.cache_read_input_tokens, } } } #[derive(Debug, Clone, Copy)] -struct RequestHashes { - request_hash: u64, - model_hash: u64, - system_hash: u64, - tools_hash: u64, - messages_hash: u64, +struct RequestFingerprints { + model: u64, + system: u64, + tools: u64, + messages: u64, } -impl RequestHashes { +impl RequestFingerprints { fn from_request(request: &MessageRequest) -> Self { Self { - request_hash: hash_serializable(request), - model_hash: hash_serializable(&request.model), - system_hash: hash_serializable(&request.system), - tools_hash: hash_serializable(&request.tools), - messages_hash: hash_serializable(&request.messages), + model: hash_serializable(&request.model), + system: hash_serializable(&request.system), + tools: hash_serializable(&request.tools), + messages: hash_serializable(&request.messages), } } } @@ -501,22 +500,15 @@ fn stable_hash_bytes(bytes: &[u8]) -> u64 { #[cfg(test)] mod tests { - use std::sync::{Mutex, OnceLock}; - use std::time::{SystemTime, UNIX_EPOCH}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; use super::{ detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache, PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX, }; + use crate::test_env_lock; use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage}; - fn env_lock() -> std::sync::MutexGuard<'static, ()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - } - #[test] fn path_builder_sanitizes_session_identifier() { let paths = PromptCachePaths::for_session("session:/with spaces"); @@ -588,7 +580,7 @@ mod tests { #[test] fn completion_cache_round_trip_persists_recent_response() { - let _guard = env_lock(); + let _guard = test_env_lock(); let temp_root = std::env::temp_dir().join(format!( "prompt-cache-test-{}-{}", std::process::id(), @@ -624,6 +616,62 @@ mod tests { std::env::remove_var("CLAUDE_CONFIG_HOME"); } + #[test] + fn distinct_requests_do_not_collide_in_completion_cache() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-distinct-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::new("distinct-request-session"); + let first_request = sample_request("first"); + let second_request = sample_request("second"); + + let response = sample_response(42, 12, "cached"); + let _ = cache.record_response(&first_request, &response); + + assert!(cache.lookup_completion(&second_request).is_none()); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn expired_completion_entries_are_not_reused() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-expired-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::with_config(PromptCacheConfig { + session_id: "expired-session".to_string(), + completion_ttl: Duration::ZERO, + ..PromptCacheConfig::default() + }); + let request = sample_request("expire me"); + let response = sample_response(7, 3, "stale"); + + let _ = cache.record_response(&request, &response); + + assert!(cache.lookup_completion(&request).is_none()); + let stats = cache.stats(); + assert_eq!(stats.completion_cache_hits, 0); + assert_eq!(stats.completion_cache_misses, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + #[test] fn sanitize_path_caps_long_values() { let long_value = "x".repeat(200); diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index 9f59710..1444156 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -84,6 +84,7 @@ async fn send_message_posts_json_and_parses_response() { } #[tokio::test] +#[allow(clippy::await_holding_lock)] async fn stream_message_parses_sse_events_with_tool_use() { let _guard = env_lock(); let temp_root = std::env::temp_dir().join(format!( @@ -180,12 +181,15 @@ async fn stream_message_parses_sse_events_with_tool_use() { let request = captured.first().expect("server should capture request"); assert!(request.body.contains("\"stream\":true")); - let stats = client + let cache_stats = client .prompt_cache_stats() .expect("prompt cache stats should exist"); - assert_eq!(stats.tracked_requests, 1); - assert_eq!(stats.last_cache_read_input_tokens, Some(0)); - assert_eq!(stats.last_cache_source.as_deref(), Some("api-response")); + assert_eq!(cache_stats.tracked_requests, 1); + assert_eq!(cache_stats.last_cache_read_input_tokens, Some(0)); + assert_eq!( + cache_stats.last_cache_source.as_deref(), + Some("api-response") + ); std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); std::env::remove_var("CLAUDE_CONFIG_HOME"); @@ -273,6 +277,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { } #[tokio::test] +#[allow(clippy::await_holding_lock)] async fn send_message_reuses_recent_completion_cache_entries() { let _guard = env_lock(); let temp_root = std::env::temp_dir().join(format!( @@ -312,18 +317,19 @@ async fn send_message_reuses_recent_completion_cache_entries() { assert_eq!(first.content, second.content); assert_eq!(state.lock().await.len(), 1); - let stats = client + let cache_stats = client .prompt_cache_stats() .expect("prompt cache stats should exist"); - assert_eq!(stats.completion_cache_hits, 1); - assert_eq!(stats.completion_cache_misses, 1); - assert_eq!(stats.completion_cache_writes, 1); + assert_eq!(cache_stats.completion_cache_hits, 1); + assert_eq!(cache_stats.completion_cache_misses, 1); + assert_eq!(cache_stats.completion_cache_writes, 1); std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); std::env::remove_var("CLAUDE_CONFIG_HOME"); } #[tokio::test] +#[allow(clippy::await_holding_lock)] async fn send_message_tracks_unexpected_prompt_cache_breaks() { let _guard = env_lock(); let temp_root = std::env::temp_dir().join(format!( @@ -372,12 +378,12 @@ async fn send_message_tracks_unexpected_prompt_cache_breaks() { .await .expect("second response should succeed"); - let stats = client + let cache_stats = client .prompt_cache_stats() .expect("prompt cache stats should exist"); - assert_eq!(stats.unexpected_cache_breaks, 1); + assert_eq!(cache_stats.unexpected_cache_breaks, 1); assert_eq!( - stats.last_break_reason.as_deref(), + cache_stats.last_break_reason.as_deref(), Some("cache read tokens dropped while prompt fingerprint remained stable") ); From c9d214c8d142a306710132f5c84b8cba53a5e0e6 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 06:15:13 +0000 Subject: [PATCH 3/4] feat: cache-tracking progress --- rust/crates/api/src/client.rs | 48 +++++++- rust/crates/runtime/src/conversation.rs | 54 +++++++-- rust/crates/runtime/src/hooks.rs | 8 +- rust/crates/runtime/src/lib.rs | 4 +- rust/crates/rusty-claude-cli/src/main.rs | 123 +++++++++++++++++---- rust/crates/rusty-claude-cli/src/render.rs | 4 +- rust/crates/tools/src/lib.rs | 49 ++++++-- 7 files changed, 238 insertions(+), 52 deletions(-) diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index f90eaf8..c7aca3f 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -1,4 +1,5 @@ use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use runtime::{ @@ -8,7 +9,7 @@ use runtime::{ use serde::Deserialize; use crate::error::ApiError; -use crate::prompt_cache::{PromptCache, PromptCacheStats}; +use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; use crate::sse::SseParser; use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage}; @@ -110,6 +111,7 @@ pub struct AnthropicClient { initial_backoff: Duration, max_backoff: Duration, prompt_cache: Option, + last_prompt_cache_record: Arc>>, } impl AnthropicClient { @@ -123,6 +125,7 @@ impl AnthropicClient { initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, prompt_cache: None, + last_prompt_cache_record: Arc::new(Mutex::new(None)), } } @@ -136,6 +139,7 @@ impl AnthropicClient { initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, prompt_cache: None, + last_prompt_cache_record: Arc::new(Mutex::new(None)), } } @@ -209,6 +213,14 @@ impl AnthropicClient { self.prompt_cache.as_ref().map(PromptCache::stats) } + #[must_use] + pub fn take_last_prompt_cache_record(&self) -> Option { + self.last_prompt_cache_record() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take() + } + #[must_use] pub fn auth_source(&self) -> &AuthSource { &self.auth @@ -218,12 +230,16 @@ impl AnthropicClient { &self, request: &MessageRequest, ) -> Result { + self.store_last_prompt_cache_record(None); let request = MessageRequest { stream: false, ..request.clone() }; if let Some(prompt_cache) = &self.prompt_cache { if let Some(response) = prompt_cache.lookup_completion(&request) { + self.store_last_prompt_cache_record(Some(prompt_cache_record_from_stats( + prompt_cache.stats(), + ))); return Ok(response); } } @@ -237,7 +253,8 @@ impl AnthropicClient { response.request_id = request_id; } if let Some(prompt_cache) = &self.prompt_cache { - let _ = prompt_cache.record_response(&request, &response); + let record = prompt_cache.record_response(&request, &response); + self.store_last_prompt_cache_record(Some(record)); } Ok(response) } @@ -246,6 +263,7 @@ impl AnthropicClient { &self, request: &MessageRequest, ) -> Result { + self.store_last_prompt_cache_record(None); let response = self .send_with_retry(&request.clone().with_streaming()) .await?; @@ -263,10 +281,22 @@ impl AnthropicClient { request: request.clone().with_streaming(), last_usage: None, finalized: false, + last_record: self.last_prompt_cache_record.clone(), }), }) } + fn store_last_prompt_cache_record(&self, record: Option) { + *self + .last_prompt_cache_record() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = record; + } + + fn last_prompt_cache_record(&self) -> &Arc>> { + &self.last_prompt_cache_record + } + pub async fn exchange_oauth_code( &self, config: &OAuthConfig, @@ -615,6 +645,7 @@ struct StreamCacheTracking { request: MessageRequest, last_usage: Option, finalized: bool, + last_record: Arc>>, } impl StreamCacheTracking { @@ -638,12 +669,23 @@ impl StreamCacheTracking { return; } if let Some(usage) = &self.last_usage { - let _ = self.prompt_cache.record_usage(&self.request, usage); + let record = self.prompt_cache.record_usage(&self.request, usage); + *self + .last_record + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record); } self.finalized = true; } } +fn prompt_cache_record_from_stats(stats: PromptCacheStats) -> PromptCacheRecord { + PromptCacheRecord { + cache_break: None, + stats, + } +} + async fn expect_success(response: reqwest::Response) -> Result { let status = response.status(); if status.is_success() { diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index 4ffbabc..00dbf54 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -25,9 +25,19 @@ pub enum AssistantEvent { input: String, }, Usage(TokenUsage), + PromptCache(PromptCacheEvent), MessageStop, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PromptCacheEvent { + pub unexpected: bool, + pub reason: String, + pub previous_cache_read_input_tokens: u32, + pub current_cache_read_input_tokens: u32, + pub token_drop: u32, +} + pub trait ApiClient { fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError>; } @@ -84,6 +94,7 @@ impl std::error::Error for RuntimeError {} pub struct TurnSummary { pub assistant_messages: Vec, pub tool_results: Vec, + pub prompt_cache_events: Vec, pub iterations: usize, pub usage: TokenUsage, } @@ -118,7 +129,7 @@ where tool_executor, permission_policy, system_prompt, - RuntimeFeatureConfig::default(), + &RuntimeFeatureConfig::default(), ) } @@ -129,7 +140,7 @@ where tool_executor: T, permission_policy: PermissionPolicy, system_prompt: Vec, - feature_config: RuntimeFeatureConfig, + feature_config: &RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); Self { @@ -140,7 +151,7 @@ where system_prompt, max_iterations: usize::MAX, usage_tracker, - hook_runner: HookRunner::from_feature_config(&feature_config), + hook_runner: HookRunner::from_feature_config(feature_config), } } @@ -161,6 +172,7 @@ where let mut assistant_messages = Vec::new(); let mut tool_results = Vec::new(); + let mut prompt_cache_events = Vec::new(); let mut iterations = 0; loop { @@ -176,10 +188,12 @@ where messages: self.session.messages.clone(), }; let events = self.api_client.stream(request)?; - let (assistant_message, usage) = build_assistant_message(events)?; + let (assistant_message, usage, turn_prompt_cache_events) = + build_assistant_message(events)?; if let Some(usage) = usage { self.usage_tracker.record(usage); } + prompt_cache_events.extend(turn_prompt_cache_events); let pending_tool_uses = assistant_message .blocks .iter() @@ -257,6 +271,7 @@ where Ok(TurnSummary { assistant_messages, tool_results, + prompt_cache_events, iterations, usage: self.usage_tracker.cumulative_usage(), }) @@ -290,9 +305,17 @@ where fn build_assistant_message( events: Vec, -) -> Result<(ConversationMessage, Option), RuntimeError> { +) -> Result< + ( + ConversationMessage, + Option, + Vec, + ), + RuntimeError, +> { let mut text = String::new(); let mut blocks = Vec::new(); + let mut prompt_cache_events = Vec::new(); let mut finished = false; let mut usage = None; @@ -304,6 +327,7 @@ fn build_assistant_message( blocks.push(ContentBlock::ToolUse { id, name, input }); } AssistantEvent::Usage(value) => usage = Some(value), + AssistantEvent::PromptCache(event) => prompt_cache_events.push(event), AssistantEvent::MessageStop => { finished = true; } @@ -324,6 +348,7 @@ fn build_assistant_message( Ok(( ConversationMessage::assistant_with_usage(blocks, usage), usage, + prompt_cache_events, )) } @@ -396,7 +421,7 @@ impl ToolExecutor for StaticToolExecutor { #[cfg(test)] mod tests { use super::{ - ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, + ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError, StaticToolExecutor, }; use crate::compact::CompactionConfig; @@ -453,6 +478,15 @@ mod tests { cache_creation_input_tokens: 1, cache_read_input_tokens: 3, }), + AssistantEvent::PromptCache(PromptCacheEvent { + unexpected: true, + reason: + "cache read tokens dropped while prompt fingerprint remained stable" + .to_string(), + previous_cache_read_input_tokens: 6_000, + current_cache_read_input_tokens: 1_000, + token_drop: 5_000, + }), AssistantEvent::MessageStop, ]) } @@ -506,8 +540,10 @@ mod tests { assert_eq!(summary.iterations, 2); assert_eq!(summary.assistant_messages.len(), 2); assert_eq!(summary.tool_results.len(), 1); + assert_eq!(summary.prompt_cache_events.len(), 1); assert_eq!(runtime.session().messages.len(), 4); assert_eq!(summary.usage.output_tokens, 10); + assert!(summary.prompt_cache_events[0].unexpected); assert!(matches!( runtime.session().messages[1].blocks[1], ContentBlock::ToolUse { .. } @@ -609,7 +645,7 @@ mod tests { }), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), )), @@ -675,7 +711,7 @@ mod tests { StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'post hook ran'")], )), @@ -697,7 +733,7 @@ mod tests { "post hook should preserve non-error result: {output:?}" ); assert!( - output.contains("4"), + output.contains('4'), "tool output missing value: {output:?}" ); assert!( diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs index 36756a0..80770ba 100644 --- a/rust/crates/runtime/src/hooks.rs +++ b/rust/crates/runtime/src/hooks.rs @@ -64,7 +64,7 @@ impl HookRunner { #[must_use] pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { - self.run_commands( + Self::run_commands( HookEvent::PreToolUse, self.config.pre_tool_use(), tool_name, @@ -82,7 +82,7 @@ impl HookRunner { tool_output: &str, is_error: bool, ) -> HookRunResult { - self.run_commands( + Self::run_commands( HookEvent::PostToolUse, self.config.post_tool_use(), tool_name, @@ -93,7 +93,6 @@ impl HookRunner { } fn run_commands( - &self, event: HookEvent, commands: &[String], tool_name: &str, @@ -118,7 +117,7 @@ impl HookRunner { let mut messages = Vec::new(); for command in commands { - match self.run_command( + match Self::run_command( command, event, tool_name, @@ -150,7 +149,6 @@ impl HookRunner { } fn run_command( - &self, command: &str, event: HookEvent, tool_name: &str, diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index da745e5..856f9f5 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -31,8 +31,8 @@ pub use config::{ ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME, }; pub use conversation::{ - ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor, - ToolError, ToolExecutor, TurnSummary, + ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError, + StaticToolExecutor, ToolError, ToolExecutor, TurnSummary, }; pub use file_ops::{ edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput, diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 5f8a7a6..dcce2b0 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -13,8 +13,9 @@ use std::time::{SystemTime, UNIX_EPOCH}; use api::{ resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, - InputMessage, MessageRequest, MessageResponse, OutputContentBlock, - StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, + InputMessage, MessageRequest, MessageResponse, OutputContentBlock, PromptCache, + PromptCacheRecord, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, + ToolResultContentBlock, }; use commands::{ @@ -28,8 +29,8 @@ use runtime::{ parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig, - OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, - Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, + OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent, + RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, }; use serde_json::json; use tools::{execute_tool, mvp_tool_specs, ToolSpec}; @@ -995,6 +996,7 @@ impl LiveCli { let session = create_managed_session_handle()?; let runtime = build_runtime( Session::new(), + session.id.clone(), model.clone(), system_prompt.clone(), enable_tools, @@ -1050,13 +1052,14 @@ impl LiveCli { let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); match result { - Ok(_) => { + Ok(summary) => { spinner.finish( "✨ Done", TerminalRenderer::new().color_theme(), &mut stdout, )?; println!(); + print_prompt_cache_events(&summary); self.persist_session()?; Ok(()) } @@ -1086,6 +1089,7 @@ impl LiveCli { let session = self.runtime.session().clone(); let mut runtime = build_runtime( session, + self.session.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1105,6 +1109,7 @@ impl LiveCli { "iterations": summary.iterations, "tool_uses": collect_tool_uses(&summary), "tool_results": collect_tool_results(&summary), + "prompt_cache_events": collect_prompt_cache_events(&summary), "usage": { "input_tokens": summary.usage.input_tokens, "output_tokens": summary.usage.output_tokens, @@ -1232,6 +1237,7 @@ impl LiveCli { let message_count = session.messages.len(); self.runtime = build_runtime( session, + self.session.id.clone(), model.clone(), self.system_prompt.clone(), true, @@ -1275,6 +1281,7 @@ impl LiveCli { self.permission_mode = permission_mode_from_label(normalized); self.runtime = build_runtime( session, + self.session.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1300,6 +1307,7 @@ impl LiveCli { self.session = create_managed_session_handle()?; self.runtime = build_runtime( Session::new(), + self.session.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1335,6 +1343,7 @@ impl LiveCli { let message_count = session.messages.len(); self.runtime = build_runtime( session, + handle.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1407,6 +1416,7 @@ impl LiveCli { let message_count = session.messages.len(); self.runtime = build_runtime( session, + handle.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1437,6 +1447,7 @@ impl LiveCli { let skipped = removed == 0; self.runtime = build_runtime( result.compacted_session, + self.session.id.clone(), self.model.clone(), self.system_prompt.clone(), true, @@ -1912,8 +1923,10 @@ fn build_runtime_feature_config( .clone()) } +#[allow(clippy::too_many_arguments)] fn build_runtime( session: Session, + session_id: String, model: String, system_prompt: Vec, enable_tools: bool, @@ -1924,11 +1937,17 @@ fn build_runtime( { Ok(ConversationRuntime::new_with_features( session, - AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, + AnthropicRuntimeClient::new( + model, + enable_tools, + emit_output, + allowed_tools.clone(), + session_id, + )?, CliToolExecutor::new(allowed_tools, emit_output), permission_policy(permission_mode), system_prompt, - build_runtime_feature_config()?, + &build_runtime_feature_config()?, )) } @@ -1993,11 +2012,13 @@ impl AnthropicRuntimeClient { enable_tools: bool, emit_output: bool, allowed_tools: Option, + session_id: impl Into, ) -> Result> { Ok(Self { runtime: tokio::runtime::Runtime::new()?, client: AnthropicClient::from_auth(resolve_cli_auth_source()?) - .with_base_url(api::read_base_url()), + .with_base_url(api::read_base_url()) + .with_prompt_cache(PromptCache::new(session_id)), model, enable_tools, emit_output, @@ -2112,8 +2133,8 @@ impl ApiClient for AnthropicRuntimeClient { events.push(AssistantEvent::Usage(TokenUsage { input_tokens: delta.usage.input_tokens, output_tokens: delta.usage.output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, + cache_creation_input_tokens: delta.usage.cache_creation_input_tokens, + cache_read_input_tokens: delta.usage.cache_read_input_tokens, })); } ApiStreamEvent::MessageStop(_) => { @@ -2128,6 +2149,8 @@ impl ApiClient for AnthropicRuntimeClient { } } + push_prompt_cache_record(&self.client, &mut events); + if !saw_stop && events.iter().any(|event| { matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) @@ -2152,7 +2175,9 @@ impl ApiClient for AnthropicRuntimeClient { }) .await .map_err(|error| RuntimeError::new(error.to_string()))?; - response_to_events(response, out) + let mut events = response_to_events(response, out)?; + push_prompt_cache_record(&self.client, &mut events); + Ok(events) }) } } @@ -2213,6 +2238,39 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec Vec { + summary + .prompt_cache_events + .iter() + .map(|event| { + json!({ + "unexpected": event.unexpected, + "reason": event.reason, + "previous_cache_read_input_tokens": event.previous_cache_read_input_tokens, + "current_cache_read_input_tokens": event.current_cache_read_input_tokens, + "token_drop": event.token_drop, + }) + }) + .collect() +} + +fn print_prompt_cache_events(summary: &runtime::TurnSummary) { + for event in &summary.prompt_cache_events { + let label = if event.unexpected { + "Prompt cache break" + } else { + "Prompt cache invalidation" + }; + println!( + "{label}: {} (cache read {} -> {}, drop {})", + event.reason, + event.previous_cache_read_input_tokens, + event.current_cache_read_input_tokens, + event.token_drop, + ); + } +} + fn slash_command_completion_candidates() -> Vec { slash_command_specs() .iter() @@ -2359,18 +2417,20 @@ fn first_visible_line(text: &str) -> &str { } fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { + use std::fmt::Write as _; + let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")]; if let Some(task_id) = parsed .get("backgroundTaskId") .and_then(|value| value.as_str()) { - lines[0].push_str(&format!(" backgrounded ({task_id})")); + let _ = write!(lines[0], " backgrounded ({task_id})"); } else if let Some(status) = parsed .get("returnCodeInterpretation") .and_then(|value| value.as_str()) .filter(|status| !status.is_empty()) { - lines[0].push_str(&format!(" {status}")); + let _ = write!(lines[0], " {status}"); } if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { @@ -2392,15 +2452,15 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String { let path = extract_tool_path(file); let start_line = file .get("startLine") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(1); let num_lines = file .get("numLines") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let total_lines = file .get("totalLines") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(num_lines); let content = file .get("content") @@ -2426,8 +2486,7 @@ fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String { let line_count = parsed .get("content") .and_then(|value| value.as_str()) - .map(|content| content.lines().count()) - .unwrap_or(0); + .map_or(0, |content| content.lines().count()); format!( "{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", if kind == "create" { "Wrote" } else { "Updated" }, @@ -2458,7 +2517,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { let path = extract_tool_path(parsed); let suffix = if parsed .get("replaceAll") - .and_then(|value| value.as_bool()) + .and_then(serde_json::Value::as_bool) .unwrap_or(false) { " (replace all)" @@ -2486,7 +2545,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { let num_files = parsed .get("numFiles") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let filenames = parsed .get("filenames") @@ -2510,11 +2569,11 @@ fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { let num_matches = parsed .get("numMatches") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let num_files = parsed .get("numFiles") - .and_then(|value| value.as_u64()) + .and_then(serde_json::Value::as_u64) .unwrap_or(0); let content = parsed .get("content") @@ -2621,6 +2680,26 @@ fn response_to_events( Ok(events) } +fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec) { + if let Some(event) = client + .take_last_prompt_cache_record() + .and_then(prompt_cache_record_to_runtime_event) + { + events.push(AssistantEvent::PromptCache(event)); + } +} + +fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option { + let cache_break = record.cache_break?; + Some(PromptCacheEvent { + unexpected: cache_break.unexpected, + reason: cache_break.reason, + previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens, + current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens, + token_drop: cache_break.token_drop, + }) +} + struct CliToolExecutor { renderer: TerminalRenderer, emit_output: bool, diff --git a/rust/crates/rusty-claude-cli/src/render.rs b/rust/crates/rusty-claude-cli/src/render.rs index 465c5a4..d8d8796 100644 --- a/rust/crates/rusty-claude-cli/src/render.rs +++ b/rust/crates/rusty-claude-cli/src/render.rs @@ -286,7 +286,7 @@ impl TerminalRenderer { ) { match event { Event::Start(Tag::Heading { level, .. }) => { - self.start_heading(state, level as u8, output) + Self::start_heading(state, level as u8, output); } Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), @@ -426,7 +426,7 @@ impl TerminalRenderer { } } - fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) { + fn start_heading(state: &mut RenderState, level: u8, output: &mut String) { state.heading_level = Some(level); if !output.is_empty() { output.push('\n'); diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 8dcd33d..be11e6b 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -5,15 +5,15 @@ use std::time::{Duration, Instant}; use api::{ read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage, - MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, - ToolDefinition, ToolResultContentBlock, + MessageRequest, MessageResponse, OutputContentBlock, PromptCache, PromptCacheRecord, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use reqwest::blocking::Client; use runtime::{ edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, - RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, + PromptCacheEvent, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -1466,7 +1466,8 @@ fn build_agent_runtime( .clone() .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); let allowed_tools = job.allowed_tools.clone(); - let api_client = AnthropicRuntimeClient::new(model, allowed_tools.clone())?; + let api_client = + AnthropicRuntimeClient::new(model, allowed_tools.clone(), job.manifest.agent_id.clone())?; let tool_executor = SubagentToolExecutor::new(allowed_tools); Ok(ConversationRuntime::new( Session::new(), @@ -1643,10 +1644,15 @@ struct AnthropicRuntimeClient { } impl AnthropicRuntimeClient { - fn new(model: String, allowed_tools: BTreeSet) -> Result { + fn new( + model: String, + allowed_tools: BTreeSet, + session_id: impl Into, + ) -> Result { let client = AnthropicClient::from_env() .map_err(|error| error.to_string())? - .with_base_url(read_base_url()); + .with_base_url(read_base_url()) + .with_prompt_cache(PromptCache::new(session_id)); Ok(Self { runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, client, @@ -1657,6 +1663,7 @@ impl AnthropicRuntimeClient { } impl ApiClient for AnthropicRuntimeClient { + #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) .into_iter() @@ -1726,8 +1733,8 @@ impl ApiClient for AnthropicRuntimeClient { events.push(AssistantEvent::Usage(TokenUsage { input_tokens: delta.usage.input_tokens, output_tokens: delta.usage.output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, + cache_creation_input_tokens: delta.usage.cache_creation_input_tokens, + cache_read_input_tokens: delta.usage.cache_read_input_tokens, })); } ApiStreamEvent::MessageStop(_) => { @@ -1737,6 +1744,8 @@ impl ApiClient for AnthropicRuntimeClient { } } + push_prompt_cache_record(&self.client, &mut events); + if !saw_stop && events.iter().any(|event| { matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) @@ -1761,7 +1770,9 @@ impl ApiClient for AnthropicRuntimeClient { }) .await .map_err(|error| RuntimeError::new(error.to_string()))?; - Ok(response_to_events(response)) + let mut events = response_to_events(response); + push_prompt_cache_record(&self.client, &mut events); + Ok(events) }) } } @@ -1884,6 +1895,26 @@ fn response_to_events(response: MessageResponse) -> Vec { events } +fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec) { + if let Some(event) = client + .take_last_prompt_cache_record() + .and_then(prompt_cache_record_to_runtime_event) + { + events.push(AssistantEvent::PromptCache(event)); + } +} + +fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option { + let cache_break = record.cache_break?; + Some(PromptCacheEvent { + unexpected: cache_break.unexpected, + reason: cache_break.reason, + previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens, + current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens, + token_drop: cache_break.token_drop, + }) +} + fn final_assistant_text(summary: &runtime::TurnSummary) -> String { summary .assistant_messages From 799c92eadace8296fc3d94a844cf787707e73fdd Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 06:25:26 +0000 Subject: [PATCH 4/4] feat: cache-tracking progress --- rust/crates/api/tests/client_integration.rs | 42 +++++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index 1444156..69208f1 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -53,6 +53,8 @@ async fn send_message_posts_json_and_parses_response() { assert_eq!(response.id, "msg_test"); assert_eq!(response.total_tokens(), 16); assert_eq!(response.request_id.as_deref(), Some("req_body_123")); + assert_eq!(response.usage.cache_creation_input_tokens, 0); + assert_eq!(response.usage.cache_read_input_tokens, 0); assert_eq!( response.content, vec![OutputContentBlock::Text { @@ -83,6 +85,39 @@ async fn send_message_posts_json_and_parses_response() { assert_eq!(body["tool_choice"]["type"], json!("auto")); } +#[tokio::test] +async fn send_message_parses_prompt_cache_token_usage_from_response() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"msg_cache_tokens\",", + "\"type\":\"message\",", + "\"role\":\"assistant\",", + "\"content\":[{\"type\":\"text\",\"text\":\"Cache tokens\"}],", + "\"model\":\"claude-3-7-sonnet-latest\",", + "\"stop_reason\":\"end_turn\",", + "\"stop_sequence\":null,", + "\"usage\":{\"input_tokens\":12,\"cache_creation_input_tokens\":321,\"cache_read_input_tokens\":654,\"output_tokens\":4}", + "}" + ); + let server = spawn_server( + state, + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = AnthropicClient::new("test-key").with_base_url(server.base_url()); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.usage.input_tokens, 12); + assert_eq!(response.usage.cache_creation_input_tokens, 321); + assert_eq!(response.usage.cache_read_input_tokens, 654); + assert_eq!(response.usage.output_tokens, 4); +} + #[tokio::test] #[allow(clippy::await_holding_lock)] async fn stream_message_parses_sse_events_with_tool_use() { @@ -99,7 +134,7 @@ async fn stream_message_parses_sse_events_with_tool_use() { let state = Arc::new(Mutex::new(Vec::::new())); let sse = concat!( "event: message_start\n", - "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":13,\"cache_read_input_tokens\":21,\"output_tokens\":0}}}\n\n", "event: content_block_start\n", "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n", "event: content_block_delta\n", @@ -107,7 +142,7 @@ async fn stream_message_parses_sse_events_with_tool_use() { "event: content_block_stop\n", "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", "event: message_delta\n", - "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":34,\"cache_read_input_tokens\":55,\"output_tokens\":1}}\n\n", "event: message_stop\n", "data: {\"type\":\"message_stop\"}\n\n", "data: [DONE]\n\n" @@ -185,7 +220,8 @@ async fn stream_message_parses_sse_events_with_tool_use() { .prompt_cache_stats() .expect("prompt cache stats should exist"); assert_eq!(cache_stats.tracked_requests, 1); - assert_eq!(cache_stats.last_cache_read_input_tokens, Some(0)); + assert_eq!(cache_stats.last_cache_creation_input_tokens, Some(34)); + assert_eq!(cache_stats.last_cache_read_input_tokens, Some(55)); assert_eq!( cache_stats.last_cache_source.as_deref(), Some("api-response")