From 26344c578b460e230da64bd7afd6399263cd9cf8 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:40:17 +0000 Subject: [PATCH] wip: cache-tracking progress --- rust/crates/api/src/client.rs | 29 ++---- rust/crates/api/src/lib.rs | 8 ++ rust/crates/api/src/prompt_cache.rs | 106 ++++++++++++++------ rust/crates/api/tests/client_integration.rs | 28 ++++-- 4 files changed, 113 insertions(+), 58 deletions(-) diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 4d264b5..f90eaf8 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -689,7 +689,6 @@ mod tests { use std::io::{Read, Write}; use std::net::TcpListener; use std::sync::atomic::{AtomicU64, Ordering}; - use std::sync::{Mutex, OnceLock}; use std::thread; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -699,15 +698,9 @@ mod tests { now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, }; + use crate::test_env_lock; use crate::types::{ContentBlockDelta, MessageRequest}; - fn env_lock() -> std::sync::MutexGuard<'static, ()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - } - fn temp_config_home() -> std::path::PathBuf { static NEXT_ID: AtomicU64 = AtomicU64::new(0); std::env::temp_dir().join(format!( @@ -753,7 +746,7 @@ mod tests { #[test] fn read_api_key_requires_presence() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_API_KEY"); std::env::remove_var("CLAUDE_CONFIG_HOME"); @@ -763,7 +756,7 @@ mod tests { #[test] fn read_api_key_requires_non_empty_value() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); std::env::remove_var("ANTHROPIC_API_KEY"); let error = super::read_api_key().expect_err("empty key should error"); @@ -773,7 +766,7 @@ mod tests { #[test] fn read_api_key_prefers_api_key_env() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); assert_eq!( @@ -786,7 +779,7 @@ mod tests { #[test] fn read_auth_token_reads_auth_token_env() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -806,7 +799,7 @@ mod tests { #[test] fn auth_source_from_env_combines_api_key_and_bearer_token() { - let _guard = env_lock(); + let _guard = test_env_lock(); std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); let auth = AuthSource::from_env().expect("env auth"); @@ -818,7 +811,7 @@ mod tests { #[test] fn auth_source_from_saved_oauth_when_env_absent() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -857,7 +850,7 @@ mod tests { #[test] fn resolve_saved_oauth_token_refreshes_expired_credentials() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -889,7 +882,7 @@ mod tests { #[test] fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -913,7 +906,7 @@ mod tests { #[test] fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); @@ -945,7 +938,7 @@ mod tests { #[test] fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { - let _guard = env_lock(); + let _guard = test_env_lock(); let config_home = temp_config_home(); std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 43e2ffa..fc6ab87 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -20,3 +20,11 @@ pub use types::{ MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, }; + +#[cfg(test)] +pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: std::sync::OnceLock> = std::sync::OnceLock::new(); + LOCK.get_or_init(|| std::sync::Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) +} diff --git a/rust/crates/api/src/prompt_cache.rs b/rust/crates/api/src/prompt_cache.rs index 5a6a7da..be7cb83 100644 --- a/rust/crates/api/src/prompt_cache.rs +++ b/rust/crates/api/src/prompt_cache.rs @@ -141,6 +141,7 @@ impl PromptCache { self.lock().stats.clone() } + #[must_use] pub fn lookup_completion(&self, request: &MessageRequest) -> Option { let request_hash = request_hash_hex(request); let (paths, ttl) = { @@ -191,6 +192,7 @@ impl PromptCache { Some(entry.response) } + #[must_use] pub fn record_response( &self, request: &MessageRequest, @@ -199,6 +201,7 @@ impl PromptCache { self.record_usage_internal(request, &response.usage, Some(response)) } + #[must_use] pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord { self.record_usage_internal(request, usage, None) } @@ -267,7 +270,6 @@ struct TrackedPromptState { observed_at_unix_secs: u64, #[serde(default = "current_fingerprint_version")] fingerprint_version: u32, - request_hash: u64, model_hash: u64, system_hash: u64, tools_hash: u64, @@ -277,37 +279,34 @@ struct TrackedPromptState { impl TrackedPromptState { fn from_usage(request: &MessageRequest, usage: &Usage) -> Self { - let hashes = RequestHashes::from_request(request); + let hashes = RequestFingerprints::from_request(request); Self { observed_at_unix_secs: now_unix_secs(), fingerprint_version: current_fingerprint_version(), - request_hash: hashes.request_hash, - model_hash: hashes.model_hash, - system_hash: hashes.system_hash, - tools_hash: hashes.tools_hash, - messages_hash: hashes.messages_hash, + model_hash: hashes.model, + system_hash: hashes.system, + tools_hash: hashes.tools, + messages_hash: hashes.messages, cache_read_input_tokens: usage.cache_read_input_tokens, } } } #[derive(Debug, Clone, Copy)] -struct RequestHashes { - request_hash: u64, - model_hash: u64, - system_hash: u64, - tools_hash: u64, - messages_hash: u64, +struct RequestFingerprints { + model: u64, + system: u64, + tools: u64, + messages: u64, } -impl RequestHashes { +impl RequestFingerprints { fn from_request(request: &MessageRequest) -> Self { Self { - request_hash: hash_serializable(request), - model_hash: hash_serializable(&request.model), - system_hash: hash_serializable(&request.system), - tools_hash: hash_serializable(&request.tools), - messages_hash: hash_serializable(&request.messages), + model: hash_serializable(&request.model), + system: hash_serializable(&request.system), + tools: hash_serializable(&request.tools), + messages: hash_serializable(&request.messages), } } } @@ -501,22 +500,15 @@ fn stable_hash_bytes(bytes: &[u8]) -> u64 { #[cfg(test)] mod tests { - use std::sync::{Mutex, OnceLock}; - use std::time::{SystemTime, UNIX_EPOCH}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; use super::{ detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache, PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX, }; + use crate::test_env_lock; use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage}; - fn env_lock() -> std::sync::MutexGuard<'static, ()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - } - #[test] fn path_builder_sanitizes_session_identifier() { let paths = PromptCachePaths::for_session("session:/with spaces"); @@ -588,7 +580,7 @@ mod tests { #[test] fn completion_cache_round_trip_persists_recent_response() { - let _guard = env_lock(); + let _guard = test_env_lock(); let temp_root = std::env::temp_dir().join(format!( "prompt-cache-test-{}-{}", std::process::id(), @@ -624,6 +616,62 @@ mod tests { std::env::remove_var("CLAUDE_CONFIG_HOME"); } + #[test] + fn distinct_requests_do_not_collide_in_completion_cache() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-distinct-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::new("distinct-request-session"); + let first_request = sample_request("first"); + let second_request = sample_request("second"); + + let response = sample_response(42, 12, "cached"); + let _ = cache.record_response(&first_request, &response); + + assert!(cache.lookup_completion(&second_request).is_none()); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn expired_completion_entries_are_not_reused() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-expired-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::with_config(PromptCacheConfig { + session_id: "expired-session".to_string(), + completion_ttl: Duration::ZERO, + ..PromptCacheConfig::default() + }); + let request = sample_request("expire me"); + let response = sample_response(7, 3, "stale"); + + let _ = cache.record_response(&request, &response); + + assert!(cache.lookup_completion(&request).is_none()); + let stats = cache.stats(); + assert_eq!(stats.completion_cache_hits, 0); + assert_eq!(stats.completion_cache_misses, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + #[test] fn sanitize_path_caps_long_values() { let long_value = "x".repeat(200); diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index 9f59710..1444156 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -84,6 +84,7 @@ async fn send_message_posts_json_and_parses_response() { } #[tokio::test] +#[allow(clippy::await_holding_lock)] async fn stream_message_parses_sse_events_with_tool_use() { let _guard = env_lock(); let temp_root = std::env::temp_dir().join(format!( @@ -180,12 +181,15 @@ async fn stream_message_parses_sse_events_with_tool_use() { let request = captured.first().expect("server should capture request"); assert!(request.body.contains("\"stream\":true")); - let stats = client + let cache_stats = client .prompt_cache_stats() .expect("prompt cache stats should exist"); - assert_eq!(stats.tracked_requests, 1); - assert_eq!(stats.last_cache_read_input_tokens, Some(0)); - assert_eq!(stats.last_cache_source.as_deref(), Some("api-response")); + assert_eq!(cache_stats.tracked_requests, 1); + assert_eq!(cache_stats.last_cache_read_input_tokens, Some(0)); + assert_eq!( + cache_stats.last_cache_source.as_deref(), + Some("api-response") + ); std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); std::env::remove_var("CLAUDE_CONFIG_HOME"); @@ -273,6 +277,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { } #[tokio::test] +#[allow(clippy::await_holding_lock)] async fn send_message_reuses_recent_completion_cache_entries() { let _guard = env_lock(); let temp_root = std::env::temp_dir().join(format!( @@ -312,18 +317,19 @@ async fn send_message_reuses_recent_completion_cache_entries() { assert_eq!(first.content, second.content); assert_eq!(state.lock().await.len(), 1); - let stats = client + let cache_stats = client .prompt_cache_stats() .expect("prompt cache stats should exist"); - assert_eq!(stats.completion_cache_hits, 1); - assert_eq!(stats.completion_cache_misses, 1); - assert_eq!(stats.completion_cache_writes, 1); + assert_eq!(cache_stats.completion_cache_hits, 1); + assert_eq!(cache_stats.completion_cache_misses, 1); + assert_eq!(cache_stats.completion_cache_writes, 1); std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); std::env::remove_var("CLAUDE_CONFIG_HOME"); } #[tokio::test] +#[allow(clippy::await_holding_lock)] async fn send_message_tracks_unexpected_prompt_cache_breaks() { let _guard = env_lock(); let temp_root = std::env::temp_dir().join(format!( @@ -372,12 +378,12 @@ async fn send_message_tracks_unexpected_prompt_cache_breaks() { .await .expect("second response should succeed"); - let stats = client + let cache_stats = client .prompt_cache_stats() .expect("prompt cache stats should exist"); - assert_eq!(stats.unexpected_cache_breaks, 1); + assert_eq!(cache_stats.unexpected_cache_breaks, 1); assert_eq!( - stats.last_break_reason.as_deref(), + cache_stats.last_break_reason.as_deref(), Some("cache read tokens dropped while prompt fingerprint remained stable") );