4 Commits

Author SHA1 Message Date
Yeachan-Heo
799c92eada feat: cache-tracking progress 2026-04-01 06:25:26 +00:00
Yeachan-Heo
c9d214c8d1 feat: cache-tracking progress 2026-04-01 06:15:13 +00:00
Yeachan-Heo
26344c578b wip: cache-tracking progress 2026-04-01 04:40:17 +00:00
Yeachan-Heo
0cf2204d43 wip: cache-tracking progress 2026-04-01 04:30:24 +00:00
14 changed files with 1401 additions and 1157 deletions

View File

@@ -1,4 +1,5 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use runtime::{ use runtime::{
@@ -8,8 +9,9 @@ use runtime::{
use serde::Deserialize; use serde::Deserialize;
use crate::error::ApiError; use crate::error::ApiError;
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
use crate::sse::SseParser; use crate::sse::SseParser;
use crate::types::{MessageRequest, MessageResponse, StreamEvent}; use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage};
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
const ANTHROPIC_VERSION: &str = "2023-06-01"; const ANTHROPIC_VERSION: &str = "2023-06-01";
@@ -108,6 +110,8 @@ pub struct AnthropicClient {
max_retries: u32, max_retries: u32,
initial_backoff: Duration, initial_backoff: Duration,
max_backoff: Duration, max_backoff: Duration,
prompt_cache: Option<PromptCache>,
last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
} }
impl AnthropicClient { impl AnthropicClient {
@@ -120,6 +124,8 @@ impl AnthropicClient {
max_retries: DEFAULT_MAX_RETRIES, max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF, initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
} }
} }
@@ -132,6 +138,8 @@ impl AnthropicClient {
max_retries: DEFAULT_MAX_RETRIES, max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF, initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
} }
} }
@@ -189,6 +197,30 @@ impl AnthropicClient {
self self
} }
#[must_use]
pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self {
self.prompt_cache = Some(prompt_cache);
self
}
#[must_use]
pub fn prompt_cache(&self) -> Option<&PromptCache> {
self.prompt_cache.as_ref()
}
#[must_use]
pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
self.prompt_cache.as_ref().map(PromptCache::stats)
}
#[must_use]
pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
self.last_prompt_cache_record()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
}
#[must_use] #[must_use]
pub fn auth_source(&self) -> &AuthSource { pub fn auth_source(&self) -> &AuthSource {
&self.auth &self.auth
@@ -198,10 +230,19 @@ impl AnthropicClient {
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageResponse, ApiError> { ) -> Result<MessageResponse, ApiError> {
self.store_last_prompt_cache_record(None);
let request = MessageRequest { let request = MessageRequest {
stream: false, stream: false,
..request.clone() ..request.clone()
}; };
if let Some(prompt_cache) = &self.prompt_cache {
if let Some(response) = prompt_cache.lookup_completion(&request) {
self.store_last_prompt_cache_record(Some(prompt_cache_record_from_stats(
prompt_cache.stats(),
)));
return Ok(response);
}
}
let response = self.send_with_retry(&request).await?; let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers()); let request_id = request_id_from_headers(response.headers());
let mut response = response let mut response = response
@@ -211,6 +252,10 @@ impl AnthropicClient {
if response.request_id.is_none() { if response.request_id.is_none() {
response.request_id = request_id; response.request_id = request_id;
} }
if let Some(prompt_cache) = &self.prompt_cache {
let record = prompt_cache.record_response(&request, &response);
self.store_last_prompt_cache_record(Some(record));
}
Ok(response) Ok(response)
} }
@@ -218,6 +263,7 @@ impl AnthropicClient {
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageStream, ApiError> { ) -> Result<MessageStream, ApiError> {
self.store_last_prompt_cache_record(None);
let response = self let response = self
.send_with_retry(&request.clone().with_streaming()) .send_with_retry(&request.clone().with_streaming())
.await?; .await?;
@@ -227,9 +273,30 @@ impl AnthropicClient {
parser: SseParser::new(), parser: SseParser::new(),
pending: VecDeque::new(), pending: VecDeque::new(),
done: false, done: false,
cache_tracking: self
.prompt_cache
.as_ref()
.map(|prompt_cache| StreamCacheTracking {
prompt_cache: prompt_cache.clone(),
request: request.clone().with_streaming(),
last_usage: None,
finalized: false,
last_record: self.last_prompt_cache_record.clone(),
}),
}) })
} }
fn store_last_prompt_cache_record(&self, record: Option<PromptCacheRecord>) {
*self
.last_prompt_cache_record()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = record;
}
fn last_prompt_cache_record(&self) -> &Arc<Mutex<Option<PromptCacheRecord>>> {
&self.last_prompt_cache_record
}
pub async fn exchange_oauth_code( pub async fn exchange_oauth_code(
&self, &self,
config: &OAuthConfig, config: &OAuthConfig,
@@ -527,6 +594,7 @@ pub struct MessageStream {
parser: SseParser, parser: SseParser,
pending: VecDeque<StreamEvent>, pending: VecDeque<StreamEvent>,
done: bool, done: bool,
cache_tracking: Option<StreamCacheTracking>,
} }
impl MessageStream { impl MessageStream {
@@ -538,6 +606,9 @@ impl MessageStream {
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> { pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
loop { loop {
if let Some(event) = self.pending.pop_front() { if let Some(event) = self.pending.pop_front() {
if let Some(cache_tracking) = &mut self.cache_tracking {
cache_tracking.observe(&event);
}
return Ok(Some(event)); return Ok(Some(event));
} }
@@ -545,8 +616,14 @@ impl MessageStream {
let remaining = self.parser.finish()?; let remaining = self.parser.finish()?;
self.pending.extend(remaining); self.pending.extend(remaining);
if let Some(event) = self.pending.pop_front() { if let Some(event) = self.pending.pop_front() {
if let Some(cache_tracking) = &mut self.cache_tracking {
cache_tracking.observe(&event);
}
return Ok(Some(event)); return Ok(Some(event));
} }
if let Some(cache_tracking) = &mut self.cache_tracking {
cache_tracking.finalize();
}
return Ok(None); return Ok(None);
} }
@@ -562,6 +639,53 @@ impl MessageStream {
} }
} }
#[derive(Debug, Clone)]
struct StreamCacheTracking {
prompt_cache: PromptCache,
request: MessageRequest,
last_usage: Option<Usage>,
finalized: bool,
last_record: Arc<Mutex<Option<PromptCacheRecord>>>,
}
impl StreamCacheTracking {
fn observe(&mut self, event: &StreamEvent) {
match event {
StreamEvent::MessageStart(event) => {
self.last_usage = Some(event.message.usage.clone());
}
StreamEvent::MessageDelta(event) => {
self.last_usage = Some(event.usage.clone());
}
StreamEvent::ContentBlockStart(_)
| StreamEvent::ContentBlockDelta(_)
| StreamEvent::ContentBlockStop(_)
| StreamEvent::MessageStop(_) => {}
}
}
fn finalize(&mut self) {
if self.finalized {
return;
}
if let Some(usage) = &self.last_usage {
let record = self.prompt_cache.record_usage(&self.request, usage);
*self
.last_record
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
}
self.finalized = true;
}
}
fn prompt_cache_record_from_stats(stats: PromptCacheStats) -> PromptCacheRecord {
PromptCacheRecord {
cache_break: None,
stats,
}
}
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> { async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
let status = response.status(); let status = response.status();
if status.is_success() { if status.is_success() {
@@ -606,7 +730,7 @@ mod tests {
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::net::TcpListener; use std::net::TcpListener;
use std::sync::{Mutex, OnceLock}; use std::sync::atomic::{AtomicU64, Ordering};
use std::thread; use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
@@ -616,19 +740,15 @@ mod tests {
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
}; };
use crate::test_env_lock;
use crate::types::{ContentBlockDelta, MessageRequest}; use crate::types::{ContentBlockDelta, MessageRequest};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock")
}
fn temp_config_home() -> std::path::PathBuf { fn temp_config_home() -> std::path::PathBuf {
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
std::env::temp_dir().join(format!( std::env::temp_dir().join(format!(
"api-oauth-test-{}-{}", "api-oauth-test-{}-{}-{}",
std::process::id(), std::process::id(),
NEXT_ID.fetch_add(1, Ordering::Relaxed),
SystemTime::now() SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.expect("time") .expect("time")
@@ -668,7 +788,7 @@ mod tests {
#[test] #[test]
fn read_api_key_requires_presence() { fn read_api_key_requires_presence() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY"); std::env::remove_var("ANTHROPIC_API_KEY");
std::env::remove_var("CLAUDE_CONFIG_HOME"); std::env::remove_var("CLAUDE_CONFIG_HOME");
@@ -678,7 +798,7 @@ mod tests {
#[test] #[test]
fn read_api_key_requires_non_empty_value() { fn read_api_key_requires_non_empty_value() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
std::env::remove_var("ANTHROPIC_API_KEY"); std::env::remove_var("ANTHROPIC_API_KEY");
let error = super::read_api_key().expect_err("empty key should error"); let error = super::read_api_key().expect_err("empty key should error");
@@ -688,7 +808,7 @@ mod tests {
#[test] #[test]
fn read_api_key_prefers_api_key_env() { fn read_api_key_prefers_api_key_env() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
assert_eq!( assert_eq!(
@@ -701,7 +821,7 @@ mod tests {
#[test] #[test]
fn read_auth_token_reads_auth_token_env() { fn read_auth_token_reads_auth_token_env() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -721,7 +841,7 @@ mod tests {
#[test] #[test]
fn auth_source_from_env_combines_api_key_and_bearer_token() { fn auth_source_from_env_combines_api_key_and_bearer_token() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
let auth = AuthSource::from_env().expect("env auth"); let auth = AuthSource::from_env().expect("env auth");
@@ -733,7 +853,7 @@ mod tests {
#[test] #[test]
fn auth_source_from_saved_oauth_when_env_absent() { fn auth_source_from_saved_oauth_when_env_absent() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -772,7 +892,7 @@ mod tests {
#[test] #[test]
fn resolve_saved_oauth_token_refreshes_expired_credentials() { fn resolve_saved_oauth_token_refreshes_expired_credentials() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -804,7 +924,7 @@ mod tests {
#[test] #[test]
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -828,7 +948,7 @@ mod tests {
#[test] #[test]
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -860,7 +980,7 @@ mod tests {
#[test] #[test]
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");

View File

@@ -1,5 +1,6 @@
mod client; mod client;
mod error; mod error;
mod prompt_cache;
mod sse; mod sse;
mod types; mod types;
@@ -8,6 +9,10 @@ pub use client::{
AnthropicClient, AuthSource, MessageStream, OAuthTokenSet, AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
}; };
pub use error::ApiError; pub use error::ApiError;
pub use prompt_cache::{
CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord,
PromptCacheStats,
};
pub use sse::{parse_frame, SseParser}; pub use sse::{parse_frame, SseParser};
pub use types::{ pub use types::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
@@ -15,3 +20,11 @@ pub use types::{
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
}; };
#[cfg(test)]
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}

View File

@@ -0,0 +1,727 @@
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::types::{MessageRequest, MessageResponse, Usage};
const DEFAULT_COMPLETION_TTL_SECS: u64 = 30;
const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60;
const DEFAULT_BREAK_MIN_DROP: u32 = 2_000;
const MAX_SANITIZED_LENGTH: usize = 80;
const REQUEST_FINGERPRINT_VERSION: u32 = 1;
const REQUEST_FINGERPRINT_PREFIX: &str = "v1";
const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
#[derive(Debug, Clone)]
pub struct PromptCacheConfig {
pub session_id: String,
pub completion_ttl: Duration,
pub prompt_ttl: Duration,
pub cache_break_min_drop: u32,
}
impl PromptCacheConfig {
#[must_use]
pub fn new(session_id: impl Into<String>) -> Self {
Self {
session_id: session_id.into(),
completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS),
prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS),
cache_break_min_drop: DEFAULT_BREAK_MIN_DROP,
}
}
}
impl Default for PromptCacheConfig {
fn default() -> Self {
Self::new("default")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PromptCachePaths {
pub root: PathBuf,
pub session_dir: PathBuf,
pub completion_dir: PathBuf,
pub session_state_path: PathBuf,
pub stats_path: PathBuf,
}
impl PromptCachePaths {
#[must_use]
pub fn for_session(session_id: &str) -> Self {
let root = base_cache_root();
let session_dir = root.join(sanitize_path_segment(session_id));
let completion_dir = session_dir.join("completions");
Self {
root,
session_state_path: session_dir.join("session-state.json"),
stats_path: session_dir.join("stats.json"),
session_dir,
completion_dir,
}
}
#[must_use]
pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf {
self.completion_dir.join(format!("{request_hash}.json"))
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct PromptCacheStats {
pub tracked_requests: u64,
pub completion_cache_hits: u64,
pub completion_cache_misses: u64,
pub completion_cache_writes: u64,
pub expected_invalidations: u64,
pub unexpected_cache_breaks: u64,
pub total_cache_creation_input_tokens: u64,
pub total_cache_read_input_tokens: u64,
pub last_cache_creation_input_tokens: Option<u32>,
pub last_cache_read_input_tokens: Option<u32>,
pub last_request_hash: Option<String>,
pub last_completion_cache_key: Option<String>,
pub last_break_reason: Option<String>,
pub last_cache_source: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CacheBreakEvent {
pub unexpected: bool,
pub reason: String,
pub previous_cache_read_input_tokens: u32,
pub current_cache_read_input_tokens: u32,
pub token_drop: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptCacheRecord {
pub cache_break: Option<CacheBreakEvent>,
pub stats: PromptCacheStats,
}
#[derive(Debug, Clone)]
pub struct PromptCache {
inner: Arc<Mutex<PromptCacheInner>>,
}
impl PromptCache {
#[must_use]
pub fn new(session_id: impl Into<String>) -> Self {
Self::with_config(PromptCacheConfig::new(session_id))
}
#[must_use]
pub fn with_config(config: PromptCacheConfig) -> Self {
let paths = PromptCachePaths::for_session(&config.session_id);
let stats = read_json::<PromptCacheStats>(&paths.stats_path).unwrap_or_default();
let previous = read_json::<TrackedPromptState>(&paths.session_state_path);
Self {
inner: Arc::new(Mutex::new(PromptCacheInner {
config,
paths,
stats,
previous,
})),
}
}
#[must_use]
pub fn paths(&self) -> PromptCachePaths {
self.lock().paths.clone()
}
#[must_use]
pub fn stats(&self) -> PromptCacheStats {
self.lock().stats.clone()
}
#[must_use]
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
let request_hash = request_hash_hex(request);
let (paths, ttl) = {
let inner = self.lock();
(inner.paths.clone(), inner.config.completion_ttl)
};
let entry_path = paths.completion_entry_path(&request_hash);
let entry = read_json::<CompletionCacheEntry>(&entry_path);
let Some(entry) = entry else {
let mut inner = self.lock();
inner.stats.completion_cache_misses += 1;
inner.stats.last_completion_cache_key = Some(request_hash);
persist_state(&inner);
return None;
};
if entry.fingerprint_version != current_fingerprint_version() {
let mut inner = self.lock();
inner.stats.completion_cache_misses += 1;
inner.stats.last_completion_cache_key = Some(request_hash.clone());
let _ = fs::remove_file(entry_path);
persist_state(&inner);
return None;
}
let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs();
let mut inner = self.lock();
inner.stats.last_completion_cache_key = Some(request_hash.clone());
if expired {
inner.stats.completion_cache_misses += 1;
let _ = fs::remove_file(entry_path);
persist_state(&inner);
return None;
}
inner.stats.completion_cache_hits += 1;
apply_usage_to_stats(
&mut inner.stats,
&entry.response.usage,
&request_hash,
"completion-cache",
);
inner.previous = Some(TrackedPromptState::from_usage(
request,
&entry.response.usage,
));
persist_state(&inner);
Some(entry.response)
}
#[must_use]
pub fn record_response(
&self,
request: &MessageRequest,
response: &MessageResponse,
) -> PromptCacheRecord {
self.record_usage_internal(request, &response.usage, Some(response))
}
#[must_use]
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
self.record_usage_internal(request, usage, None)
}
fn record_usage_internal(
&self,
request: &MessageRequest,
usage: &Usage,
response: Option<&MessageResponse>,
) -> PromptCacheRecord {
let request_hash = request_hash_hex(request);
let mut inner = self.lock();
let previous = inner.previous.clone();
let current = TrackedPromptState::from_usage(request, usage);
let cache_break = detect_cache_break(&inner.config, previous.as_ref(), &current);
inner.stats.tracked_requests += 1;
apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response");
if let Some(event) = &cache_break {
if event.unexpected {
inner.stats.unexpected_cache_breaks += 1;
} else {
inner.stats.expected_invalidations += 1;
}
inner.stats.last_break_reason = Some(event.reason.clone());
}
inner.previous = Some(current);
if let Some(response) = response {
write_completion_entry(&inner.paths, &request_hash, response);
inner.stats.completion_cache_writes += 1;
}
persist_state(&inner);
PromptCacheRecord {
cache_break,
stats: inner.stats.clone(),
}
}
fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> {
self.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
#[derive(Debug)]
struct PromptCacheInner {
config: PromptCacheConfig,
paths: PromptCachePaths,
stats: PromptCacheStats,
previous: Option<TrackedPromptState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CompletionCacheEntry {
cached_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")]
fingerprint_version: u32,
response: MessageResponse,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct TrackedPromptState {
observed_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")]
fingerprint_version: u32,
model_hash: u64,
system_hash: u64,
tools_hash: u64,
messages_hash: u64,
cache_read_input_tokens: u32,
}
impl TrackedPromptState {
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
let hashes = RequestFingerprints::from_request(request);
Self {
observed_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(),
model_hash: hashes.model,
system_hash: hashes.system,
tools_hash: hashes.tools,
messages_hash: hashes.messages,
cache_read_input_tokens: usage.cache_read_input_tokens,
}
}
}
#[derive(Debug, Clone, Copy)]
struct RequestFingerprints {
model: u64,
system: u64,
tools: u64,
messages: u64,
}
impl RequestFingerprints {
fn from_request(request: &MessageRequest) -> Self {
Self {
model: hash_serializable(&request.model),
system: hash_serializable(&request.system),
tools: hash_serializable(&request.tools),
messages: hash_serializable(&request.messages),
}
}
}
fn detect_cache_break(
config: &PromptCacheConfig,
previous: Option<&TrackedPromptState>,
current: &TrackedPromptState,
) -> Option<CacheBreakEvent> {
let previous = previous?;
if previous.fingerprint_version != current.fingerprint_version {
return Some(CacheBreakEvent {
unexpected: false,
reason: format!(
"fingerprint version changed (v{} -> v{})",
previous.fingerprint_version, current.fingerprint_version
),
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
current_cache_read_input_tokens: current.cache_read_input_tokens,
token_drop: previous
.cache_read_input_tokens
.saturating_sub(current.cache_read_input_tokens),
});
}
let token_drop = previous
.cache_read_input_tokens
.saturating_sub(current.cache_read_input_tokens);
if token_drop < config.cache_break_min_drop {
return None;
}
let mut reasons = Vec::new();
if previous.model_hash != current.model_hash {
reasons.push("model changed");
}
if previous.system_hash != current.system_hash {
reasons.push("system prompt changed");
}
if previous.tools_hash != current.tools_hash {
reasons.push("tool definitions changed");
}
if previous.messages_hash != current.messages_hash {
reasons.push("message payload changed");
}
let elapsed = current
.observed_at_unix_secs
.saturating_sub(previous.observed_at_unix_secs);
let (unexpected, reason) = if reasons.is_empty() {
if elapsed > config.prompt_ttl.as_secs() {
(
false,
format!("possible prompt cache TTL expiry after {elapsed}s"),
)
} else {
(
true,
"cache read tokens dropped while prompt fingerprint remained stable".to_string(),
)
}
} else {
(false, reasons.join(", "))
};
Some(CacheBreakEvent {
unexpected,
reason,
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
current_cache_read_input_tokens: current.cache_read_input_tokens,
token_drop,
})
}
fn apply_usage_to_stats(
stats: &mut PromptCacheStats,
usage: &Usage,
request_hash: &str,
source: &str,
) {
stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens);
stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens);
stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens);
stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens);
stats.last_request_hash = Some(request_hash.to_string());
stats.last_cache_source = Some(source.to_string());
}
fn persist_state(inner: &PromptCacheInner) {
let _ = ensure_cache_dirs(&inner.paths);
let _ = write_json(&inner.paths.stats_path, &inner.stats);
if let Some(previous) = &inner.previous {
let _ = write_json(&inner.paths.session_state_path, previous);
}
}
fn write_completion_entry(
paths: &PromptCachePaths,
request_hash: &str,
response: &MessageResponse,
) {
let _ = ensure_cache_dirs(paths);
let entry = CompletionCacheEntry {
cached_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(),
response: response.clone(),
};
let _ = write_json(&paths.completion_entry_path(request_hash), &entry);
}
fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> {
fs::create_dir_all(&paths.completion_dir)
}
fn write_json<T: Serialize>(path: &Path, value: &T) -> std::io::Result<()> {
let json = serde_json::to_vec_pretty(value)
.map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?;
fs::write(path, json)
}
fn read_json<T: for<'de> Deserialize<'de>>(path: &Path) -> Option<T> {
let bytes = fs::read(path).ok()?;
serde_json::from_slice(&bytes).ok()
}
fn request_hash_hex(request: &MessageRequest) -> String {
format!(
"{REQUEST_FINGERPRINT_PREFIX}-{:016x}",
hash_serializable(request)
)
}
fn hash_serializable<T: Serialize>(value: &T) -> u64 {
let json = serde_json::to_vec(value).unwrap_or_default();
stable_hash_bytes(&json)
}
fn sanitize_path_segment(value: &str) -> String {
let sanitized: String = value
.chars()
.map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' })
.collect();
if sanitized.len() <= MAX_SANITIZED_LENGTH {
return sanitized;
}
let suffix = format!("-{:x}", hash_string(value));
format!(
"{}{}",
&sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())],
suffix
)
}
fn hash_string(value: &str) -> u64 {
stable_hash_bytes(value.as_bytes())
}
fn base_cache_root() -> PathBuf {
if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") {
return PathBuf::from(config_home)
.join("cache")
.join("prompt-cache");
}
if let Some(home) = std::env::var_os("HOME") {
return PathBuf::from(home)
.join(".claude")
.join("cache")
.join("prompt-cache");
}
std::env::temp_dir().join("claude-prompt-cache")
}
fn now_unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |duration| duration.as_secs())
}
const fn current_fingerprint_version() -> u32 {
REQUEST_FINGERPRINT_VERSION
}
fn stable_hash_bytes(bytes: &[u8]) -> u64 {
let mut hash = FNV_OFFSET_BASIS;
for byte in bytes {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(FNV_PRIME);
}
hash
}
#[cfg(test)]
mod tests {
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use super::{
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
};
use crate::test_env_lock;
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
#[test]
fn path_builder_sanitizes_session_identifier() {
let paths = PromptCachePaths::for_session("session:/with spaces");
let session_dir = paths
.session_dir
.file_name()
.and_then(|value| value.to_str())
.expect("session dir name");
assert_eq!(session_dir, "session--with-spaces");
assert!(paths.completion_dir.ends_with("completions"));
assert!(paths.stats_path.ends_with("stats.json"));
assert!(paths.session_state_path.ends_with("session-state.json"));
}
#[test]
fn request_fingerprint_drives_unexpected_break_detection() {
let request = sample_request("same");
let previous = TrackedPromptState::from_usage(
&request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 6_000,
output_tokens: 0,
},
);
let current = TrackedPromptState::from_usage(
&request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 1_000,
output_tokens: 0,
},
);
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
.expect("break should be detected");
assert!(event.unexpected);
assert!(event.reason.contains("stable"));
}
#[test]
fn changed_prompt_marks_break_as_expected() {
let previous_request = sample_request("first");
let current_request = sample_request("second");
let previous = TrackedPromptState::from_usage(
&previous_request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 6_000,
output_tokens: 0,
},
);
let current = TrackedPromptState::from_usage(
&current_request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 1_000,
output_tokens: 0,
},
);
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
.expect("break should be detected");
assert!(!event.unexpected);
assert!(event.reason.contains("message payload changed"));
}
#[test]
fn completion_cache_round_trip_persists_recent_response() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-test-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("unit-test-session");
let request = sample_request("cache me");
let response = sample_response(42, 12, "cached");
assert!(cache.lookup_completion(&request).is_none());
let record = cache.record_response(&request, &response);
assert!(record.cache_break.is_none());
let cached = cache
.lookup_completion(&request)
.expect("cached response should load");
assert_eq!(cached.content, response.content);
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 1);
assert_eq!(stats.completion_cache_misses, 1);
assert_eq!(stats.completion_cache_writes, 1);
let persisted = read_json::<super::PromptCacheStats>(&cache.paths().stats_path)
.expect("stats should persist");
assert_eq!(persisted.completion_cache_hits, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn distinct_requests_do_not_collide_in_completion_cache() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-distinct-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("distinct-request-session");
let first_request = sample_request("first");
let second_request = sample_request("second");
let response = sample_response(42, 12, "cached");
let _ = cache.record_response(&first_request, &response);
assert!(cache.lookup_completion(&second_request).is_none());
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn expired_completion_entries_are_not_reused() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-expired-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::with_config(PromptCacheConfig {
session_id: "expired-session".to_string(),
completion_ttl: Duration::ZERO,
..PromptCacheConfig::default()
});
let request = sample_request("expire me");
let response = sample_response(7, 3, "stale");
let _ = cache.record_response(&request, &response);
assert!(cache.lookup_completion(&request).is_none());
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 0);
assert_eq!(stats.completion_cache_misses, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn sanitize_path_caps_long_values() {
let long_value = "x".repeat(200);
let sanitized = sanitize_path_segment(&long_value);
assert!(sanitized.len() <= 80);
}
#[test]
fn request_hashes_are_versioned_and_stable() {
let request = sample_request("stable");
let first = request_hash_hex(&request);
let second = request_hash_hex(&request);
assert_eq!(first, second);
assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX));
}
fn sample_request(text: &str) -> MessageRequest {
MessageRequest {
model: "claude-3-7-sonnet-latest".to_string(),
max_tokens: 64,
messages: vec![InputMessage::user_text(text)],
system: Some("system".to_string()),
tools: None,
tool_choice: None,
stream: false,
}
}
fn sample_response(
cache_read_input_tokens: u32,
output_tokens: u32,
text: &str,
) -> MessageResponse {
MessageResponse {
id: "msg_test".to_string(),
kind: "message".to_string(),
role: "assistant".to_string(),
content: vec![OutputContentBlock::Text {
text: text.to_string(),
}],
model: "claude-3-7-sonnet-latest".to_string(),
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
usage: Usage {
input_tokens: 10,
cache_creation_input_tokens: 5,
cache_read_input_tokens,
output_tokens,
},
request_id: Some("req_test".to_string()),
}
}
}

View File

@@ -1,17 +1,25 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::{Mutex as StdMutex, OnceLock};
use std::time::Duration; use std::time::Duration;
use api::{ use api::{
AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
StreamEvent, ToolChoice, ToolDefinition, PromptCache, StreamEvent, ToolChoice, ToolDefinition,
}; };
use serde_json::json; use serde_json::json;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::Mutex; use tokio::sync::Mutex;
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
LOCK.get_or_init(|| StdMutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[tokio::test] #[tokio::test]
async fn send_message_posts_json_and_parses_response() { async fn send_message_posts_json_and_parses_response() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new())); let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
@@ -45,6 +53,8 @@ async fn send_message_posts_json_and_parses_response() {
assert_eq!(response.id, "msg_test"); assert_eq!(response.id, "msg_test");
assert_eq!(response.total_tokens(), 16); assert_eq!(response.total_tokens(), 16);
assert_eq!(response.request_id.as_deref(), Some("req_body_123")); assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
assert_eq!(response.usage.cache_creation_input_tokens, 0);
assert_eq!(response.usage.cache_read_input_tokens, 0);
assert_eq!( assert_eq!(
response.content, response.content,
vec![OutputContentBlock::Text { vec![OutputContentBlock::Text {
@@ -76,11 +86,55 @@ async fn send_message_posts_json_and_parses_response() {
} }
#[tokio::test] #[tokio::test]
async fn send_message_parses_prompt_cache_token_usage_from_response() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let body = concat!(
"{",
"\"id\":\"msg_cache_tokens\",",
"\"type\":\"message\",",
"\"role\":\"assistant\",",
"\"content\":[{\"type\":\"text\",\"text\":\"Cache tokens\"}],",
"\"model\":\"claude-3-7-sonnet-latest\",",
"\"stop_reason\":\"end_turn\",",
"\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":12,\"cache_creation_input_tokens\":321,\"cache_read_input_tokens\":654,\"output_tokens\":4}",
"}"
);
let server = spawn_server(
state,
vec![http_response("200 OK", "application/json", body)],
)
.await;
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
let response = client
.send_message(&sample_request(false))
.await
.expect("request should succeed");
assert_eq!(response.usage.input_tokens, 12);
assert_eq!(response.usage.cache_creation_input_tokens, 321);
assert_eq!(response.usage.cache_read_input_tokens, 654);
assert_eq!(response.usage.output_tokens, 4);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn stream_message_parses_sse_events_with_tool_use() { async fn stream_message_parses_sse_events_with_tool_use() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-stream-cache-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new())); let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let sse = concat!( let sse = concat!(
"event: message_start\n", "event: message_start\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":13,\"cache_read_input_tokens\":21,\"output_tokens\":0}}}\n\n",
"event: content_block_start\n", "event: content_block_start\n",
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n", "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
"event: content_block_delta\n", "event: content_block_delta\n",
@@ -88,7 +142,7 @@ async fn stream_message_parses_sse_events_with_tool_use() {
"event: content_block_stop\n", "event: content_block_stop\n",
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
"event: message_delta\n", "event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n", "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":34,\"cache_read_input_tokens\":55,\"output_tokens\":1}}\n\n",
"event: message_stop\n", "event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n\n", "data: {\"type\":\"message_stop\"}\n\n",
"data: [DONE]\n\n" "data: [DONE]\n\n"
@@ -106,7 +160,8 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let client = AnthropicClient::new("test-key") let client = AnthropicClient::new("test-key")
.with_auth_token(Some("proxy-token".to_string())) .with_auth_token(Some("proxy-token".to_string()))
.with_base_url(server.base_url()); .with_base_url(server.base_url())
.with_prompt_cache(PromptCache::new("stream-session"));
let mut stream = client let mut stream = client
.stream_message(&sample_request(false)) .stream_message(&sample_request(false))
.await .await
@@ -160,6 +215,20 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let captured = state.lock().await; let captured = state.lock().await;
let request = captured.first().expect("server should capture request"); let request = captured.first().expect("server should capture request");
assert!(request.body.contains("\"stream\":true")); assert!(request.body.contains("\"stream\":true"));
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.tracked_requests, 1);
assert_eq!(cache_stats.last_cache_creation_input_tokens, Some(34));
assert_eq!(cache_stats.last_cache_read_input_tokens, Some(55));
assert_eq!(
cache_stats.last_cache_source.as_deref(),
Some("api-response")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
} }
#[tokio::test] #[tokio::test]
@@ -243,6 +312,121 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
} }
} }
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_reuses_recent_completion_cache_entries() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-prompt-cache-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}",
)],
)
.await;
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::new("integration-session"));
let first = client
.send_message(&sample_request(false))
.await
.expect("first request should succeed");
let second = client
.send_message(&sample_request(false))
.await
.expect("second request should reuse cache");
assert_eq!(first.content, second.content);
assert_eq!(state.lock().await.len(), 1);
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.completion_cache_hits, 1);
assert_eq!(cache_stats.completion_cache_misses, 1);
assert_eq!(cache_stats.completion_cache_writes, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_tracks_unexpected_prompt_cache_breaks() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-prompt-break-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state,
vec![
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}",
),
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}",
),
],
)
.await;
let request = sample_request(false);
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::with_config(api::PromptCacheConfig {
session_id: "break-session".to_string(),
completion_ttl: Duration::from_secs(0),
..api::PromptCacheConfig::default()
}));
client
.send_message(&request)
.await
.expect("first response should succeed");
client
.send_message(&request)
.await
.expect("second response should succeed");
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.unexpected_cache_breaks, 1);
assert_eq!(
cache_stats.last_break_reason.as_deref(),
Some("cache read tokens dropped while prompt fingerprint remained stable")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test] #[tokio::test]
#[ignore = "requires ANTHROPIC_API_KEY and network access"] #[ignore = "requires ANTHROPIC_API_KEY and network access"]
async fn live_stream_smoke_test() { async fn live_stream_smoke_test() {

View File

@@ -125,8 +125,8 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[
}, },
SlashCommandSpec { SlashCommandSpec {
name: "session", name: "session",
summary: "List, switch, or fork managed local sessions", summary: "List or switch managed local sessions",
argument_hint: Some("[list|switch <session-id>|fork [branch-name]]"), argument_hint: Some("[list|switch <session-id>]"),
resume_supported: false, resume_supported: false,
}, },
]; ];
@@ -229,7 +229,7 @@ pub fn resume_supported_slash_commands() -> Vec<&'static SlashCommandSpec> {
pub fn render_slash_command_help() -> String { pub fn render_slash_command_help() -> String {
let mut lines = vec![ let mut lines = vec![
"Slash commands".to_string(), "Slash commands".to_string(),
" [resume] means the command also works with --resume SESSION.jsonl".to_string(), " [resume] means the command also works with --resume SESSION.json".to_string(),
]; ];
for spec in slash_command_specs() { for spec in slash_command_specs() {
let name = match spec.argument_hint { let name = match spec.argument_hint {
@@ -365,19 +365,12 @@ mod tests {
target: Some("abc123".to_string()) target: Some("abc123".to_string())
}) })
); );
assert_eq!(
SlashCommand::parse("/session fork incident-review"),
Some(SlashCommand::Session {
action: Some("fork".to_string()),
target: Some("incident-review".to_string())
})
);
} }
#[test] #[test]
fn renders_help_from_shared_specs() { fn renders_help_from_shared_specs() {
let help = render_slash_command_help(); let help = render_slash_command_help();
assert!(help.contains("works with --resume SESSION.jsonl")); assert!(help.contains("works with --resume SESSION.json"));
assert!(help.contains("/help")); assert!(help.contains("/help"));
assert!(help.contains("/status")); assert!(help.contains("/status"));
assert!(help.contains("/compact")); assert!(help.contains("/compact"));
@@ -392,24 +385,26 @@ mod tests {
assert!(help.contains("/diff")); assert!(help.contains("/diff"));
assert!(help.contains("/version")); assert!(help.contains("/version"));
assert!(help.contains("/export [file]")); assert!(help.contains("/export [file]"));
assert!(help.contains("/session [list|switch <session-id>|fork [branch-name]]")); assert!(help.contains("/session [list|switch <session-id>]"));
assert_eq!(slash_command_specs().len(), 15); assert_eq!(slash_command_specs().len(), 15);
assert_eq!(resume_supported_slash_commands().len(), 11); assert_eq!(resume_supported_slash_commands().len(), 11);
} }
#[test] #[test]
fn compacts_sessions_via_slash_command() { fn compacts_sessions_via_slash_command() {
let mut session = Session::new(); let session = Session {
session.messages = vec![ version: 1,
ConversationMessage::user_text("a ".repeat(200)), messages: vec![
ConversationMessage::assistant(vec![ContentBlock::Text { ConversationMessage::user_text("a ".repeat(200)),
text: "b ".repeat(200), ConversationMessage::assistant(vec![ContentBlock::Text {
}]), text: "b ".repeat(200),
ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false), }]),
ConversationMessage::assistant(vec![ContentBlock::Text { ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false),
text: "recent".to_string(), ConversationMessage::assistant(vec![ContentBlock::Text {
}]), text: "recent".to_string(),
]; }]),
],
};
let result = handle_slash_command( let result = handle_slash_command(
"/compact", "/compact",
@@ -460,12 +455,6 @@ mod tests {
CompactionConfig::default() CompactionConfig::default()
) )
.is_none()); .is_none());
assert!(handle_slash_command(
"/resume session.jsonl",
&session,
CompactionConfig::default()
)
.is_none());
assert!(handle_slash_command("/config", &session, CompactionConfig::default()).is_none()); assert!(handle_slash_command("/config", &session, CompactionConfig::default()).is_none());
assert!( assert!(
handle_slash_command("/config env", &session, CompactionConfig::default()).is_none() handle_slash_command("/config env", &session, CompactionConfig::default()).is_none()

View File

@@ -99,14 +99,13 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
}]; }];
compacted_messages.extend(preserved); compacted_messages.extend(preserved);
let mut compacted_session = session.clone();
compacted_session.messages = compacted_messages;
compacted_session.record_compaction(summary.clone(), removed.len());
CompactionResult { CompactionResult {
summary, summary,
formatted_summary, formatted_summary,
compacted_session, compacted_session: Session {
version: session.version,
messages: compacted_messages,
},
removed_message_count: removed.len(), removed_message_count: removed.len(),
} }
} }
@@ -391,8 +390,10 @@ mod tests {
#[test] #[test]
fn leaves_small_sessions_unchanged() { fn leaves_small_sessions_unchanged() {
let mut session = Session::new(); let session = Session {
session.messages = vec![ConversationMessage::user_text("hello")]; version: 1,
messages: vec![ConversationMessage::user_text("hello")],
};
let result = compact_session(&session, CompactionConfig::default()); let result = compact_session(&session, CompactionConfig::default());
assert_eq!(result.removed_message_count, 0); assert_eq!(result.removed_message_count, 0);
@@ -403,21 +404,23 @@ mod tests {
#[test] #[test]
fn compacts_older_messages_into_a_system_summary() { fn compacts_older_messages_into_a_system_summary() {
let mut session = Session::new(); let session = Session {
session.messages = vec![ version: 1,
ConversationMessage::user_text("one ".repeat(200)), messages: vec![
ConversationMessage::assistant(vec![ContentBlock::Text { ConversationMessage::user_text("one ".repeat(200)),
text: "two ".repeat(200), ConversationMessage::assistant(vec![ContentBlock::Text {
}]), text: "two ".repeat(200),
ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false), }]),
ConversationMessage { ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false),
role: MessageRole::Assistant, ConversationMessage {
blocks: vec![ContentBlock::Text { role: MessageRole::Assistant,
text: "recent".to_string(), blocks: vec![ContentBlock::Text {
}], text: "recent".to_string(),
usage: None, }],
}, usage: None,
]; },
],
};
let result = compact_session( let result = compact_session(
&session, &session,

View File

@@ -25,9 +25,19 @@ pub enum AssistantEvent {
input: String, input: String,
}, },
Usage(TokenUsage), Usage(TokenUsage),
PromptCache(PromptCacheEvent),
MessageStop, MessageStop,
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptCacheEvent {
pub unexpected: bool,
pub reason: String,
pub previous_cache_read_input_tokens: u32,
pub current_cache_read_input_tokens: u32,
pub token_drop: u32,
}
pub trait ApiClient { pub trait ApiClient {
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>; fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
} }
@@ -84,6 +94,7 @@ impl std::error::Error for RuntimeError {}
pub struct TurnSummary { pub struct TurnSummary {
pub assistant_messages: Vec<ConversationMessage>, pub assistant_messages: Vec<ConversationMessage>,
pub tool_results: Vec<ConversationMessage>, pub tool_results: Vec<ConversationMessage>,
pub prompt_cache_events: Vec<PromptCacheEvent>,
pub iterations: usize, pub iterations: usize,
pub usage: TokenUsage, pub usage: TokenUsage,
} }
@@ -118,7 +129,7 @@ where
tool_executor, tool_executor,
permission_policy, permission_policy,
system_prompt, system_prompt,
RuntimeFeatureConfig::default(), &RuntimeFeatureConfig::default(),
) )
} }
@@ -129,7 +140,7 @@ where
tool_executor: T, tool_executor: T,
permission_policy: PermissionPolicy, permission_policy: PermissionPolicy,
system_prompt: Vec<String>, system_prompt: Vec<String>,
feature_config: RuntimeFeatureConfig, feature_config: &RuntimeFeatureConfig,
) -> Self { ) -> Self {
let usage_tracker = UsageTracker::from_session(&session); let usage_tracker = UsageTracker::from_session(&session);
Self { Self {
@@ -140,7 +151,7 @@ where
system_prompt, system_prompt,
max_iterations: usize::MAX, max_iterations: usize::MAX,
usage_tracker, usage_tracker,
hook_runner: HookRunner::from_feature_config(&feature_config), hook_runner: HookRunner::from_feature_config(feature_config),
} }
} }
@@ -156,11 +167,12 @@ where
mut prompter: Option<&mut dyn PermissionPrompter>, mut prompter: Option<&mut dyn PermissionPrompter>,
) -> Result<TurnSummary, RuntimeError> { ) -> Result<TurnSummary, RuntimeError> {
self.session self.session
.push_user_text(user_input.into()) .messages
.map_err(|error| RuntimeError::new(error.to_string()))?; .push(ConversationMessage::user_text(user_input.into()));
let mut assistant_messages = Vec::new(); let mut assistant_messages = Vec::new();
let mut tool_results = Vec::new(); let mut tool_results = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut iterations = 0; let mut iterations = 0;
loop { loop {
@@ -176,10 +188,12 @@ where
messages: self.session.messages.clone(), messages: self.session.messages.clone(),
}; };
let events = self.api_client.stream(request)?; let events = self.api_client.stream(request)?;
let (assistant_message, usage) = build_assistant_message(events)?; let (assistant_message, usage, turn_prompt_cache_events) =
build_assistant_message(events)?;
if let Some(usage) = usage { if let Some(usage) = usage {
self.usage_tracker.record(usage); self.usage_tracker.record(usage);
} }
prompt_cache_events.extend(turn_prompt_cache_events);
let pending_tool_uses = assistant_message let pending_tool_uses = assistant_message
.blocks .blocks
.iter() .iter()
@@ -191,9 +205,7 @@ where
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.session self.session.messages.push(assistant_message.clone());
.push_message(assistant_message.clone())
.map_err(|error| RuntimeError::new(error.to_string()))?;
assistant_messages.push(assistant_message); assistant_messages.push(assistant_message);
if pending_tool_uses.is_empty() { if pending_tool_uses.is_empty() {
@@ -251,9 +263,7 @@ where
ConversationMessage::tool_result(tool_use_id, tool_name, reason, true) ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
} }
}; };
self.session self.session.messages.push(result_message.clone());
.push_message(result_message.clone())
.map_err(|error| RuntimeError::new(error.to_string()))?;
tool_results.push(result_message); tool_results.push(result_message);
} }
} }
@@ -261,6 +271,7 @@ where
Ok(TurnSummary { Ok(TurnSummary {
assistant_messages, assistant_messages,
tool_results, tool_results,
prompt_cache_events,
iterations, iterations,
usage: self.usage_tracker.cumulative_usage(), usage: self.usage_tracker.cumulative_usage(),
}) })
@@ -286,11 +297,6 @@ where
&self.session &self.session
} }
#[must_use]
pub fn fork_session(&self, branch_name: Option<String>) -> Session {
self.session.fork(branch_name)
}
#[must_use] #[must_use]
pub fn into_session(self) -> Session { pub fn into_session(self) -> Session {
self.session self.session
@@ -299,9 +305,17 @@ where
fn build_assistant_message( fn build_assistant_message(
events: Vec<AssistantEvent>, events: Vec<AssistantEvent>,
) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> { ) -> Result<
(
ConversationMessage,
Option<TokenUsage>,
Vec<PromptCacheEvent>,
),
RuntimeError,
> {
let mut text = String::new(); let mut text = String::new();
let mut blocks = Vec::new(); let mut blocks = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut finished = false; let mut finished = false;
let mut usage = None; let mut usage = None;
@@ -313,6 +327,7 @@ fn build_assistant_message(
blocks.push(ContentBlock::ToolUse { id, name, input }); blocks.push(ContentBlock::ToolUse { id, name, input });
} }
AssistantEvent::Usage(value) => usage = Some(value), AssistantEvent::Usage(value) => usage = Some(value),
AssistantEvent::PromptCache(event) => prompt_cache_events.push(event),
AssistantEvent::MessageStop => { AssistantEvent::MessageStop => {
finished = true; finished = true;
} }
@@ -333,6 +348,7 @@ fn build_assistant_message(
Ok(( Ok((
ConversationMessage::assistant_with_usage(blocks, usage), ConversationMessage::assistant_with_usage(blocks, usage),
usage, usage,
prompt_cache_events,
)) ))
} }
@@ -405,7 +421,7 @@ impl ToolExecutor for StaticToolExecutor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ use super::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
StaticToolExecutor, StaticToolExecutor,
}; };
use crate::compact::CompactionConfig; use crate::compact::CompactionConfig;
@@ -417,9 +433,7 @@ mod tests {
use crate::prompt::{ProjectContext, SystemPromptBuilder}; use crate::prompt::{ProjectContext, SystemPromptBuilder};
use crate::session::{ContentBlock, MessageRole, Session}; use crate::session::{ContentBlock, MessageRole, Session};
use crate::usage::TokenUsage; use crate::usage::TokenUsage;
use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
struct ScriptedApiClient { struct ScriptedApiClient {
call_count: usize, call_count: usize,
@@ -464,6 +478,15 @@ mod tests {
cache_creation_input_tokens: 1, cache_creation_input_tokens: 1,
cache_read_input_tokens: 3, cache_read_input_tokens: 3,
}), }),
AssistantEvent::PromptCache(PromptCacheEvent {
unexpected: true,
reason:
"cache read tokens dropped while prompt fingerprint remained stable"
.to_string(),
previous_cache_read_input_tokens: 6_000,
current_cache_read_input_tokens: 1_000,
token_drop: 5_000,
}),
AssistantEvent::MessageStop, AssistantEvent::MessageStop,
]) ])
} }
@@ -517,8 +540,10 @@ mod tests {
assert_eq!(summary.iterations, 2); assert_eq!(summary.iterations, 2);
assert_eq!(summary.assistant_messages.len(), 2); assert_eq!(summary.assistant_messages.len(), 2);
assert_eq!(summary.tool_results.len(), 1); assert_eq!(summary.tool_results.len(), 1);
assert_eq!(summary.prompt_cache_events.len(), 1);
assert_eq!(runtime.session().messages.len(), 4); assert_eq!(runtime.session().messages.len(), 4);
assert_eq!(summary.usage.output_tokens, 10); assert_eq!(summary.usage.output_tokens, 10);
assert!(summary.prompt_cache_events[0].unexpected);
assert!(matches!( assert!(matches!(
runtime.session().messages[1].blocks[1], runtime.session().messages[1].blocks[1],
ContentBlock::ToolUse { .. } ContentBlock::ToolUse { .. }
@@ -620,7 +645,7 @@ mod tests {
}), }),
PermissionPolicy::new(PermissionMode::DangerFullAccess), PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()], vec!["system".to_string()],
RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'blocked by hook'; exit 2")], vec![shell_snippet("printf 'blocked by hook'; exit 2")],
Vec::new(), Vec::new(),
)), )),
@@ -686,7 +711,7 @@ mod tests {
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
PermissionPolicy::new(PermissionMode::DangerFullAccess), PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()], vec!["system".to_string()],
RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'pre hook ran'")],
vec![shell_snippet("printf 'post hook ran'")], vec![shell_snippet("printf 'post hook ran'")],
)), )),
@@ -708,7 +733,7 @@ mod tests {
"post hook should preserve non-error result: {output:?}" "post hook should preserve non-error result: {output:?}"
); );
assert!( assert!(
output.contains("4"), output.contains('4'),
"tool output missing value: {output:?}" "tool output missing value: {output:?}"
); );
assert!( assert!(
@@ -798,86 +823,6 @@ mod tests {
result.compacted_session.messages[0].role, result.compacted_session.messages[0].role,
MessageRole::System MessageRole::System
); );
assert_eq!(
result.compacted_session.session_id,
runtime.session().session_id
);
assert!(result.compacted_session.compaction.is_some());
}
#[test]
fn persists_conversation_turn_messages_to_jsonl_session() {
struct SimpleApi;
impl ApiClient for SimpleApi {
fn stream(
&mut self,
_request: ApiRequest,
) -> Result<Vec<AssistantEvent>, RuntimeError> {
Ok(vec![
AssistantEvent::TextDelta("done".to_string()),
AssistantEvent::MessageStop,
])
}
}
let path = temp_session_path("persisted-turn");
let session = Session::new().with_persistence_path(path.clone());
let mut runtime = ConversationRuntime::new(
session,
SimpleApi,
StaticToolExecutor::new(),
PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()],
);
runtime
.run_turn("persist this turn", None)
.expect("turn should succeed");
let restored = Session::load_from_path(&path).expect("persisted session should reload");
fs::remove_file(&path).expect("temp session file should be removable");
assert_eq!(restored.messages.len(), 2);
assert_eq!(restored.messages[0].role, MessageRole::User);
assert_eq!(restored.messages[1].role, MessageRole::Assistant);
assert_eq!(restored.session_id, runtime.session().session_id);
}
#[test]
fn forks_runtime_session_without_mutating_original() {
let mut session = Session::new();
session
.push_user_text("branch me")
.expect("message should append");
let runtime = ConversationRuntime::new(
session.clone(),
ScriptedApiClient { call_count: 0 },
StaticToolExecutor::new(),
PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()],
);
let forked = runtime.fork_session(Some("alt-path".to_string()));
assert_eq!(forked.messages, session.messages);
assert_ne!(forked.session_id, session.session_id);
assert_eq!(
forked
.fork
.as_ref()
.map(|fork| (fork.parent_session_id.as_str(), fork.branch_name.as_deref())),
Some((session.session_id.as_str(), Some("alt-path")))
);
assert!(runtime.session().fork.is_none());
}
fn temp_session_path(label: &str) -> PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after epoch")
.as_nanos();
std::env::temp_dir().join(format!("runtime-conversation-{label}-{nanos}.json"))
} }
#[cfg(windows)] #[cfg(windows)]

View File

@@ -64,7 +64,7 @@ impl HookRunner {
#[must_use] #[must_use]
pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
self.run_commands( Self::run_commands(
HookEvent::PreToolUse, HookEvent::PreToolUse,
self.config.pre_tool_use(), self.config.pre_tool_use(),
tool_name, tool_name,
@@ -82,7 +82,7 @@ impl HookRunner {
tool_output: &str, tool_output: &str,
is_error: bool, is_error: bool,
) -> HookRunResult { ) -> HookRunResult {
self.run_commands( Self::run_commands(
HookEvent::PostToolUse, HookEvent::PostToolUse,
self.config.post_tool_use(), self.config.post_tool_use(),
tool_name, tool_name,
@@ -93,7 +93,6 @@ impl HookRunner {
} }
fn run_commands( fn run_commands(
&self,
event: HookEvent, event: HookEvent,
commands: &[String], commands: &[String],
tool_name: &str, tool_name: &str,
@@ -118,7 +117,7 @@ impl HookRunner {
let mut messages = Vec::new(); let mut messages = Vec::new();
for command in commands { for command in commands {
match self.run_command( match Self::run_command(
command, command,
event, event,
tool_name, tool_name,
@@ -150,7 +149,6 @@ impl HookRunner {
} }
fn run_command( fn run_command(
&self,
command: &str, command: &str,
event: HookEvent, event: HookEvent,
tool_name: &str, tool_name: &str,

View File

@@ -31,8 +31,8 @@ pub use config::{
ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
}; };
pub use conversation::{ pub use conversation::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor, ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
ToolError, ToolExecutor, TurnSummary, StaticToolExecutor, ToolError, ToolExecutor, TurnSummary,
}; };
pub use file_ops::{ pub use file_ops::{
edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput, edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput,
@@ -76,10 +76,7 @@ pub use remote::{
RemoteSessionContext, UpstreamProxyBootstrap, UpstreamProxyState, DEFAULT_REMOTE_BASE_URL, RemoteSessionContext, UpstreamProxyBootstrap, UpstreamProxyState, DEFAULT_REMOTE_BASE_URL,
DEFAULT_SESSION_TOKEN_PATH, DEFAULT_SYSTEM_CA_BUNDLE, NO_PROXY_HOSTS, UPSTREAM_PROXY_ENV_KEYS, DEFAULT_SESSION_TOKEN_PATH, DEFAULT_SYSTEM_CA_BUNDLE, NO_PROXY_HOSTS, UPSTREAM_PROXY_ENV_KEYS,
}; };
pub use session::{ pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
ContentBlock, ConversationMessage, MessageRole, Session, SessionCompaction, SessionError,
SessionFork,
};
pub use usage::{ pub use usage::{
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker, format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
}; };

View File

@@ -1,19 +1,11 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::fs::{self, OpenOptions}; use std::fs;
use std::io::Write; use std::path::Path;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use crate::json::{JsonError, JsonValue}; use crate::json::{JsonError, JsonValue};
use crate::usage::TokenUsage; use crate::usage::TokenUsage;
const SESSION_VERSION: u32 = 1;
const ROTATE_AFTER_BYTES: u64 = 256 * 1024;
const MAX_ROTATED_FILES: usize = 3;
static SESSION_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageRole { pub enum MessageRole {
System, System,
@@ -48,49 +40,11 @@ pub struct ConversationMessage {
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct SessionCompaction {
pub count: u32,
pub removed_message_count: usize,
pub summary: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SessionFork {
pub parent_session_id: String,
pub branch_name: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct SessionPersistence {
path: PathBuf,
}
#[derive(Debug, Clone)]
pub struct Session { pub struct Session {
pub version: u32, pub version: u32,
pub session_id: String,
pub created_at_ms: u64,
pub updated_at_ms: u64,
pub messages: Vec<ConversationMessage>, pub messages: Vec<ConversationMessage>,
pub compaction: Option<SessionCompaction>,
pub fork: Option<SessionFork>,
persistence: Option<SessionPersistence>,
} }
impl PartialEq for Session {
fn eq(&self, other: &Self) -> bool {
self.version == other.version
&& self.session_id == other.session_id
&& self.created_at_ms == other.created_at_ms
&& self.updated_at_ms == other.updated_at_ms
&& self.messages == other.messages
&& self.compaction == other.compaction
&& self.fork == other.fork
}
}
impl Eq for Session {}
#[derive(Debug)] #[derive(Debug)]
pub enum SessionError { pub enum SessionError {
Io(std::io::Error), Io(std::io::Error),
@@ -125,84 +79,20 @@ impl From<JsonError> for SessionError {
impl Session { impl Session {
#[must_use] #[must_use]
pub fn new() -> Self { pub fn new() -> Self {
let now = current_time_millis();
Self { Self {
version: SESSION_VERSION, version: 1,
session_id: generate_session_id(),
created_at_ms: now,
updated_at_ms: now,
messages: Vec::new(), messages: Vec::new(),
compaction: None,
fork: None,
persistence: None,
} }
} }
#[must_use]
pub fn with_persistence_path(mut self, path: impl Into<PathBuf>) -> Self {
self.persistence = Some(SessionPersistence { path: path.into() });
self
}
#[must_use]
pub fn persistence_path(&self) -> Option<&Path> {
self.persistence.as_ref().map(|value| value.path.as_path())
}
pub fn save_to_path(&self, path: impl AsRef<Path>) -> Result<(), SessionError> { pub fn save_to_path(&self, path: impl AsRef<Path>) -> Result<(), SessionError> {
let path = path.as_ref(); fs::write(path, self.to_json().render())?;
rotate_session_file_if_needed(path)?;
write_atomic(path, &self.render_jsonl_snapshot())?;
cleanup_rotated_logs(path)?;
Ok(()) Ok(())
} }
pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, SessionError> { pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, SessionError> {
let path = path.as_ref();
let contents = fs::read_to_string(path)?; let contents = fs::read_to_string(path)?;
let session = match JsonValue::parse(&contents) { Self::from_json(&JsonValue::parse(&contents)?)
Ok(value) => Self::from_json(&value)?,
Err(_) => Self::from_jsonl(&contents)?,
};
Ok(session.with_persistence_path(path.to_path_buf()))
}
pub fn push_message(&mut self, message: ConversationMessage) -> Result<(), SessionError> {
self.touch();
self.messages.push(message.clone());
self.append_persisted_message(&message)
}
pub fn push_user_text(&mut self, text: impl Into<String>) -> Result<(), SessionError> {
self.push_message(ConversationMessage::user_text(text))
}
pub fn record_compaction(&mut self, summary: impl Into<String>, removed_message_count: usize) {
self.touch();
let count = self.compaction.as_ref().map_or(1, |value| value.count + 1);
self.compaction = Some(SessionCompaction {
count,
removed_message_count,
summary: summary.into(),
});
}
#[must_use]
pub fn fork(&self, branch_name: Option<String>) -> Self {
let now = current_time_millis();
Self {
version: self.version,
session_id: generate_session_id(),
created_at_ms: now,
updated_at_ms: now,
messages: self.messages.clone(),
compaction: self.compaction.clone(),
fork: Some(SessionFork {
parent_session_id: self.session_id.clone(),
branch_name: normalize_optional_string(branch_name),
}),
persistence: None,
}
} }
#[must_use] #[must_use]
@@ -212,18 +102,6 @@ impl Session {
"version".to_string(), "version".to_string(),
JsonValue::Number(i64::from(self.version)), JsonValue::Number(i64::from(self.version)),
); );
object.insert(
"session_id".to_string(),
JsonValue::String(self.session_id.clone()),
);
object.insert(
"created_at_ms".to_string(),
JsonValue::Number(i64_from_u64(self.created_at_ms, "created_at_ms")),
);
object.insert(
"updated_at_ms".to_string(),
JsonValue::Number(i64_from_u64(self.updated_at_ms, "updated_at_ms")),
);
object.insert( object.insert(
"messages".to_string(), "messages".to_string(),
JsonValue::Array( JsonValue::Array(
@@ -233,12 +111,6 @@ impl Session {
.collect(), .collect(),
), ),
); );
if let Some(compaction) = &self.compaction {
object.insert("compaction".to_string(), compaction.to_json());
}
if let Some(fork) = &self.fork {
object.insert("fork".to_string(), fork.to_json());
}
JsonValue::Object(object) JsonValue::Object(object)
} }
@@ -259,179 +131,7 @@ impl Session {
.iter() .iter()
.map(ConversationMessage::from_json) .map(ConversationMessage::from_json)
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let now = current_time_millis(); Ok(Self { version, messages })
let session_id = object
.get("session_id")
.and_then(JsonValue::as_str)
.map(ToOwned::to_owned)
.unwrap_or_else(generate_session_id);
let created_at_ms = object
.get("created_at_ms")
.map(|value| required_u64_from_value(value, "created_at_ms"))
.transpose()?
.unwrap_or(now);
let updated_at_ms = object
.get("updated_at_ms")
.map(|value| required_u64_from_value(value, "updated_at_ms"))
.transpose()?
.unwrap_or(created_at_ms);
let compaction = object
.get("compaction")
.map(SessionCompaction::from_json)
.transpose()?;
let fork = object.get("fork").map(SessionFork::from_json).transpose()?;
Ok(Self {
version,
session_id,
created_at_ms,
updated_at_ms,
messages,
compaction,
fork,
persistence: None,
})
}
fn from_jsonl(contents: &str) -> Result<Self, SessionError> {
let mut version = SESSION_VERSION;
let mut session_id = None;
let mut created_at_ms = None;
let mut updated_at_ms = None;
let mut messages = Vec::new();
let mut compaction = None;
let mut fork = None;
for (line_number, raw_line) in contents.lines().enumerate() {
let line = raw_line.trim();
if line.is_empty() {
continue;
}
let value = JsonValue::parse(line).map_err(|error| {
SessionError::Format(format!(
"invalid JSONL record at line {}: {}",
line_number + 1,
error
))
})?;
let object = value.as_object().ok_or_else(|| {
SessionError::Format(format!(
"JSONL record at line {} must be an object",
line_number + 1
))
})?;
match object
.get("type")
.and_then(JsonValue::as_str)
.ok_or_else(|| {
SessionError::Format(format!(
"JSONL record at line {} missing type",
line_number + 1
))
})? {
"session_meta" => {
version = required_u32(object, "version")?;
session_id = Some(required_string(object, "session_id")?);
created_at_ms = Some(required_u64(object, "created_at_ms")?);
updated_at_ms = Some(required_u64(object, "updated_at_ms")?);
fork = object.get("fork").map(SessionFork::from_json).transpose()?;
}
"message" => {
let message_value = object.get("message").ok_or_else(|| {
SessionError::Format(format!(
"JSONL record at line {} missing message",
line_number + 1
))
})?;
messages.push(ConversationMessage::from_json(message_value)?);
}
"compaction" => {
compaction = Some(SessionCompaction::from_json(&JsonValue::Object(
object.clone(),
))?);
}
other => {
return Err(SessionError::Format(format!(
"unsupported JSONL record type at line {}: {other}",
line_number + 1
)))
}
}
}
let now = current_time_millis();
Ok(Self {
version,
session_id: session_id.unwrap_or_else(generate_session_id),
created_at_ms: created_at_ms.unwrap_or(now),
updated_at_ms: updated_at_ms.unwrap_or(created_at_ms.unwrap_or(now)),
messages,
compaction,
fork,
persistence: None,
})
}
fn render_jsonl_snapshot(&self) -> String {
let mut lines = vec![self.meta_record().render()];
if let Some(compaction) = &self.compaction {
lines.push(compaction.to_jsonl_record().render());
}
lines.extend(
self.messages
.iter()
.map(|message| message_record(message).render()),
);
let mut rendered = lines.join("\n");
rendered.push('\n');
rendered
}
fn append_persisted_message(&self, message: &ConversationMessage) -> Result<(), SessionError> {
let Some(path) = self.persistence_path() else {
return Ok(());
};
let needs_bootstrap = !path.exists() || fs::metadata(path)?.len() == 0;
if needs_bootstrap {
self.save_to_path(path)?;
return Ok(());
}
let mut file = OpenOptions::new().append(true).open(path)?;
writeln!(file, "{}", message_record(message).render())?;
Ok(())
}
fn meta_record(&self) -> JsonValue {
let mut object = BTreeMap::new();
object.insert(
"type".to_string(),
JsonValue::String("session_meta".to_string()),
);
object.insert(
"version".to_string(),
JsonValue::Number(i64::from(self.version)),
);
object.insert(
"session_id".to_string(),
JsonValue::String(self.session_id.clone()),
);
object.insert(
"created_at_ms".to_string(),
JsonValue::Number(i64_from_u64(self.created_at_ms, "created_at_ms")),
);
object.insert(
"updated_at_ms".to_string(),
JsonValue::Number(i64_from_u64(self.updated_at_ms, "updated_at_ms")),
);
if let Some(fork) = &self.fork {
object.insert("fork".to_string(), fork.to_json());
}
JsonValue::Object(object)
}
fn touch(&mut self) {
self.updated_at_ms = current_time_millis();
} }
} }
@@ -624,92 +324,6 @@ impl ContentBlock {
} }
} }
impl SessionCompaction {
#[must_use]
pub fn to_json(&self) -> JsonValue {
let mut object = BTreeMap::new();
object.insert(
"count".to_string(),
JsonValue::Number(i64::from(self.count)),
);
object.insert(
"removed_message_count".to_string(),
JsonValue::Number(i64_from_usize(
self.removed_message_count,
"removed_message_count",
)),
);
object.insert(
"summary".to_string(),
JsonValue::String(self.summary.clone()),
);
JsonValue::Object(object)
}
#[must_use]
pub fn to_jsonl_record(&self) -> JsonValue {
let mut object = self
.to_json()
.as_object()
.cloned()
.expect("compaction should render to object");
object.insert(
"type".to_string(),
JsonValue::String("compaction".to_string()),
);
JsonValue::Object(object)
}
fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
let object = value
.as_object()
.ok_or_else(|| SessionError::Format("compaction must be an object".to_string()))?;
Ok(Self {
count: required_u32(object, "count")?,
removed_message_count: required_usize(object, "removed_message_count")?,
summary: required_string(object, "summary")?,
})
}
}
impl SessionFork {
#[must_use]
pub fn to_json(&self) -> JsonValue {
let mut object = BTreeMap::new();
object.insert(
"parent_session_id".to_string(),
JsonValue::String(self.parent_session_id.clone()),
);
if let Some(branch_name) = &self.branch_name {
object.insert(
"branch_name".to_string(),
JsonValue::String(branch_name.clone()),
);
}
JsonValue::Object(object)
}
fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
let object = value
.as_object()
.ok_or_else(|| SessionError::Format("fork metadata must be an object".to_string()))?;
Ok(Self {
parent_session_id: required_string(object, "parent_session_id")?,
branch_name: object
.get("branch_name")
.and_then(JsonValue::as_str)
.map(ToOwned::to_owned),
})
}
}
fn message_record(message: &ConversationMessage) -> JsonValue {
let mut object = BTreeMap::new();
object.insert("type".to_string(), JsonValue::String("message".to_string()));
object.insert("message".to_string(), message.to_json());
JsonValue::Object(object)
}
fn usage_to_json(usage: TokenUsage) -> JsonValue { fn usage_to_json(usage: TokenUsage) -> JsonValue {
let mut object = BTreeMap::new(); let mut object = BTreeMap::new();
object.insert( object.insert(
@@ -762,155 +376,22 @@ fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32,
u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range"))) u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
} }
fn required_u64(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u64, SessionError> {
let value = object
.get(key)
.ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
required_u64_from_value(value, key)
}
fn required_u64_from_value(value: &JsonValue, key: &str) -> Result<u64, SessionError> {
let value = value
.as_i64()
.ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
u64::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
}
fn required_usize(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<usize, SessionError> {
let value = object
.get(key)
.and_then(JsonValue::as_i64)
.ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
usize::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
}
fn i64_from_u64(value: u64, key: &str) -> i64 {
i64::try_from(value).unwrap_or_else(|_| panic!("{key} out of range for JSON number"))
}
fn i64_from_usize(value: usize, key: &str) -> i64 {
i64::try_from(value).unwrap_or_else(|_| panic!("{key} out of range for JSON number"))
}
fn normalize_optional_string(value: Option<String>) -> Option<String> {
value.and_then(|value| {
let trimmed = value.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
})
}
fn current_time_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_millis() as u64)
.unwrap_or_default()
}
fn generate_session_id() -> String {
let millis = current_time_millis();
let counter = SESSION_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("session-{millis}-{counter}")
}
fn write_atomic(path: &Path, contents: &str) -> Result<(), SessionError> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let temp_path = temporary_path_for(path);
fs::write(&temp_path, contents)?;
fs::rename(temp_path, path)?;
Ok(())
}
fn temporary_path_for(path: &Path) -> PathBuf {
let file_name = path
.file_name()
.and_then(|value| value.to_str())
.unwrap_or("session");
path.with_file_name(format!(
"{file_name}.tmp-{}-{}",
current_time_millis(),
SESSION_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
))
}
fn rotate_session_file_if_needed(path: &Path) -> Result<(), SessionError> {
let Ok(metadata) = fs::metadata(path) else {
return Ok(());
};
if metadata.len() < ROTATE_AFTER_BYTES {
return Ok(());
}
let rotated_path = rotated_log_path(path);
fs::rename(path, rotated_path)?;
Ok(())
}
fn rotated_log_path(path: &Path) -> PathBuf {
let stem = path
.file_stem()
.and_then(|value| value.to_str())
.unwrap_or("session");
path.with_file_name(format!("{stem}.rot-{}.jsonl", current_time_millis()))
}
fn cleanup_rotated_logs(path: &Path) -> Result<(), SessionError> {
let Some(parent) = path.parent() else {
return Ok(());
};
let stem = path
.file_stem()
.and_then(|value| value.to_str())
.unwrap_or("session");
let prefix = format!("{stem}.rot-");
let mut rotated_paths = fs::read_dir(parent)?
.filter_map(Result::ok)
.map(|entry| entry.path())
.filter(|entry_path| {
entry_path
.file_name()
.and_then(|value| value.to_str())
.is_some_and(|name| name.starts_with(&prefix) && name.ends_with(".jsonl"))
})
.collect::<Vec<_>>();
rotated_paths.sort_by_key(|entry_path| {
fs::metadata(entry_path)
.and_then(|metadata| metadata.modified())
.unwrap_or(UNIX_EPOCH)
});
let remove_count = rotated_paths.len().saturating_sub(MAX_ROTATED_FILES);
for stale_path in rotated_paths.into_iter().take(remove_count) {
fs::remove_file(stale_path)?;
}
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ use super::{ContentBlock, ConversationMessage, MessageRole, Session};
cleanup_rotated_logs, rotate_session_file_if_needed, ContentBlock, ConversationMessage,
MessageRole, Session, SessionFork,
};
use crate::json::JsonValue;
use crate::usage::TokenUsage; use crate::usage::TokenUsage;
use std::fs; use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
#[test] #[test]
fn persists_and_restores_session_jsonl() { fn persists_and_restores_session_json() {
let mut session = Session::new(); let mut session = Session::new();
session session
.push_user_text("hello") .messages
.expect("user message should append"); .push(ConversationMessage::user_text("hello"));
session session
.push_message(ConversationMessage::assistant_with_usage( .messages
.push(ConversationMessage::assistant_with_usage(
vec![ vec![
ContentBlock::Text { ContentBlock::Text {
text: "thinking".to_string(), text: "thinking".to_string(),
@@ -927,15 +408,16 @@ mod tests {
cache_creation_input_tokens: 1, cache_creation_input_tokens: 1,
cache_read_input_tokens: 2, cache_read_input_tokens: 2,
}), }),
)) ));
.expect("assistant message should append"); session.messages.push(ConversationMessage::tool_result(
session "tool-1", "bash", "hi", false,
.push_message(ConversationMessage::tool_result( ));
"tool-1", "bash", "hi", false,
))
.expect("tool result should append");
let path = temp_session_path("jsonl"); let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after epoch")
.as_nanos();
let path = std::env::temp_dir().join(format!("runtime-session-{nanos}.json"));
session.save_to_path(&path).expect("session should save"); session.save_to_path(&path).expect("session should save");
let restored = Session::load_from_path(&path).expect("session should load"); let restored = Session::load_from_path(&path).expect("session should load");
fs::remove_file(&path).expect("temp file should be removable"); fs::remove_file(&path).expect("temp file should be removable");
@@ -946,157 +428,5 @@ mod tests {
restored.messages[1].usage.expect("usage").total_tokens(), restored.messages[1].usage.expect("usage").total_tokens(),
17 17
); );
assert_eq!(restored.session_id, session.session_id);
}
#[test]
fn loads_legacy_session_json_object() {
let path = temp_session_path("legacy");
let legacy = JsonValue::Object(
[
("version".to_string(), JsonValue::Number(1)),
(
"messages".to_string(),
JsonValue::Array(vec![ConversationMessage::user_text("legacy").to_json()]),
),
]
.into_iter()
.collect(),
);
fs::write(&path, legacy.render()).expect("legacy file should write");
let restored = Session::load_from_path(&path).expect("legacy session should load");
fs::remove_file(&path).expect("temp file should be removable");
assert_eq!(restored.messages.len(), 1);
assert_eq!(
restored.messages[0],
ConversationMessage::user_text("legacy")
);
assert!(!restored.session_id.is_empty());
}
#[test]
fn appends_messages_to_persisted_jsonl_session() {
let path = temp_session_path("append");
let mut session = Session::new().with_persistence_path(path.clone());
session
.save_to_path(&path)
.expect("initial save should succeed");
session
.push_user_text("hi")
.expect("user append should succeed");
session
.push_message(ConversationMessage::assistant(vec![ContentBlock::Text {
text: "hello".to_string(),
}]))
.expect("assistant append should succeed");
let restored = Session::load_from_path(&path).expect("session should replay from jsonl");
fs::remove_file(&path).expect("temp file should be removable");
assert_eq!(restored.messages.len(), 2);
assert_eq!(restored.messages[0], ConversationMessage::user_text("hi"));
}
#[test]
fn persists_compaction_metadata() {
let path = temp_session_path("compaction");
let mut session = Session::new();
session
.push_user_text("before")
.expect("message should append");
session.record_compaction("summarized earlier work", 4);
session.save_to_path(&path).expect("session should save");
let restored = Session::load_from_path(&path).expect("session should load");
fs::remove_file(&path).expect("temp file should be removable");
let compaction = restored.compaction.expect("compaction metadata");
assert_eq!(compaction.count, 1);
assert_eq!(compaction.removed_message_count, 4);
assert!(compaction.summary.contains("summarized"));
}
#[test]
fn forks_sessions_with_branch_metadata_and_persists_it() {
let path = temp_session_path("fork");
let mut session = Session::new();
session
.push_user_text("before fork")
.expect("message should append");
let forked = session
.fork(Some("investigation".to_string()))
.with_persistence_path(path.clone());
forked
.save_to_path(&path)
.expect("forked session should save");
let restored = Session::load_from_path(&path).expect("forked session should load");
fs::remove_file(&path).expect("temp file should be removable");
assert_ne!(restored.session_id, session.session_id);
assert_eq!(
restored.fork,
Some(SessionFork {
parent_session_id: session.session_id,
branch_name: Some("investigation".to_string()),
})
);
assert_eq!(restored.messages, forked.messages);
}
#[test]
fn rotates_and_cleans_up_large_session_logs() {
let path = temp_session_path("rotation");
fs::write(&path, "x".repeat((super::ROTATE_AFTER_BYTES + 10) as usize))
.expect("oversized file should write");
rotate_session_file_if_needed(&path).expect("rotation should succeed");
assert!(
!path.exists(),
"original path should be rotated away before rewrite"
);
for _ in 0..5 {
let rotated = super::rotated_log_path(&path);
fs::write(&rotated, "old").expect("rotated file should write");
}
cleanup_rotated_logs(&path).expect("cleanup should succeed");
let rotated_count = rotation_files(&path).len();
assert!(rotated_count <= super::MAX_ROTATED_FILES);
for rotated in rotation_files(&path) {
fs::remove_file(rotated).expect("rotated file should be removable");
}
}
fn temp_session_path(label: &str) -> PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after epoch")
.as_nanos();
std::env::temp_dir().join(format!("runtime-session-{label}-{nanos}.json"))
}
fn rotation_files(path: &PathBuf) -> Vec<PathBuf> {
let stem = path
.file_stem()
.and_then(|value| value.to_str())
.expect("temp path should have file stem")
.to_string();
fs::read_dir(path.parent().expect("temp path should have parent"))
.expect("temp dir should read")
.filter_map(Result::ok)
.map(|entry| entry.path())
.filter(|entry_path| {
entry_path
.file_name()
.and_then(|value| value.to_str())
.is_some_and(|name| {
name.starts_with(&format!("{stem}.rot-")) && name.ends_with(".jsonl")
})
})
.collect()
} }
} }

View File

@@ -286,19 +286,21 @@ mod tests {
#[test] #[test]
fn reconstructs_usage_from_session_messages() { fn reconstructs_usage_from_session_messages() {
let mut session = Session::new(); let session = Session {
session.messages = vec![ConversationMessage { version: 1,
role: MessageRole::Assistant, messages: vec![ConversationMessage {
blocks: vec![ContentBlock::Text { role: MessageRole::Assistant,
text: "done".to_string(), blocks: vec![ContentBlock::Text {
text: "done".to_string(),
}],
usage: Some(TokenUsage {
input_tokens: 5,
output_tokens: 2,
cache_creation_input_tokens: 1,
cache_read_input_tokens: 0,
}),
}], }],
usage: Some(TokenUsage { };
input_tokens: 5,
output_tokens: 2,
cache_creation_input_tokens: 1,
cache_read_input_tokens: 0,
}),
}];
let tracker = UsageTracker::from_session(&session); let tracker = UsageTracker::from_session(&session);
assert_eq!(tracker.turns(), 1); assert_eq!(tracker.turns(), 1);

View File

@@ -9,12 +9,13 @@ use std::io::{self, Read, Write};
use std::net::TcpListener; use std::net::TcpListener;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
use std::time::UNIX_EPOCH; use std::time::{SystemTime, UNIX_EPOCH};
use api::{ use api::{
resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
InputMessage, MessageRequest, MessageResponse, OutputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock, PromptCache,
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, PromptCacheRecord, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition,
ToolResultContentBlock,
}; };
use commands::{ use commands::{
@@ -28,8 +29,8 @@ use runtime::{
parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest,
AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock,
ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig, ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig,
OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent,
Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker,
}; };
use serde_json::json; use serde_json::json;
use tools::{execute_tool, mvp_tool_specs, ToolSpec}; use tools::{execute_tool, mvp_tool_specs, ToolSpec};
@@ -47,8 +48,6 @@ const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545;
const VERSION: &str = env!("CARGO_PKG_VERSION"); const VERSION: &str = env!("CARGO_PKG_VERSION");
const BUILD_TARGET: Option<&str> = option_env!("TARGET"); const BUILD_TARGET: Option<&str> = option_env!("TARGET");
const GIT_SHA: Option<&str> = option_env!("GIT_SHA"); const GIT_SHA: Option<&str> = option_env!("GIT_SHA");
const PRIMARY_SESSION_EXTENSION: &str = "jsonl";
const LEGACY_SESSION_EXTENSION: &str = "json";
type AllowedToolSet = BTreeSet<String>; type AllowedToolSet = BTreeSet<String>;
@@ -591,19 +590,7 @@ fn print_version() {
} }
fn resume_session(session_path: &Path, commands: &[String]) { fn resume_session(session_path: &Path, commands: &[String]) {
let resolved_path = if session_path.exists() { let session = match Session::load_from_path(session_path) {
session_path.to_path_buf()
} else {
match resolve_session_reference(&session_path.display().to_string()) {
Ok(handle) => handle.path,
Err(error) => {
eprintln!("failed to restore session: {error}");
std::process::exit(1);
}
}
};
let session = match Session::load_from_path(&resolved_path) {
Ok(session) => session, Ok(session) => session,
Err(error) => { Err(error) => {
eprintln!("failed to restore session: {error}"); eprintln!("failed to restore session: {error}");
@@ -614,7 +601,7 @@ fn resume_session(session_path: &Path, commands: &[String]) {
if commands.is_empty() { if commands.is_empty() {
println!( println!(
"Restored session from {} ({} messages).", "Restored session from {} ({} messages).",
resolved_path.display(), session_path.display(),
session.messages.len() session.messages.len()
); );
return; return;
@@ -626,7 +613,7 @@ fn resume_session(session_path: &Path, commands: &[String]) {
eprintln!("unsupported resumed command: {raw_command}"); eprintln!("unsupported resumed command: {raw_command}");
std::process::exit(2); std::process::exit(2);
}; };
match run_resume_command(&resolved_path, &session, &command) { match run_resume_command(session_path, &session, &command) {
Ok(ResumeCommandOutcome { Ok(ResumeCommandOutcome {
session: next_session, session: next_session,
message, message,
@@ -987,8 +974,6 @@ struct ManagedSessionSummary {
path: PathBuf, path: PathBuf,
modified_epoch_secs: u64, modified_epoch_secs: u64,
message_count: usize, message_count: usize,
parent_session_id: Option<String>,
branch_name: Option<String>,
} }
struct LiveCli { struct LiveCli {
@@ -1008,10 +993,10 @@ impl LiveCli {
permission_mode: PermissionMode, permission_mode: PermissionMode,
) -> Result<Self, Box<dyn std::error::Error>> { ) -> Result<Self, Box<dyn std::error::Error>> {
let system_prompt = build_system_prompt()?; let system_prompt = build_system_prompt()?;
let session_state = Session::new(); let session = create_managed_session_handle()?;
let session = create_managed_session_handle(&session_state.session_id)?;
let runtime = build_runtime( let runtime = build_runtime(
session_state.with_persistence_path(session.path.clone()), Session::new(),
session.id.clone(),
model.clone(), model.clone(),
system_prompt.clone(), system_prompt.clone(),
enable_tools, enable_tools,
@@ -1067,13 +1052,14 @@ impl LiveCli {
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); let result = self.runtime.run_turn(input, Some(&mut permission_prompter));
match result { match result {
Ok(_) => { Ok(summary) => {
spinner.finish( spinner.finish(
"✨ Done", "✨ Done",
TerminalRenderer::new().color_theme(), TerminalRenderer::new().color_theme(),
&mut stdout, &mut stdout,
)?; )?;
println!(); println!();
print_prompt_cache_events(&summary);
self.persist_session()?; self.persist_session()?;
Ok(()) Ok(())
} }
@@ -1103,6 +1089,7 @@ impl LiveCli {
let session = self.runtime.session().clone(); let session = self.runtime.session().clone();
let mut runtime = build_runtime( let mut runtime = build_runtime(
session, session,
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1122,6 +1109,7 @@ impl LiveCli {
"iterations": summary.iterations, "iterations": summary.iterations,
"tool_uses": collect_tool_uses(&summary), "tool_uses": collect_tool_uses(&summary),
"tool_results": collect_tool_results(&summary), "tool_results": collect_tool_results(&summary),
"prompt_cache_events": collect_prompt_cache_events(&summary),
"usage": { "usage": {
"input_tokens": summary.usage.input_tokens, "input_tokens": summary.usage.input_tokens,
"output_tokens": summary.usage.output_tokens, "output_tokens": summary.usage.output_tokens,
@@ -1249,6 +1237,7 @@ impl LiveCli {
let message_count = session.messages.len(); let message_count = session.messages.len();
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
self.session.id.clone(),
model.clone(), model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1292,6 +1281,7 @@ impl LiveCli {
self.permission_mode = permission_mode_from_label(normalized); self.permission_mode = permission_mode_from_label(normalized);
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1314,10 +1304,10 @@ impl LiveCli {
return Ok(false); return Ok(false);
} }
let session_state = Session::new(); self.session = create_managed_session_handle()?;
self.session = create_managed_session_handle(&session_state.session_id)?;
self.runtime = build_runtime( self.runtime = build_runtime(
session_state.with_persistence_path(self.session.path.clone()), Session::new(),
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1351,9 +1341,9 @@ impl LiveCli {
let handle = resolve_session_reference(&session_ref)?; let handle = resolve_session_reference(&session_ref)?;
let session = Session::load_from_path(&handle.path)?; let session = Session::load_from_path(&handle.path)?;
let message_count = session.messages.len(); let message_count = session.messages.len();
let session_id = session.session_id.clone();
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
handle.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1361,10 +1351,7 @@ impl LiveCli {
self.allowed_tools.clone(), self.allowed_tools.clone(),
self.permission_mode, self.permission_mode,
)?; )?;
self.session = SessionHandle { self.session = handle;
id: session_id,
path: handle.path,
};
println!( println!(
"{}", "{}",
format_resume_report( format_resume_report(
@@ -1427,41 +1414,9 @@ impl LiveCli {
let handle = resolve_session_reference(target)?; let handle = resolve_session_reference(target)?;
let session = Session::load_from_path(&handle.path)?; let session = Session::load_from_path(&handle.path)?;
let message_count = session.messages.len(); let message_count = session.messages.len();
let session_id = session.session_id.clone();
self.runtime = build_runtime( self.runtime = build_runtime(
session, session,
self.model.clone(), handle.id.clone(),
self.system_prompt.clone(),
true,
true,
self.allowed_tools.clone(),
self.permission_mode,
)?;
self.session = SessionHandle {
id: session_id,
path: handle.path,
};
println!(
"Session switched\n Active session {}\n File {}\n Messages {}",
self.session.id,
self.session.path.display(),
message_count,
);
Ok(true)
}
Some("fork") => {
let forked = self.runtime.fork_session(target.map(ToOwned::to_owned));
let parent_session_id = self.session.id.clone();
let handle = create_managed_session_handle(&forked.session_id)?;
let branch_name = forked
.fork
.as_ref()
.and_then(|fork| fork.branch_name.clone());
let forked = forked.with_persistence_path(handle.path.clone());
let message_count = forked.messages.len();
forked.save_to_path(&handle.path)?;
self.runtime = build_runtime(
forked,
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1471,19 +1426,15 @@ impl LiveCli {
)?; )?;
self.session = handle; self.session = handle;
println!( println!(
"Session forked\n Parent session {}\n Active session {}\n Branch {}\n File {}\n Messages {}", "Session switched\n Active session {}\n File {}\n Messages {}",
parent_session_id,
self.session.id, self.session.id,
branch_name.as_deref().unwrap_or("(unnamed)"),
self.session.path.display(), self.session.path.display(),
message_count, message_count,
); );
Ok(true) Ok(true)
} }
Some(other) => { Some(other) => {
println!( println!("Unknown /session action '{other}'. Use /session list or /session switch <session-id>.");
"Unknown /session action '{other}'. Use /session list, /session switch <session-id>, or /session fork [branch-name]."
);
Ok(false) Ok(false)
} }
} }
@@ -1496,6 +1447,7 @@ impl LiveCli {
let skipped = removed == 0; let skipped = removed == 0;
self.runtime = build_runtime( self.runtime = build_runtime(
result.compacted_session, result.compacted_session,
self.session.id.clone(),
self.model.clone(), self.model.clone(),
self.system_prompt.clone(), self.system_prompt.clone(),
true, true,
@@ -1516,61 +1468,44 @@ fn sessions_dir() -> Result<PathBuf, Box<dyn std::error::Error>> {
Ok(path) Ok(path)
} }
fn create_managed_session_handle( fn create_managed_session_handle() -> Result<SessionHandle, Box<dyn std::error::Error>> {
session_id: &str, let id = generate_session_id();
) -> Result<SessionHandle, Box<dyn std::error::Error>> { let path = sessions_dir()?.join(format!("{id}.json"));
let id = session_id.to_string();
let path = sessions_dir()?.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}"));
Ok(SessionHandle { id, path }) Ok(SessionHandle { id, path })
} }
fn generate_session_id() -> String {
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_millis())
.unwrap_or_default();
format!("session-{millis}")
}
fn resolve_session_reference(reference: &str) -> Result<SessionHandle, Box<dyn std::error::Error>> { fn resolve_session_reference(reference: &str) -> Result<SessionHandle, Box<dyn std::error::Error>> {
let direct = PathBuf::from(reference); let direct = PathBuf::from(reference);
let looks_like_path = direct.extension().is_some() || direct.components().count() > 1;
let path = if direct.exists() { let path = if direct.exists() {
direct direct
} else if looks_like_path {
return Err(format!("session not found: {reference}").into());
} else { } else {
resolve_managed_session_path(reference)? sessions_dir()?.join(format!("{reference}.json"))
}; };
if !path.exists() {
return Err(format!("session not found: {reference}").into());
}
let id = path let id = path
.file_name() .file_stem()
.and_then(|value| value.to_str()) .and_then(|value| value.to_str())
.and_then(|name| {
name.strip_suffix(&format!(".{PRIMARY_SESSION_EXTENSION}"))
.or_else(|| name.strip_suffix(&format!(".{LEGACY_SESSION_EXTENSION}")))
})
.unwrap_or(reference) .unwrap_or(reference)
.to_string(); .to_string();
Ok(SessionHandle { id, path }) Ok(SessionHandle { id, path })
} }
fn resolve_managed_session_path(session_id: &str) -> Result<PathBuf, Box<dyn std::error::Error>> {
let directory = sessions_dir()?;
for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] {
let path = directory.join(format!("{session_id}.{extension}"));
if path.exists() {
return Ok(path);
}
}
Err(format!("session not found: {session_id}").into())
}
fn is_managed_session_file(path: &Path) -> bool {
path.extension()
.and_then(|ext| ext.to_str())
.is_some_and(|extension| {
extension == PRIMARY_SESSION_EXTENSION || extension == LEGACY_SESSION_EXTENSION
})
}
fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, Box<dyn std::error::Error>> { fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, Box<dyn std::error::Error>> {
let mut sessions = Vec::new(); let mut sessions = Vec::new();
for entry in fs::read_dir(sessions_dir()?)? { for entry in fs::read_dir(sessions_dir()?)? {
let entry = entry?; let entry = entry?;
let path = entry.path(); let path = entry.path();
if !is_managed_session_file(&path) { if path.extension().and_then(|ext| ext.to_str()) != Some("json") {
continue; continue;
} }
let metadata = entry.metadata()?; let metadata = entry.metadata()?;
@@ -1580,41 +1515,19 @@ fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, Box<dyn std::er
.and_then(|time| time.duration_since(UNIX_EPOCH).ok()) .and_then(|time| time.duration_since(UNIX_EPOCH).ok())
.map(|duration| duration.as_secs()) .map(|duration| duration.as_secs())
.unwrap_or_default(); .unwrap_or_default();
let (id, message_count, parent_session_id, branch_name) = Session::load_from_path(&path) let message_count = Session::load_from_path(&path)
.map(|session| { .map(|session| session.messages.len())
let parent_session_id = session .unwrap_or_default();
.fork let id = path
.as_ref() .file_stem()
.map(|fork| fork.parent_session_id.clone()); .and_then(|value| value.to_str())
let branch_name = session .unwrap_or("unknown")
.fork .to_string();
.as_ref()
.and_then(|fork| fork.branch_name.clone());
(
session.session_id,
session.messages.len(),
parent_session_id,
branch_name,
)
})
.unwrap_or_else(|_| {
(
path.file_stem()
.and_then(|value| value.to_str())
.unwrap_or("unknown")
.to_string(),
0,
None,
None,
)
});
sessions.push(ManagedSessionSummary { sessions.push(ManagedSessionSummary {
id, id,
path, path,
modified_epoch_secs, modified_epoch_secs,
message_count, message_count,
parent_session_id,
branch_name,
}); });
} }
sessions.sort_by(|left, right| right.modified_epoch_secs.cmp(&left.modified_epoch_secs)); sessions.sort_by(|left, right| right.modified_epoch_secs.cmp(&left.modified_epoch_secs));
@@ -1637,23 +1550,11 @@ fn render_session_list(active_session_id: &str) -> Result<String, Box<dyn std::e
} else { } else {
"○ saved" "○ saved"
}; };
let lineage = match (
session.branch_name.as_deref(),
session.parent_session_id.as_deref(),
) {
(Some(branch_name), Some(parent_session_id)) => {
format!(" branch={branch_name} from={parent_session_id}")
}
(None, Some(parent_session_id)) => format!(" from={parent_session_id}"),
(Some(branch_name), None) => format!(" branch={branch_name}"),
(None, None) => String::new(),
};
lines.push(format!( lines.push(format!(
" {id:<20} {marker:<10} msgs={msgs:<4} modified={modified}{lineage} path={path}", " {id:<20} {marker:<10} msgs={msgs:<4} modified={modified} path={path}",
id = session.id, id = session.id,
msgs = session.message_count, msgs = session.message_count,
modified = session.modified_epoch_secs, modified = session.modified_epoch_secs,
lineage = lineage,
path = session.path.display(), path = session.path.display(),
)); ));
} }
@@ -2022,8 +1923,10 @@ fn build_runtime_feature_config(
.clone()) .clone())
} }
#[allow(clippy::too_many_arguments)]
fn build_runtime( fn build_runtime(
session: Session, session: Session,
session_id: String,
model: String, model: String,
system_prompt: Vec<String>, system_prompt: Vec<String>,
enable_tools: bool, enable_tools: bool,
@@ -2034,11 +1937,17 @@ fn build_runtime(
{ {
Ok(ConversationRuntime::new_with_features( Ok(ConversationRuntime::new_with_features(
session, session,
AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, AnthropicRuntimeClient::new(
model,
enable_tools,
emit_output,
allowed_tools.clone(),
session_id,
)?,
CliToolExecutor::new(allowed_tools, emit_output), CliToolExecutor::new(allowed_tools, emit_output),
permission_policy(permission_mode), permission_policy(permission_mode),
system_prompt, system_prompt,
build_runtime_feature_config()?, &build_runtime_feature_config()?,
)) ))
} }
@@ -2103,11 +2012,13 @@ impl AnthropicRuntimeClient {
enable_tools: bool, enable_tools: bool,
emit_output: bool, emit_output: bool,
allowed_tools: Option<AllowedToolSet>, allowed_tools: Option<AllowedToolSet>,
session_id: impl Into<String>,
) -> Result<Self, Box<dyn std::error::Error>> { ) -> Result<Self, Box<dyn std::error::Error>> {
Ok(Self { Ok(Self {
runtime: tokio::runtime::Runtime::new()?, runtime: tokio::runtime::Runtime::new()?,
client: AnthropicClient::from_auth(resolve_cli_auth_source()?) client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
.with_base_url(api::read_base_url()), .with_base_url(api::read_base_url())
.with_prompt_cache(PromptCache::new(session_id)),
model, model,
enable_tools, enable_tools,
emit_output, emit_output,
@@ -2222,8 +2133,8 @@ impl ApiClient for AnthropicRuntimeClient {
events.push(AssistantEvent::Usage(TokenUsage { events.push(AssistantEvent::Usage(TokenUsage {
input_tokens: delta.usage.input_tokens, input_tokens: delta.usage.input_tokens,
output_tokens: delta.usage.output_tokens, output_tokens: delta.usage.output_tokens,
cache_creation_input_tokens: 0, cache_creation_input_tokens: delta.usage.cache_creation_input_tokens,
cache_read_input_tokens: 0, cache_read_input_tokens: delta.usage.cache_read_input_tokens,
})); }));
} }
ApiStreamEvent::MessageStop(_) => { ApiStreamEvent::MessageStop(_) => {
@@ -2238,6 +2149,8 @@ impl ApiClient for AnthropicRuntimeClient {
} }
} }
push_prompt_cache_record(&self.client, &mut events);
if !saw_stop if !saw_stop
&& events.iter().any(|event| { && events.iter().any(|event| {
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
@@ -2262,7 +2175,9 @@ impl ApiClient for AnthropicRuntimeClient {
}) })
.await .await
.map_err(|error| RuntimeError::new(error.to_string()))?; .map_err(|error| RuntimeError::new(error.to_string()))?;
response_to_events(response, out) let mut events = response_to_events(response, out)?;
push_prompt_cache_record(&self.client, &mut events);
Ok(events)
}) })
} }
} }
@@ -2323,6 +2238,39 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec<serde_json::Value
.collect() .collect()
} }
fn collect_prompt_cache_events(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> {
summary
.prompt_cache_events
.iter()
.map(|event| {
json!({
"unexpected": event.unexpected,
"reason": event.reason,
"previous_cache_read_input_tokens": event.previous_cache_read_input_tokens,
"current_cache_read_input_tokens": event.current_cache_read_input_tokens,
"token_drop": event.token_drop,
})
})
.collect()
}
fn print_prompt_cache_events(summary: &runtime::TurnSummary) {
for event in &summary.prompt_cache_events {
let label = if event.unexpected {
"Prompt cache break"
} else {
"Prompt cache invalidation"
};
println!(
"{label}: {} (cache read {} -> {}, drop {})",
event.reason,
event.previous_cache_read_input_tokens,
event.current_cache_read_input_tokens,
event.token_drop,
);
}
}
fn slash_command_completion_candidates() -> Vec<String> { fn slash_command_completion_candidates() -> Vec<String> {
slash_command_specs() slash_command_specs()
.iter() .iter()
@@ -2469,18 +2417,20 @@ fn first_visible_line(text: &str) -> &str {
} }
fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String {
use std::fmt::Write as _;
let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")]; let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")];
if let Some(task_id) = parsed if let Some(task_id) = parsed
.get("backgroundTaskId") .get("backgroundTaskId")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
{ {
lines[0].push_str(&format!(" backgrounded ({task_id})")); let _ = write!(lines[0], " backgrounded ({task_id})");
} else if let Some(status) = parsed } else if let Some(status) = parsed
.get("returnCodeInterpretation") .get("returnCodeInterpretation")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
.filter(|status| !status.is_empty()) .filter(|status| !status.is_empty())
{ {
lines[0].push_str(&format!(" {status}")); let _ = write!(lines[0], " {status}");
} }
if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) {
@@ -2502,15 +2452,15 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String {
let path = extract_tool_path(file); let path = extract_tool_path(file);
let start_line = file let start_line = file
.get("startLine") .get("startLine")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(1); .unwrap_or(1);
let num_lines = file let num_lines = file
.get("numLines") .get("numLines")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let total_lines = file let total_lines = file
.get("totalLines") .get("totalLines")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(num_lines); .unwrap_or(num_lines);
let content = file let content = file
.get("content") .get("content")
@@ -2536,8 +2486,7 @@ fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String {
let line_count = parsed let line_count = parsed
.get("content") .get("content")
.and_then(|value| value.as_str()) .and_then(|value| value.as_str())
.map(|content| content.lines().count()) .map_or(0, |content| content.lines().count());
.unwrap_or(0);
format!( format!(
"{icon} \x1b[1;32m✏ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", "{icon} \x1b[1;32m✏ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m",
if kind == "create" { "Wrote" } else { "Updated" }, if kind == "create" { "Wrote" } else { "Updated" },
@@ -2568,7 +2517,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
let path = extract_tool_path(parsed); let path = extract_tool_path(parsed);
let suffix = if parsed let suffix = if parsed
.get("replaceAll") .get("replaceAll")
.and_then(|value| value.as_bool()) .and_then(serde_json::Value::as_bool)
.unwrap_or(false) .unwrap_or(false)
{ {
" (replace all)" " (replace all)"
@@ -2596,7 +2545,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
let num_files = parsed let num_files = parsed
.get("numFiles") .get("numFiles")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let filenames = parsed let filenames = parsed
.get("filenames") .get("filenames")
@@ -2620,11 +2569,11 @@ fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String {
let num_matches = parsed let num_matches = parsed
.get("numMatches") .get("numMatches")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let num_files = parsed let num_files = parsed
.get("numFiles") .get("numFiles")
.and_then(|value| value.as_u64()) .and_then(serde_json::Value::as_u64)
.unwrap_or(0); .unwrap_or(0);
let content = parsed let content = parsed
.get("content") .get("content")
@@ -2731,6 +2680,26 @@ fn response_to_events(
Ok(events) Ok(events)
} }
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
if let Some(event) = client
.take_last_prompt_cache_record()
.and_then(prompt_cache_record_to_runtime_event)
{
events.push(AssistantEvent::PromptCache(event));
}
}
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
let cache_break = record.cache_break?;
Some(PromptCacheEvent {
unexpected: cache_break.unexpected,
reason: cache_break.reason,
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
token_drop: cache_break.token_drop,
})
}
struct CliToolExecutor { struct CliToolExecutor {
renderer: TerminalRenderer, renderer: TerminalRenderer,
emit_output: bool, emit_output: bool,
@@ -2857,7 +2826,7 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> {
writeln!(out, " Shorthand non-interactive prompt mode")?; writeln!(out, " Shorthand non-interactive prompt mode")?;
writeln!( writeln!(
out, out,
" claw --resume SESSION.jsonl [/status] [/compact] [...]" " claw --resume SESSION.json [/status] [/compact] [...]"
)?; )?;
writeln!( writeln!(
out, out,
@@ -2917,7 +2886,7 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> {
)?; )?;
writeln!( writeln!(
out, out,
" claw --resume session.jsonl /status /diff /export notes.txt" " claw --resume session.json /status /diff /export notes.txt"
)?; )?;
writeln!(out, " claw login")?; writeln!(out, " claw login")?;
writeln!(out, " claw init")?; writeln!(out, " claw init")?;
@@ -2931,23 +2900,18 @@ fn print_help() {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ use super::{
create_managed_session_handle, filter_tool_specs, format_compact_report, filter_tool_specs, format_compact_report, format_cost_report, format_model_report,
format_cost_report, format_model_report, format_model_switch_report, format_model_switch_report, format_permissions_report, format_permissions_switch_report,
format_permissions_report, format_permissions_switch_report, format_resume_report, format_resume_report, format_status_report, format_tool_call_start, format_tool_result,
format_status_report, format_tool_call_start, format_tool_result,
normalize_permission_mode, parse_args, parse_git_status_metadata, print_help_to, normalize_permission_mode, parse_args, parse_git_status_metadata, print_help_to,
push_output_block, render_config_report, render_memory_report, render_repl_help, push_output_block, render_config_report, render_memory_report, render_repl_help,
resolve_model_alias, resolve_session_reference, response_to_events, resolve_model_alias, response_to_events, resume_supported_slash_commands, status_context,
resume_supported_slash_commands, status_context, CliAction, CliOutputFormat, SlashCommand, CliAction, CliOutputFormat, SlashCommand, StatusUsage, DEFAULT_MODEL,
StatusUsage, DEFAULT_MODEL,
}; };
use api::{MessageResponse, OutputContentBlock, Usage}; use api::{MessageResponse, OutputContentBlock, Usage};
use runtime::{ use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode};
AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode, Session,
};
use serde_json::json; use serde_json::json;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Mutex, OnceLock};
#[test] #[test]
fn defaults_to_repl_when_no_args() { fn defaults_to_repl_when_no_args() {
@@ -3121,13 +3085,13 @@ mod tests {
fn parses_resume_flag_with_slash_command() { fn parses_resume_flag_with_slash_command() {
let args = vec![ let args = vec![
"--resume".to_string(), "--resume".to_string(),
"session.jsonl".to_string(), "session.json".to_string(),
"/compact".to_string(), "/compact".to_string(),
]; ];
assert_eq!( assert_eq!(
parse_args(&args).expect("args should parse"), parse_args(&args).expect("args should parse"),
CliAction::ResumeSession { CliAction::ResumeSession {
session_path: PathBuf::from("session.jsonl"), session_path: PathBuf::from("session.json"),
commands: vec!["/compact".to_string()], commands: vec!["/compact".to_string()],
} }
); );
@@ -3137,7 +3101,7 @@ mod tests {
fn parses_resume_flag_with_multiple_slash_commands() { fn parses_resume_flag_with_multiple_slash_commands() {
let args = vec![ let args = vec![
"--resume".to_string(), "--resume".to_string(),
"session.jsonl".to_string(), "session.json".to_string(),
"/status".to_string(), "/status".to_string(),
"/compact".to_string(), "/compact".to_string(),
"/cost".to_string(), "/cost".to_string(),
@@ -3145,7 +3109,7 @@ mod tests {
assert_eq!( assert_eq!(
parse_args(&args).expect("args should parse"), parse_args(&args).expect("args should parse"),
CliAction::ResumeSession { CliAction::ResumeSession {
session_path: PathBuf::from("session.jsonl"), session_path: PathBuf::from("session.json"),
commands: vec![ commands: vec![
"/status".to_string(), "/status".to_string(),
"/compact".to_string(), "/compact".to_string(),
@@ -3173,7 +3137,7 @@ mod tests {
fn shared_help_uses_resume_annotation_copy() { fn shared_help_uses_resume_annotation_copy() {
let help = commands::render_slash_command_help(); let help = commands::render_slash_command_help();
assert!(help.contains("Slash commands")); assert!(help.contains("Slash commands"));
assert!(help.contains("works with --resume SESSION.jsonl")); assert!(help.contains("works with --resume SESSION.json"));
} }
#[test] #[test]
@@ -3193,7 +3157,7 @@ mod tests {
assert!(help.contains("/diff")); assert!(help.contains("/diff"));
assert!(help.contains("/version")); assert!(help.contains("/version"));
assert!(help.contains("/export [file]")); assert!(help.contains("/export [file]"));
assert!(help.contains("/session [list|switch <session-id>|fork [branch-name]]")); assert!(help.contains("/session [list|switch <session-id>]"));
assert!(help.contains("/exit")); assert!(help.contains("/exit"));
} }
@@ -3214,9 +3178,9 @@ mod tests {
#[test] #[test]
fn resume_report_uses_sectioned_layout() { fn resume_report_uses_sectioned_layout() {
let report = format_resume_report("session.jsonl", 14, 6); let report = format_resume_report("session.json", 14, 6);
assert!(report.contains("Session resumed")); assert!(report.contains("Session resumed"));
assert!(report.contains("Session file session.jsonl")); assert!(report.contains("Session file session.json"));
assert!(report.contains("Messages 14")); assert!(report.contains("Messages 14"));
assert!(report.contains("Turns 6")); assert!(report.contains("Turns 6"));
} }
@@ -3318,7 +3282,7 @@ mod tests {
"workspace-write", "workspace-write",
&super::StatusContext { &super::StatusContext {
cwd: PathBuf::from("/tmp/project"), cwd: PathBuf::from("/tmp/project"),
session_path: Some(PathBuf::from("session.jsonl")), session_path: Some(PathBuf::from("session.json")),
loaded_config_files: 2, loaded_config_files: 2,
discovered_config_files: 3, discovered_config_files: 3,
memory_file_count: 4, memory_file_count: 4,
@@ -3335,7 +3299,7 @@ mod tests {
assert!(status.contains("Cwd /tmp/project")); assert!(status.contains("Cwd /tmp/project"));
assert!(status.contains("Project root /tmp")); assert!(status.contains("Project root /tmp"));
assert!(status.contains("Git branch main")); assert!(status.contains("Git branch main"));
assert!(status.contains("Session session.jsonl")); assert!(status.contains("Session session.json"));
assert!(status.contains("Config files loaded 2/3")); assert!(status.contains("Config files loaded 2/3"));
assert!(status.contains("Memory files 4")); assert!(status.contains("Memory files 4"));
} }
@@ -3410,9 +3374,9 @@ mod tests {
#[test] #[test]
fn parses_resume_and_config_slash_commands() { fn parses_resume_and_config_slash_commands() {
assert_eq!( assert_eq!(
SlashCommand::parse("/resume saved-session.jsonl"), SlashCommand::parse("/resume saved-session.json"),
Some(SlashCommand::Resume { Some(SlashCommand::Resume {
session_path: Some("saved-session.jsonl".to_string()) session_path: Some("saved-session.json".to_string())
}) })
); );
assert_eq!( assert_eq!(
@@ -3431,65 +3395,6 @@ mod tests {
); );
assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory)); assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory));
assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init)); assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init));
assert_eq!(
SlashCommand::parse("/session fork incident-review"),
Some(SlashCommand::Session {
action: Some("fork".to_string()),
target: Some("incident-review".to_string())
})
);
}
#[test]
fn help_mentions_jsonl_resume_examples() {
let mut help = Vec::new();
print_help_to(&mut help).expect("help should render");
let help = String::from_utf8(help).expect("help should be utf8");
assert!(help.contains("claw --resume SESSION.jsonl"));
assert!(help.contains("claw --resume session.jsonl /status /diff /export notes.txt"));
}
#[test]
fn managed_sessions_default_to_jsonl_and_resolve_legacy_json() {
let _guard = cwd_lock().lock().expect("cwd lock");
let workspace = temp_workspace("session-resolution");
std::fs::create_dir_all(&workspace).expect("workspace should create");
let previous = std::env::current_dir().expect("cwd");
std::env::set_current_dir(&workspace).expect("switch cwd");
let handle = create_managed_session_handle("session-alpha").expect("jsonl handle");
assert!(handle.path.ends_with("session-alpha.jsonl"));
let legacy_path = workspace.join(".claude/sessions/legacy.json");
std::fs::create_dir_all(
legacy_path
.parent()
.expect("legacy path should have parent directory"),
)
.expect("session dir should exist");
Session::new()
.with_persistence_path(legacy_path.clone())
.save_to_path(&legacy_path)
.expect("legacy session should save");
let resolved = resolve_session_reference("legacy").expect("legacy session should resolve");
assert_eq!(resolved.path, legacy_path);
std::env::set_current_dir(previous).expect("restore cwd");
std::fs::remove_dir_all(workspace).expect("workspace should clean up");
}
fn cwd_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
fn temp_workspace(label: &str) -> PathBuf {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time should be after epoch")
.as_nanos();
std::env::temp_dir().join(format!("claw-cli-{label}-{nanos}"))
} }
#[test] #[test]

View File

@@ -286,7 +286,7 @@ impl TerminalRenderer {
) { ) {
match event { match event {
Event::Start(Tag::Heading { level, .. }) => { Event::Start(Tag::Heading { level, .. }) => {
self.start_heading(state, level as u8, output) Self::start_heading(state, level as u8, output);
} }
Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), Event::End(TagEnd::Paragraph) => output.push_str("\n\n"),
Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
@@ -426,7 +426,7 @@ impl TerminalRenderer {
} }
} }
fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) { fn start_heading(state: &mut RenderState, level: u8, output: &mut String) {
state.heading_level = Some(level); state.heading_level = Some(level);
if !output.is_empty() { if !output.is_empty() {
output.push('\n'); output.push('\n');

View File

@@ -5,15 +5,15 @@ use std::time::{Duration, Instant};
use api::{ use api::{
read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage, read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage,
MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, MessageRequest, MessageResponse, OutputContentBlock, PromptCache, PromptCacheRecord,
ToolDefinition, ToolResultContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
}; };
use reqwest::blocking::Client; use reqwest::blocking::Client;
use runtime::{ use runtime::{
edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file,
ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage,
ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy,
RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, PromptCacheEvent, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
@@ -1466,7 +1466,8 @@ fn build_agent_runtime(
.clone() .clone()
.unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
let allowed_tools = job.allowed_tools.clone(); let allowed_tools = job.allowed_tools.clone();
let api_client = AnthropicRuntimeClient::new(model, allowed_tools.clone())?; let api_client =
AnthropicRuntimeClient::new(model, allowed_tools.clone(), job.manifest.agent_id.clone())?;
let tool_executor = SubagentToolExecutor::new(allowed_tools); let tool_executor = SubagentToolExecutor::new(allowed_tools);
Ok(ConversationRuntime::new( Ok(ConversationRuntime::new(
Session::new(), Session::new(),
@@ -1643,10 +1644,15 @@ struct AnthropicRuntimeClient {
} }
impl AnthropicRuntimeClient { impl AnthropicRuntimeClient {
fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> { fn new(
model: String,
allowed_tools: BTreeSet<String>,
session_id: impl Into<String>,
) -> Result<Self, String> {
let client = AnthropicClient::from_env() let client = AnthropicClient::from_env()
.map_err(|error| error.to_string())? .map_err(|error| error.to_string())?
.with_base_url(read_base_url()); .with_base_url(read_base_url())
.with_prompt_cache(PromptCache::new(session_id));
Ok(Self { Ok(Self {
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
client, client,
@@ -1657,6 +1663,7 @@ impl AnthropicRuntimeClient {
} }
impl ApiClient for AnthropicRuntimeClient { impl ApiClient for AnthropicRuntimeClient {
#[allow(clippy::too_many_lines)]
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools))
.into_iter() .into_iter()
@@ -1726,8 +1733,8 @@ impl ApiClient for AnthropicRuntimeClient {
events.push(AssistantEvent::Usage(TokenUsage { events.push(AssistantEvent::Usage(TokenUsage {
input_tokens: delta.usage.input_tokens, input_tokens: delta.usage.input_tokens,
output_tokens: delta.usage.output_tokens, output_tokens: delta.usage.output_tokens,
cache_creation_input_tokens: 0, cache_creation_input_tokens: delta.usage.cache_creation_input_tokens,
cache_read_input_tokens: 0, cache_read_input_tokens: delta.usage.cache_read_input_tokens,
})); }));
} }
ApiStreamEvent::MessageStop(_) => { ApiStreamEvent::MessageStop(_) => {
@@ -1737,6 +1744,8 @@ impl ApiClient for AnthropicRuntimeClient {
} }
} }
push_prompt_cache_record(&self.client, &mut events);
if !saw_stop if !saw_stop
&& events.iter().any(|event| { && events.iter().any(|event| {
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
@@ -1761,7 +1770,9 @@ impl ApiClient for AnthropicRuntimeClient {
}) })
.await .await
.map_err(|error| RuntimeError::new(error.to_string()))?; .map_err(|error| RuntimeError::new(error.to_string()))?;
Ok(response_to_events(response)) let mut events = response_to_events(response);
push_prompt_cache_record(&self.client, &mut events);
Ok(events)
}) })
} }
} }
@@ -1884,6 +1895,26 @@ fn response_to_events(response: MessageResponse) -> Vec<AssistantEvent> {
events events
} }
fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec<AssistantEvent>) {
if let Some(event) = client
.take_last_prompt_cache_record()
.and_then(prompt_cache_record_to_runtime_event)
{
events.push(AssistantEvent::PromptCache(event));
}
}
fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option<PromptCacheEvent> {
let cache_break = record.cache_break?;
Some(PromptCacheEvent {
unexpected: cache_break.unexpected,
reason: cache_break.reason,
previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens,
current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens,
token_drop: cache_break.token_drop,
})
}
fn final_assistant_text(summary: &runtime::TurnSummary) -> String { fn final_assistant_text(summary: &runtime::TurnSummary) -> String {
summary summary
.assistant_messages .assistant_messages