wip: cache-tracking progress

This commit is contained in:
Yeachan-Heo
2026-04-01 04:40:17 +00:00
parent 0cf2204d43
commit 26344c578b
4 changed files with 113 additions and 58 deletions

View File

@@ -689,7 +689,6 @@ mod tests {
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::net::TcpListener; use std::net::TcpListener;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Mutex, OnceLock};
use std::thread; use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
@@ -699,15 +698,9 @@ mod tests {
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
}; };
use crate::test_env_lock;
use crate::types::{ContentBlockDelta, MessageRequest}; use crate::types::{ContentBlockDelta, MessageRequest};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn temp_config_home() -> std::path::PathBuf { fn temp_config_home() -> std::path::PathBuf {
static NEXT_ID: AtomicU64 = AtomicU64::new(0); static NEXT_ID: AtomicU64 = AtomicU64::new(0);
std::env::temp_dir().join(format!( std::env::temp_dir().join(format!(
@@ -753,7 +746,7 @@ mod tests {
#[test] #[test]
fn read_api_key_requires_presence() { fn read_api_key_requires_presence() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY"); std::env::remove_var("ANTHROPIC_API_KEY");
std::env::remove_var("CLAUDE_CONFIG_HOME"); std::env::remove_var("CLAUDE_CONFIG_HOME");
@@ -763,7 +756,7 @@ mod tests {
#[test] #[test]
fn read_api_key_requires_non_empty_value() { fn read_api_key_requires_non_empty_value() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
std::env::remove_var("ANTHROPIC_API_KEY"); std::env::remove_var("ANTHROPIC_API_KEY");
let error = super::read_api_key().expect_err("empty key should error"); let error = super::read_api_key().expect_err("empty key should error");
@@ -773,7 +766,7 @@ mod tests {
#[test] #[test]
fn read_api_key_prefers_api_key_env() { fn read_api_key_prefers_api_key_env() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
assert_eq!( assert_eq!(
@@ -786,7 +779,7 @@ mod tests {
#[test] #[test]
fn read_auth_token_reads_auth_token_env() { fn read_auth_token_reads_auth_token_env() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -806,7 +799,7 @@ mod tests {
#[test] #[test]
fn auth_source_from_env_combines_api_key_and_bearer_token() { fn auth_source_from_env_combines_api_key_and_bearer_token() {
let _guard = env_lock(); let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
let auth = AuthSource::from_env().expect("env auth"); let auth = AuthSource::from_env().expect("env auth");
@@ -818,7 +811,7 @@ mod tests {
#[test] #[test]
fn auth_source_from_saved_oauth_when_env_absent() { fn auth_source_from_saved_oauth_when_env_absent() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -857,7 +850,7 @@ mod tests {
#[test] #[test]
fn resolve_saved_oauth_token_refreshes_expired_credentials() { fn resolve_saved_oauth_token_refreshes_expired_credentials() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -889,7 +882,7 @@ mod tests {
#[test] #[test]
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -913,7 +906,7 @@ mod tests {
#[test] #[test]
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -945,7 +938,7 @@ mod tests {
#[test] #[test]
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
let _guard = env_lock(); let _guard = test_env_lock();
let config_home = temp_config_home(); let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_AUTH_TOKEN");

View File

@@ -20,3 +20,11 @@ pub use types::{
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
}; };
#[cfg(test)]
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}

View File

@@ -141,6 +141,7 @@ impl PromptCache {
self.lock().stats.clone() self.lock().stats.clone()
} }
#[must_use]
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> { pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
let request_hash = request_hash_hex(request); let request_hash = request_hash_hex(request);
let (paths, ttl) = { let (paths, ttl) = {
@@ -191,6 +192,7 @@ impl PromptCache {
Some(entry.response) Some(entry.response)
} }
#[must_use]
pub fn record_response( pub fn record_response(
&self, &self,
request: &MessageRequest, request: &MessageRequest,
@@ -199,6 +201,7 @@ impl PromptCache {
self.record_usage_internal(request, &response.usage, Some(response)) self.record_usage_internal(request, &response.usage, Some(response))
} }
#[must_use]
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord { pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
self.record_usage_internal(request, usage, None) self.record_usage_internal(request, usage, None)
} }
@@ -267,7 +270,6 @@ struct TrackedPromptState {
observed_at_unix_secs: u64, observed_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")] #[serde(default = "current_fingerprint_version")]
fingerprint_version: u32, fingerprint_version: u32,
request_hash: u64,
model_hash: u64, model_hash: u64,
system_hash: u64, system_hash: u64,
tools_hash: u64, tools_hash: u64,
@@ -277,37 +279,34 @@ struct TrackedPromptState {
impl TrackedPromptState { impl TrackedPromptState {
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self { fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
let hashes = RequestHashes::from_request(request); let hashes = RequestFingerprints::from_request(request);
Self { Self {
observed_at_unix_secs: now_unix_secs(), observed_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(), fingerprint_version: current_fingerprint_version(),
request_hash: hashes.request_hash, model_hash: hashes.model,
model_hash: hashes.model_hash, system_hash: hashes.system,
system_hash: hashes.system_hash, tools_hash: hashes.tools,
tools_hash: hashes.tools_hash, messages_hash: hashes.messages,
messages_hash: hashes.messages_hash,
cache_read_input_tokens: usage.cache_read_input_tokens, cache_read_input_tokens: usage.cache_read_input_tokens,
} }
} }
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
struct RequestHashes { struct RequestFingerprints {
request_hash: u64, model: u64,
model_hash: u64, system: u64,
system_hash: u64, tools: u64,
tools_hash: u64, messages: u64,
messages_hash: u64,
} }
impl RequestHashes { impl RequestFingerprints {
fn from_request(request: &MessageRequest) -> Self { fn from_request(request: &MessageRequest) -> Self {
Self { Self {
request_hash: hash_serializable(request), model: hash_serializable(&request.model),
model_hash: hash_serializable(&request.model), system: hash_serializable(&request.system),
system_hash: hash_serializable(&request.system), tools: hash_serializable(&request.tools),
tools_hash: hash_serializable(&request.tools), messages: hash_serializable(&request.messages),
messages_hash: hash_serializable(&request.messages),
} }
} }
} }
@@ -501,22 +500,15 @@ fn stable_hash_bytes(bytes: &[u8]) -> u64 {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::{Mutex, OnceLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::time::{SystemTime, UNIX_EPOCH};
use super::{ use super::{
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache, detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX, PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
}; };
use crate::test_env_lock;
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage}; use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[test] #[test]
fn path_builder_sanitizes_session_identifier() { fn path_builder_sanitizes_session_identifier() {
let paths = PromptCachePaths::for_session("session:/with spaces"); let paths = PromptCachePaths::for_session("session:/with spaces");
@@ -588,7 +580,7 @@ mod tests {
#[test] #[test]
fn completion_cache_round_trip_persists_recent_response() { fn completion_cache_round_trip_persists_recent_response() {
let _guard = env_lock(); let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!( let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-test-{}-{}", "prompt-cache-test-{}-{}",
std::process::id(), std::process::id(),
@@ -624,6 +616,62 @@ mod tests {
std::env::remove_var("CLAUDE_CONFIG_HOME"); std::env::remove_var("CLAUDE_CONFIG_HOME");
} }
#[test]
fn distinct_requests_do_not_collide_in_completion_cache() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-distinct-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("distinct-request-session");
let first_request = sample_request("first");
let second_request = sample_request("second");
let response = sample_response(42, 12, "cached");
let _ = cache.record_response(&first_request, &response);
assert!(cache.lookup_completion(&second_request).is_none());
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn expired_completion_entries_are_not_reused() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-expired-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::with_config(PromptCacheConfig {
session_id: "expired-session".to_string(),
completion_ttl: Duration::ZERO,
..PromptCacheConfig::default()
});
let request = sample_request("expire me");
let response = sample_response(7, 3, "stale");
let _ = cache.record_response(&request, &response);
assert!(cache.lookup_completion(&request).is_none());
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 0);
assert_eq!(stats.completion_cache_misses, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test] #[test]
fn sanitize_path_caps_long_values() { fn sanitize_path_caps_long_values() {
let long_value = "x".repeat(200); let long_value = "x".repeat(200);

View File

@@ -84,6 +84,7 @@ async fn send_message_posts_json_and_parses_response() {
} }
#[tokio::test] #[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn stream_message_parses_sse_events_with_tool_use() { async fn stream_message_parses_sse_events_with_tool_use() {
let _guard = env_lock(); let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!( let temp_root = std::env::temp_dir().join(format!(
@@ -180,12 +181,15 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let request = captured.first().expect("server should capture request"); let request = captured.first().expect("server should capture request");
assert!(request.body.contains("\"stream\":true")); assert!(request.body.contains("\"stream\":true"));
let stats = client let cache_stats = client
.prompt_cache_stats() .prompt_cache_stats()
.expect("prompt cache stats should exist"); .expect("prompt cache stats should exist");
assert_eq!(stats.tracked_requests, 1); assert_eq!(cache_stats.tracked_requests, 1);
assert_eq!(stats.last_cache_read_input_tokens, Some(0)); assert_eq!(cache_stats.last_cache_read_input_tokens, Some(0));
assert_eq!(stats.last_cache_source.as_deref(), Some("api-response")); assert_eq!(
cache_stats.last_cache_source.as_deref(),
Some("api-response")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME"); std::env::remove_var("CLAUDE_CONFIG_HOME");
@@ -273,6 +277,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
} }
#[tokio::test] #[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_reuses_recent_completion_cache_entries() { async fn send_message_reuses_recent_completion_cache_entries() {
let _guard = env_lock(); let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!( let temp_root = std::env::temp_dir().join(format!(
@@ -312,18 +317,19 @@ async fn send_message_reuses_recent_completion_cache_entries() {
assert_eq!(first.content, second.content); assert_eq!(first.content, second.content);
assert_eq!(state.lock().await.len(), 1); assert_eq!(state.lock().await.len(), 1);
let stats = client let cache_stats = client
.prompt_cache_stats() .prompt_cache_stats()
.expect("prompt cache stats should exist"); .expect("prompt cache stats should exist");
assert_eq!(stats.completion_cache_hits, 1); assert_eq!(cache_stats.completion_cache_hits, 1);
assert_eq!(stats.completion_cache_misses, 1); assert_eq!(cache_stats.completion_cache_misses, 1);
assert_eq!(stats.completion_cache_writes, 1); assert_eq!(cache_stats.completion_cache_writes, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME"); std::env::remove_var("CLAUDE_CONFIG_HOME");
} }
#[tokio::test] #[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_tracks_unexpected_prompt_cache_breaks() { async fn send_message_tracks_unexpected_prompt_cache_breaks() {
let _guard = env_lock(); let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!( let temp_root = std::env::temp_dir().join(format!(
@@ -372,12 +378,12 @@ async fn send_message_tracks_unexpected_prompt_cache_breaks() {
.await .await
.expect("second response should succeed"); .expect("second response should succeed");
let stats = client let cache_stats = client
.prompt_cache_stats() .prompt_cache_stats()
.expect("prompt cache stats should exist"); .expect("prompt cache stats should exist");
assert_eq!(stats.unexpected_cache_breaks, 1); assert_eq!(cache_stats.unexpected_cache_breaks, 1);
assert_eq!( assert_eq!(
stats.last_break_reason.as_deref(), cache_stats.last_break_reason.as_deref(),
Some("cache read tokens dropped while prompt fingerprint remained stable") Some("cache read tokens dropped while prompt fingerprint remained stable")
); );