wip: cache-tracking progress

This commit is contained in:
Yeachan-Heo
2026-04-01 04:40:17 +00:00
parent 0cf2204d43
commit 26344c578b
4 changed files with 113 additions and 58 deletions

View File

@@ -689,7 +689,6 @@ mod tests {
use std::io::{Read, Write};
use std::net::TcpListener;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Mutex, OnceLock};
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
@@ -699,15 +698,9 @@ mod tests {
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
};
use crate::test_env_lock;
use crate::types::{ContentBlockDelta, MessageRequest};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn temp_config_home() -> std::path::PathBuf {
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
std::env::temp_dir().join(format!(
@@ -753,7 +746,7 @@ mod tests {
#[test]
fn read_api_key_requires_presence() {
let _guard = env_lock();
let _guard = test_env_lock();
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
std::env::remove_var("CLAUDE_CONFIG_HOME");
@@ -763,7 +756,7 @@ mod tests {
#[test]
fn read_api_key_requires_non_empty_value() {
let _guard = env_lock();
let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
std::env::remove_var("ANTHROPIC_API_KEY");
let error = super::read_api_key().expect_err("empty key should error");
@@ -773,7 +766,7 @@ mod tests {
#[test]
fn read_api_key_prefers_api_key_env() {
let _guard = env_lock();
let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
assert_eq!(
@@ -786,7 +779,7 @@ mod tests {
#[test]
fn read_auth_token_reads_auth_token_env() {
let _guard = env_lock();
let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -806,7 +799,7 @@ mod tests {
#[test]
fn auth_source_from_env_combines_api_key_and_bearer_token() {
let _guard = env_lock();
let _guard = test_env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
let auth = AuthSource::from_env().expect("env auth");
@@ -818,7 +811,7 @@ mod tests {
#[test]
fn auth_source_from_saved_oauth_when_env_absent() {
let _guard = env_lock();
let _guard = test_env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -857,7 +850,7 @@ mod tests {
#[test]
fn resolve_saved_oauth_token_refreshes_expired_credentials() {
let _guard = env_lock();
let _guard = test_env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -889,7 +882,7 @@ mod tests {
#[test]
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
let _guard = env_lock();
let _guard = test_env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -913,7 +906,7 @@ mod tests {
#[test]
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
let _guard = env_lock();
let _guard = test_env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -945,7 +938,7 @@ mod tests {
#[test]
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
let _guard = env_lock();
let _guard = test_env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");

View File

@@ -20,3 +20,11 @@ pub use types::{
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
};
#[cfg(test)]
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}

View File

@@ -141,6 +141,7 @@ impl PromptCache {
self.lock().stats.clone()
}
#[must_use]
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
let request_hash = request_hash_hex(request);
let (paths, ttl) = {
@@ -191,6 +192,7 @@ impl PromptCache {
Some(entry.response)
}
#[must_use]
pub fn record_response(
&self,
request: &MessageRequest,
@@ -199,6 +201,7 @@ impl PromptCache {
self.record_usage_internal(request, &response.usage, Some(response))
}
#[must_use]
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
self.record_usage_internal(request, usage, None)
}
@@ -267,7 +270,6 @@ struct TrackedPromptState {
observed_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")]
fingerprint_version: u32,
request_hash: u64,
model_hash: u64,
system_hash: u64,
tools_hash: u64,
@@ -277,37 +279,34 @@ struct TrackedPromptState {
impl TrackedPromptState {
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
let hashes = RequestHashes::from_request(request);
let hashes = RequestFingerprints::from_request(request);
Self {
observed_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(),
request_hash: hashes.request_hash,
model_hash: hashes.model_hash,
system_hash: hashes.system_hash,
tools_hash: hashes.tools_hash,
messages_hash: hashes.messages_hash,
model_hash: hashes.model,
system_hash: hashes.system,
tools_hash: hashes.tools,
messages_hash: hashes.messages,
cache_read_input_tokens: usage.cache_read_input_tokens,
}
}
}
#[derive(Debug, Clone, Copy)]
struct RequestHashes {
request_hash: u64,
model_hash: u64,
system_hash: u64,
tools_hash: u64,
messages_hash: u64,
struct RequestFingerprints {
model: u64,
system: u64,
tools: u64,
messages: u64,
}
impl RequestHashes {
impl RequestFingerprints {
fn from_request(request: &MessageRequest) -> Self {
Self {
request_hash: hash_serializable(request),
model_hash: hash_serializable(&request.model),
system_hash: hash_serializable(&request.system),
tools_hash: hash_serializable(&request.tools),
messages_hash: hash_serializable(&request.messages),
model: hash_serializable(&request.model),
system: hash_serializable(&request.system),
tools: hash_serializable(&request.tools),
messages: hash_serializable(&request.messages),
}
}
}
@@ -501,22 +500,15 @@ fn stable_hash_bytes(bytes: &[u8]) -> u64 {
#[cfg(test)]
mod tests {
use std::sync::{Mutex, OnceLock};
use std::time::{SystemTime, UNIX_EPOCH};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use super::{
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
};
use crate::test_env_lock;
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[test]
fn path_builder_sanitizes_session_identifier() {
let paths = PromptCachePaths::for_session("session:/with spaces");
@@ -588,7 +580,7 @@ mod tests {
#[test]
fn completion_cache_round_trip_persists_recent_response() {
let _guard = env_lock();
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-test-{}-{}",
std::process::id(),
@@ -624,6 +616,62 @@ mod tests {
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn distinct_requests_do_not_collide_in_completion_cache() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-distinct-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("distinct-request-session");
let first_request = sample_request("first");
let second_request = sample_request("second");
let response = sample_response(42, 12, "cached");
let _ = cache.record_response(&first_request, &response);
assert!(cache.lookup_completion(&second_request).is_none());
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn expired_completion_entries_are_not_reused() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-expired-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::with_config(PromptCacheConfig {
session_id: "expired-session".to_string(),
completion_ttl: Duration::ZERO,
..PromptCacheConfig::default()
});
let request = sample_request("expire me");
let response = sample_response(7, 3, "stale");
let _ = cache.record_response(&request, &response);
assert!(cache.lookup_completion(&request).is_none());
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 0);
assert_eq!(stats.completion_cache_misses, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn sanitize_path_caps_long_values() {
let long_value = "x".repeat(200);

View File

@@ -84,6 +84,7 @@ async fn send_message_posts_json_and_parses_response() {
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn stream_message_parses_sse_events_with_tool_use() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
@@ -180,12 +181,15 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let request = captured.first().expect("server should capture request");
assert!(request.body.contains("\"stream\":true"));
let stats = client
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(stats.tracked_requests, 1);
assert_eq!(stats.last_cache_read_input_tokens, Some(0));
assert_eq!(stats.last_cache_source.as_deref(), Some("api-response"));
assert_eq!(cache_stats.tracked_requests, 1);
assert_eq!(cache_stats.last_cache_read_input_tokens, Some(0));
assert_eq!(
cache_stats.last_cache_source.as_deref(),
Some("api-response")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
@@ -273,6 +277,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_reuses_recent_completion_cache_entries() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
@@ -312,18 +317,19 @@ async fn send_message_reuses_recent_completion_cache_entries() {
assert_eq!(first.content, second.content);
assert_eq!(state.lock().await.len(), 1);
let stats = client
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(stats.completion_cache_hits, 1);
assert_eq!(stats.completion_cache_misses, 1);
assert_eq!(stats.completion_cache_writes, 1);
assert_eq!(cache_stats.completion_cache_hits, 1);
assert_eq!(cache_stats.completion_cache_misses, 1);
assert_eq!(cache_stats.completion_cache_writes, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_tracks_unexpected_prompt_cache_breaks() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
@@ -372,12 +378,12 @@ async fn send_message_tracks_unexpected_prompt_cache_breaks() {
.await
.expect("second response should succeed");
let stats = client
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(stats.unexpected_cache_breaks, 1);
assert_eq!(cache_stats.unexpected_cache_breaks, 1);
assert_eq!(
stats.last_break_reason.as_deref(),
cache_stats.last_break_reason.as_deref(),
Some("cache read tokens dropped while prompt fingerprint remained stable")
);