Merge remote-tracking branch 'origin/rcc/cache-tracking' into integration/dori-cleanroom

This commit is contained in:
YeonGyu-Kim
2026-04-02 11:17:13 +09:00
5 changed files with 1037 additions and 7 deletions

View File

@@ -0,0 +1,727 @@
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::types::{MessageRequest, MessageResponse, Usage};
const DEFAULT_COMPLETION_TTL_SECS: u64 = 30;
const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60;
const DEFAULT_BREAK_MIN_DROP: u32 = 2_000;
const MAX_SANITIZED_LENGTH: usize = 80;
const REQUEST_FINGERPRINT_VERSION: u32 = 1;
const REQUEST_FINGERPRINT_PREFIX: &str = "v1";
const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
#[derive(Debug, Clone)]
pub struct PromptCacheConfig {
pub session_id: String,
pub completion_ttl: Duration,
pub prompt_ttl: Duration,
pub cache_break_min_drop: u32,
}
impl PromptCacheConfig {
#[must_use]
pub fn new(session_id: impl Into<String>) -> Self {
Self {
session_id: session_id.into(),
completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS),
prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS),
cache_break_min_drop: DEFAULT_BREAK_MIN_DROP,
}
}
}
impl Default for PromptCacheConfig {
fn default() -> Self {
Self::new("default")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PromptCachePaths {
pub root: PathBuf,
pub session_dir: PathBuf,
pub completion_dir: PathBuf,
pub session_state_path: PathBuf,
pub stats_path: PathBuf,
}
impl PromptCachePaths {
#[must_use]
pub fn for_session(session_id: &str) -> Self {
let root = base_cache_root();
let session_dir = root.join(sanitize_path_segment(session_id));
let completion_dir = session_dir.join("completions");
Self {
root,
session_state_path: session_dir.join("session-state.json"),
stats_path: session_dir.join("stats.json"),
session_dir,
completion_dir,
}
}
#[must_use]
pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf {
self.completion_dir.join(format!("{request_hash}.json"))
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct PromptCacheStats {
pub tracked_requests: u64,
pub completion_cache_hits: u64,
pub completion_cache_misses: u64,
pub completion_cache_writes: u64,
pub expected_invalidations: u64,
pub unexpected_cache_breaks: u64,
pub total_cache_creation_input_tokens: u64,
pub total_cache_read_input_tokens: u64,
pub last_cache_creation_input_tokens: Option<u32>,
pub last_cache_read_input_tokens: Option<u32>,
pub last_request_hash: Option<String>,
pub last_completion_cache_key: Option<String>,
pub last_break_reason: Option<String>,
pub last_cache_source: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CacheBreakEvent {
pub unexpected: bool,
pub reason: String,
pub previous_cache_read_input_tokens: u32,
pub current_cache_read_input_tokens: u32,
pub token_drop: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptCacheRecord {
pub cache_break: Option<CacheBreakEvent>,
pub stats: PromptCacheStats,
}
#[derive(Debug, Clone)]
pub struct PromptCache {
inner: Arc<Mutex<PromptCacheInner>>,
}
impl PromptCache {
#[must_use]
pub fn new(session_id: impl Into<String>) -> Self {
Self::with_config(PromptCacheConfig::new(session_id))
}
#[must_use]
pub fn with_config(config: PromptCacheConfig) -> Self {
let paths = PromptCachePaths::for_session(&config.session_id);
let stats = read_json::<PromptCacheStats>(&paths.stats_path).unwrap_or_default();
let previous = read_json::<TrackedPromptState>(&paths.session_state_path);
Self {
inner: Arc::new(Mutex::new(PromptCacheInner {
config,
paths,
stats,
previous,
})),
}
}
#[must_use]
pub fn paths(&self) -> PromptCachePaths {
self.lock().paths.clone()
}
#[must_use]
pub fn stats(&self) -> PromptCacheStats {
self.lock().stats.clone()
}
#[must_use]
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
let request_hash = request_hash_hex(request);
let (paths, ttl) = {
let inner = self.lock();
(inner.paths.clone(), inner.config.completion_ttl)
};
let entry_path = paths.completion_entry_path(&request_hash);
let entry = read_json::<CompletionCacheEntry>(&entry_path);
let Some(entry) = entry else {
let mut inner = self.lock();
inner.stats.completion_cache_misses += 1;
inner.stats.last_completion_cache_key = Some(request_hash);
persist_state(&inner);
return None;
};
if entry.fingerprint_version != current_fingerprint_version() {
let mut inner = self.lock();
inner.stats.completion_cache_misses += 1;
inner.stats.last_completion_cache_key = Some(request_hash.clone());
let _ = fs::remove_file(entry_path);
persist_state(&inner);
return None;
}
let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs();
let mut inner = self.lock();
inner.stats.last_completion_cache_key = Some(request_hash.clone());
if expired {
inner.stats.completion_cache_misses += 1;
let _ = fs::remove_file(entry_path);
persist_state(&inner);
return None;
}
inner.stats.completion_cache_hits += 1;
apply_usage_to_stats(
&mut inner.stats,
&entry.response.usage,
&request_hash,
"completion-cache",
);
inner.previous = Some(TrackedPromptState::from_usage(
request,
&entry.response.usage,
));
persist_state(&inner);
Some(entry.response)
}
#[must_use]
pub fn record_response(
&self,
request: &MessageRequest,
response: &MessageResponse,
) -> PromptCacheRecord {
self.record_usage_internal(request, &response.usage, Some(response))
}
#[must_use]
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
self.record_usage_internal(request, usage, None)
}
fn record_usage_internal(
&self,
request: &MessageRequest,
usage: &Usage,
response: Option<&MessageResponse>,
) -> PromptCacheRecord {
let request_hash = request_hash_hex(request);
let mut inner = self.lock();
let previous = inner.previous.clone();
let current = TrackedPromptState::from_usage(request, usage);
let cache_break = detect_cache_break(&inner.config, previous.as_ref(), &current);
inner.stats.tracked_requests += 1;
apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response");
if let Some(event) = &cache_break {
if event.unexpected {
inner.stats.unexpected_cache_breaks += 1;
} else {
inner.stats.expected_invalidations += 1;
}
inner.stats.last_break_reason = Some(event.reason.clone());
}
inner.previous = Some(current);
if let Some(response) = response {
write_completion_entry(&inner.paths, &request_hash, response);
inner.stats.completion_cache_writes += 1;
}
persist_state(&inner);
PromptCacheRecord {
cache_break,
stats: inner.stats.clone(),
}
}
fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> {
self.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
#[derive(Debug)]
struct PromptCacheInner {
config: PromptCacheConfig,
paths: PromptCachePaths,
stats: PromptCacheStats,
previous: Option<TrackedPromptState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CompletionCacheEntry {
cached_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")]
fingerprint_version: u32,
response: MessageResponse,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct TrackedPromptState {
observed_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")]
fingerprint_version: u32,
model_hash: u64,
system_hash: u64,
tools_hash: u64,
messages_hash: u64,
cache_read_input_tokens: u32,
}
impl TrackedPromptState {
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
let hashes = RequestFingerprints::from_request(request);
Self {
observed_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(),
model_hash: hashes.model,
system_hash: hashes.system,
tools_hash: hashes.tools,
messages_hash: hashes.messages,
cache_read_input_tokens: usage.cache_read_input_tokens,
}
}
}
#[derive(Debug, Clone, Copy)]
struct RequestFingerprints {
model: u64,
system: u64,
tools: u64,
messages: u64,
}
impl RequestFingerprints {
fn from_request(request: &MessageRequest) -> Self {
Self {
model: hash_serializable(&request.model),
system: hash_serializable(&request.system),
tools: hash_serializable(&request.tools),
messages: hash_serializable(&request.messages),
}
}
}
fn detect_cache_break(
config: &PromptCacheConfig,
previous: Option<&TrackedPromptState>,
current: &TrackedPromptState,
) -> Option<CacheBreakEvent> {
let previous = previous?;
if previous.fingerprint_version != current.fingerprint_version {
return Some(CacheBreakEvent {
unexpected: false,
reason: format!(
"fingerprint version changed (v{} -> v{})",
previous.fingerprint_version, current.fingerprint_version
),
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
current_cache_read_input_tokens: current.cache_read_input_tokens,
token_drop: previous
.cache_read_input_tokens
.saturating_sub(current.cache_read_input_tokens),
});
}
let token_drop = previous
.cache_read_input_tokens
.saturating_sub(current.cache_read_input_tokens);
if token_drop < config.cache_break_min_drop {
return None;
}
let mut reasons = Vec::new();
if previous.model_hash != current.model_hash {
reasons.push("model changed");
}
if previous.system_hash != current.system_hash {
reasons.push("system prompt changed");
}
if previous.tools_hash != current.tools_hash {
reasons.push("tool definitions changed");
}
if previous.messages_hash != current.messages_hash {
reasons.push("message payload changed");
}
let elapsed = current
.observed_at_unix_secs
.saturating_sub(previous.observed_at_unix_secs);
let (unexpected, reason) = if reasons.is_empty() {
if elapsed > config.prompt_ttl.as_secs() {
(
false,
format!("possible prompt cache TTL expiry after {elapsed}s"),
)
} else {
(
true,
"cache read tokens dropped while prompt fingerprint remained stable".to_string(),
)
}
} else {
(false, reasons.join(", "))
};
Some(CacheBreakEvent {
unexpected,
reason,
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
current_cache_read_input_tokens: current.cache_read_input_tokens,
token_drop,
})
}
fn apply_usage_to_stats(
stats: &mut PromptCacheStats,
usage: &Usage,
request_hash: &str,
source: &str,
) {
stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens);
stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens);
stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens);
stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens);
stats.last_request_hash = Some(request_hash.to_string());
stats.last_cache_source = Some(source.to_string());
}
fn persist_state(inner: &PromptCacheInner) {
let _ = ensure_cache_dirs(&inner.paths);
let _ = write_json(&inner.paths.stats_path, &inner.stats);
if let Some(previous) = &inner.previous {
let _ = write_json(&inner.paths.session_state_path, previous);
}
}
fn write_completion_entry(
paths: &PromptCachePaths,
request_hash: &str,
response: &MessageResponse,
) {
let _ = ensure_cache_dirs(paths);
let entry = CompletionCacheEntry {
cached_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(),
response: response.clone(),
};
let _ = write_json(&paths.completion_entry_path(request_hash), &entry);
}
fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> {
fs::create_dir_all(&paths.completion_dir)
}
fn write_json<T: Serialize>(path: &Path, value: &T) -> std::io::Result<()> {
let json = serde_json::to_vec_pretty(value)
.map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?;
fs::write(path, json)
}
fn read_json<T: for<'de> Deserialize<'de>>(path: &Path) -> Option<T> {
let bytes = fs::read(path).ok()?;
serde_json::from_slice(&bytes).ok()
}
fn request_hash_hex(request: &MessageRequest) -> String {
format!(
"{REQUEST_FINGERPRINT_PREFIX}-{:016x}",
hash_serializable(request)
)
}
fn hash_serializable<T: Serialize>(value: &T) -> u64 {
let json = serde_json::to_vec(value).unwrap_or_default();
stable_hash_bytes(&json)
}
fn sanitize_path_segment(value: &str) -> String {
let sanitized: String = value
.chars()
.map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' })
.collect();
if sanitized.len() <= MAX_SANITIZED_LENGTH {
return sanitized;
}
let suffix = format!("-{:x}", hash_string(value));
format!(
"{}{}",
&sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())],
suffix
)
}
fn hash_string(value: &str) -> u64 {
stable_hash_bytes(value.as_bytes())
}
fn base_cache_root() -> PathBuf {
if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") {
return PathBuf::from(config_home)
.join("cache")
.join("prompt-cache");
}
if let Some(home) = std::env::var_os("HOME") {
return PathBuf::from(home)
.join(".claude")
.join("cache")
.join("prompt-cache");
}
std::env::temp_dir().join("claude-prompt-cache")
}
fn now_unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |duration| duration.as_secs())
}
const fn current_fingerprint_version() -> u32 {
REQUEST_FINGERPRINT_VERSION
}
fn stable_hash_bytes(bytes: &[u8]) -> u64 {
let mut hash = FNV_OFFSET_BASIS;
for byte in bytes {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(FNV_PRIME);
}
hash
}
#[cfg(test)]
mod tests {
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use super::{
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
};
use crate::test_env_lock;
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
#[test]
fn path_builder_sanitizes_session_identifier() {
let paths = PromptCachePaths::for_session("session:/with spaces");
let session_dir = paths
.session_dir
.file_name()
.and_then(|value| value.to_str())
.expect("session dir name");
assert_eq!(session_dir, "session--with-spaces");
assert!(paths.completion_dir.ends_with("completions"));
assert!(paths.stats_path.ends_with("stats.json"));
assert!(paths.session_state_path.ends_with("session-state.json"));
}
#[test]
fn request_fingerprint_drives_unexpected_break_detection() {
let request = sample_request("same");
let previous = TrackedPromptState::from_usage(
&request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 6_000,
output_tokens: 0,
},
);
let current = TrackedPromptState::from_usage(
&request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 1_000,
output_tokens: 0,
},
);
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
.expect("break should be detected");
assert!(event.unexpected);
assert!(event.reason.contains("stable"));
}
#[test]
fn changed_prompt_marks_break_as_expected() {
let previous_request = sample_request("first");
let current_request = sample_request("second");
let previous = TrackedPromptState::from_usage(
&previous_request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 6_000,
output_tokens: 0,
},
);
let current = TrackedPromptState::from_usage(
&current_request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 1_000,
output_tokens: 0,
},
);
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
.expect("break should be detected");
assert!(!event.unexpected);
assert!(event.reason.contains("message payload changed"));
}
#[test]
fn completion_cache_round_trip_persists_recent_response() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-test-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("unit-test-session");
let request = sample_request("cache me");
let response = sample_response(42, 12, "cached");
assert!(cache.lookup_completion(&request).is_none());
let record = cache.record_response(&request, &response);
assert!(record.cache_break.is_none());
let cached = cache
.lookup_completion(&request)
.expect("cached response should load");
assert_eq!(cached.content, response.content);
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 1);
assert_eq!(stats.completion_cache_misses, 1);
assert_eq!(stats.completion_cache_writes, 1);
let persisted = read_json::<super::PromptCacheStats>(&cache.paths().stats_path)
.expect("stats should persist");
assert_eq!(persisted.completion_cache_hits, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn distinct_requests_do_not_collide_in_completion_cache() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-distinct-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("distinct-request-session");
let first_request = sample_request("first");
let second_request = sample_request("second");
let response = sample_response(42, 12, "cached");
let _ = cache.record_response(&first_request, &response);
assert!(cache.lookup_completion(&second_request).is_none());
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn expired_completion_entries_are_not_reused() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-expired-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::with_config(PromptCacheConfig {
session_id: "expired-session".to_string(),
completion_ttl: Duration::ZERO,
..PromptCacheConfig::default()
});
let request = sample_request("expire me");
let response = sample_response(7, 3, "stale");
let _ = cache.record_response(&request, &response);
assert!(cache.lookup_completion(&request).is_none());
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 0);
assert_eq!(stats.completion_cache_misses, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn sanitize_path_caps_long_values() {
let long_value = "x".repeat(200);
let sanitized = sanitize_path_segment(&long_value);
assert!(sanitized.len() <= 80);
}
#[test]
fn request_hashes_are_versioned_and_stable() {
let request = sample_request("stable");
let first = request_hash_hex(&request);
let second = request_hash_hex(&request);
assert_eq!(first, second);
assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX));
}
fn sample_request(text: &str) -> MessageRequest {
MessageRequest {
model: "claude-3-7-sonnet-latest".to_string(),
max_tokens: 64,
messages: vec![InputMessage::user_text(text)],
system: Some("system".to_string()),
tools: None,
tool_choice: None,
stream: false,
}
}
fn sample_response(
cache_read_input_tokens: u32,
output_tokens: u32,
text: &str,
) -> MessageResponse {
MessageResponse {
id: "msg_test".to_string(),
kind: "message".to_string(),
role: "assistant".to_string(),
content: vec![OutputContentBlock::Text {
text: text.to_string(),
}],
model: "claude-3-7-sonnet-latest".to_string(),
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
usage: Usage {
input_tokens: 10,
cache_creation_input_tokens: 5,
cache_read_input_tokens,
output_tokens,
},
request_id: Some("req_test".to_string()),
}
}
}

View File

@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Mutex as StdMutex, OnceLock};
use std::time::Duration;
use api::{
@@ -13,6 +14,13 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
LOCK.get_or_init(|| StdMutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[tokio::test]
async fn send_message_posts_json_and_parses_response() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
@@ -46,6 +54,8 @@ async fn send_message_posts_json_and_parses_response() {
assert_eq!(response.id, "msg_test");
assert_eq!(response.total_tokens(), 16);
assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
assert_eq!(response.usage.cache_creation_input_tokens, 0);
assert_eq!(response.usage.cache_read_input_tokens, 0);
assert_eq!(
response.content,
vec![OutputContentBlock::Text {
@@ -198,11 +208,55 @@ async fn send_message_applies_request_profile_and_records_telemetry() {
}
#[tokio::test]
async fn send_message_parses_prompt_cache_token_usage_from_response() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let body = concat!(
"{",
"\"id\":\"msg_cache_tokens\",",
"\"type\":\"message\",",
"\"role\":\"assistant\",",
"\"content\":[{\"type\":\"text\",\"text\":\"Cache tokens\"}],",
"\"model\":\"claude-3-7-sonnet-latest\",",
"\"stop_reason\":\"end_turn\",",
"\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":12,\"cache_creation_input_tokens\":321,\"cache_read_input_tokens\":654,\"output_tokens\":4}",
"}"
);
let server = spawn_server(
state,
vec![http_response("200 OK", "application/json", body)],
)
.await;
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
let response = client
.send_message(&sample_request(false))
.await
.expect("request should succeed");
assert_eq!(response.usage.input_tokens, 12);
assert_eq!(response.usage.cache_creation_input_tokens, 321);
assert_eq!(response.usage.cache_read_input_tokens, 654);
assert_eq!(response.usage.output_tokens, 4);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn stream_message_parses_sse_events_with_tool_use() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-stream-cache-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let sse = concat!(
"event: message_start\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":13,\"cache_read_input_tokens\":21,\"output_tokens\":0}}}\n\n",
"event: content_block_start\n",
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
"event: content_block_delta\n",
@@ -210,7 +264,7 @@ async fn stream_message_parses_sse_events_with_tool_use() {
"event: content_block_stop\n",
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
"event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":34,\"cache_read_input_tokens\":55,\"output_tokens\":1}}\n\n",
"event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n\n",
"data: [DONE]\n\n"
@@ -228,7 +282,8 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let client = ApiClient::new("test-key")
.with_auth_token(Some("proxy-token".to_string()))
.with_base_url(server.base_url());
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::new("stream-session"));
let mut stream = client
.stream_message(&sample_request(false))
.await
@@ -282,6 +337,20 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let captured = state.lock().await;
let request = captured.first().expect("server should capture request");
assert!(request.body.contains("\"stream\":true"));
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.tracked_requests, 1);
assert_eq!(cache_stats.last_cache_creation_input_tokens, Some(34));
assert_eq!(cache_stats.last_cache_read_input_tokens, Some(55));
assert_eq!(
cache_stats.last_cache_source.as_deref(),
Some("api-response")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test]
@@ -406,6 +475,121 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
}
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_reuses_recent_completion_cache_entries() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-prompt-cache-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}",
)],
)
.await;
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::new("integration-session"));
let first = client
.send_message(&sample_request(false))
.await
.expect("first request should succeed");
let second = client
.send_message(&sample_request(false))
.await
.expect("second request should reuse cache");
assert_eq!(first.content, second.content);
assert_eq!(state.lock().await.len(), 1);
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.completion_cache_hits, 1);
assert_eq!(cache_stats.completion_cache_misses, 1);
assert_eq!(cache_stats.completion_cache_writes, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_tracks_unexpected_prompt_cache_breaks() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-prompt-break-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state,
vec![
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}",
),
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}",
),
],
)
.await;
let request = sample_request(false);
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::with_config(api::PromptCacheConfig {
session_id: "break-session".to_string(),
completion_ttl: Duration::from_secs(0),
..api::PromptCacheConfig::default()
}));
client
.send_message(&request)
.await
.expect("first response should succeed");
client
.send_message(&request)
.await
.expect("second response should succeed");
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.unexpected_cache_breaks, 1);
assert_eq!(
cache_stats.last_break_reason.as_deref(),
Some("cache read tokens dropped while prompt fingerprint remained stable")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test]
#[ignore = "requires ANTHROPIC_API_KEY and network access"]
async fn live_stream_smoke_test() {

View File

@@ -32,9 +32,19 @@ pub enum AssistantEvent {
input: String,
},
Usage(TokenUsage),
PromptCache(PromptCacheEvent),
MessageStop,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptCacheEvent {
pub unexpected: bool,
pub reason: String,
pub previous_cache_read_input_tokens: u32,
pub current_cache_read_input_tokens: u32,
pub token_drop: u32,
}
pub trait ApiClient {
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
}
@@ -91,6 +101,7 @@ impl std::error::Error for RuntimeError {}
pub struct TurnSummary {
pub assistant_messages: Vec<ConversationMessage>,
pub tool_results: Vec<ConversationMessage>,
pub prompt_cache_events: Vec<PromptCacheEvent>,
pub iterations: usize,
pub usage: TokenUsage,
pub auto_compaction: Option<AutoCompactionEvent>,
@@ -284,6 +295,7 @@ where
let mut assistant_messages = Vec::new();
let mut tool_results = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut iterations = 0;
loop {
@@ -317,6 +329,7 @@ where
if let Some(usage) = usage {
self.usage_tracker.record(usage);
}
prompt_cache_events.extend(turn_prompt_cache_events);
let pending_tool_uses = assistant_message
.blocks
.iter()
@@ -435,6 +448,7 @@ where
Ok(TurnSummary {
assistant_messages,
tool_results,
prompt_cache_events,
iterations,
usage: self.usage_tracker.cumulative_usage(),
auto_compaction,
@@ -534,9 +548,17 @@ fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 {
fn build_assistant_message(
events: Vec<AssistantEvent>,
) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
) -> Result<
(
ConversationMessage,
Option<TokenUsage>,
Vec<PromptCacheEvent>,
),
RuntimeError,
> {
let mut text = String::new();
let mut blocks = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut finished = false;
let mut usage = None;
@@ -548,6 +570,7 @@ fn build_assistant_message(
blocks.push(ContentBlock::ToolUse { id, name, input });
}
AssistantEvent::Usage(value) => usage = Some(value),
AssistantEvent::PromptCache(event) => prompt_cache_events.push(event),
AssistantEvent::MessageStop => {
finished = true;
}
@@ -568,6 +591,7 @@ fn build_assistant_message(
Ok((
ConversationMessage::assistant_with_usage(blocks, usage),
usage,
prompt_cache_events,
))
}
@@ -700,6 +724,15 @@ mod tests {
cache_creation_input_tokens: 1,
cache_read_input_tokens: 3,
}),
AssistantEvent::PromptCache(PromptCacheEvent {
unexpected: true,
reason:
"cache read tokens dropped while prompt fingerprint remained stable"
.to_string(),
previous_cache_read_input_tokens: 6_000,
current_cache_read_input_tokens: 1_000,
token_drop: 5_000,
}),
AssistantEvent::MessageStop,
])
}
@@ -753,6 +786,7 @@ mod tests {
assert_eq!(summary.iterations, 2);
assert_eq!(summary.assistant_messages.len(), 2);
assert_eq!(summary.tool_results.len(), 1);
assert_eq!(summary.prompt_cache_events.len(), 1);
assert_eq!(runtime.session().messages.len(), 4);
assert_eq!(summary.usage.output_tokens, 10);
assert_eq!(summary.auto_compaction, None);

View File

@@ -1311,6 +1311,7 @@ impl LiveCli {
})),
"tool_uses": collect_tool_uses(&summary),
"tool_results": collect_tool_results(&summary),
"prompt_cache_events": collect_prompt_cache_events(&summary),
"usage": {
"input_tokens": summary.usage.input_tokens,
"output_tokens": summary.usage.output_tokens,
@@ -3276,7 +3277,8 @@ impl AnthropicRuntimeClient {
Ok(Self {
runtime: tokio::runtime::Runtime::new()?,
client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
.with_base_url(api::read_base_url()),
.with_base_url(api::read_base_url())
.with_prompt_cache(PromptCache::new(session_id)),
model,
enable_tools,
emit_output,
@@ -3413,6 +3415,8 @@ impl ApiClient for AnthropicRuntimeClient {
}
}
push_prompt_cache_record(&self.client, &mut events);
if !saw_stop
&& events.iter().any(|event| {
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
@@ -3437,7 +3441,9 @@ impl ApiClient for AnthropicRuntimeClient {
})
.await
.map_err(|error| RuntimeError::new(error.to_string()))?;
response_to_events(response, out)
let mut events = response_to_events(response, out)?;
push_prompt_cache_record(&self.client, &mut events);
Ok(events)
})
}
}
@@ -3498,6 +3504,39 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec<serde_json::Value
.collect()
}
fn collect_prompt_cache_events(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> {
summary
.prompt_cache_events
.iter()
.map(|event| {
json!({
"unexpected": event.unexpected,
"reason": event.reason,
"previous_cache_read_input_tokens": event.previous_cache_read_input_tokens,
"current_cache_read_input_tokens": event.current_cache_read_input_tokens,
"token_drop": event.token_drop,
})
})
.collect()
}
fn print_prompt_cache_events(summary: &runtime::TurnSummary) {
for event in &summary.prompt_cache_events {
let label = if event.unexpected {
"Prompt cache break"
} else {
"Prompt cache invalidation"
};
println!(
"{label}: {} (cache read {} -> {}, drop {})",
event.reason,
event.previous_cache_read_input_tokens,
event.current_cache_read_input_tokens,
event.token_drop,
);
}
}
fn slash_command_completion_candidates() -> Vec<String> {
slash_command_specs()
.iter()
@@ -3653,6 +3692,8 @@ fn first_visible_line(text: &str) -> &str {
}
fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String {
use std::fmt::Write as _;
let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")];
if let Some(task_id) = parsed
.get("backgroundTaskId")
@@ -3996,6 +4037,26 @@ fn response_to_events(
Ok(events)
}
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
if let Some(event) = client
.take_last_prompt_cache_record()
.and_then(prompt_cache_record_to_runtime_event)
{
events.push(AssistantEvent::PromptCache(event));
}
}
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
let cache_break = record.cache_break?;
Some(PromptCacheEvent {
unexpected: cache_break.unexpected,
reason: cache_break.reason,
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
token_drop: cache_break.token_drop,
})
}
struct CliToolExecutor {
renderer: TerminalRenderer,
emit_output: bool,

View File

@@ -1900,6 +1900,8 @@ impl ApiClient for ProviderRuntimeClient {
}
}
push_prompt_cache_record(&self.client, &mut events);
if !saw_stop
&& events.iter().any(|event| {
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
@@ -1924,7 +1926,9 @@ impl ApiClient for ProviderRuntimeClient {
})
.await
.map_err(|error| RuntimeError::new(error.to_string()))?;
Ok(response_to_events(response))
let mut events = response_to_events(response);
push_prompt_cache_record(&self.client, &mut events);
Ok(events)
})
}
}
@@ -2045,6 +2049,26 @@ fn response_to_events(response: MessageResponse) -> Vec<AssistantEvent> {
events
}
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
if let Some(event) = client
.take_last_prompt_cache_record()
.and_then(prompt_cache_record_to_runtime_event)
{
events.push(AssistantEvent::PromptCache(event));
}
}
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
let cache_break = record.cache_break?;
Some(PromptCacheEvent {
unexpected: cache_break.unexpected,
reason: cache_break.reason,
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
token_drop: cache_break.token_drop,
})
}
fn final_assistant_text(summary: &runtime::TurnSummary) -> String {
summary
.assistant_messages