feat: cache-tracking progress

This commit is contained in:
Yeachan-Heo
2026-04-01 06:15:13 +00:00
parent 26344c578b
commit c9d214c8d1
7 changed files with 238 additions and 52 deletions

View File

@@ -1,4 +1,5 @@
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use runtime::{
@@ -8,7 +9,7 @@ use runtime::{
use serde::Deserialize;
use crate::error::ApiError;
use crate::prompt_cache::{PromptCache, PromptCacheStats};
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
use crate::sse::SseParser;
use crate::types::{MessageRequest, MessageResponse, StreamEvent, Usage};
@@ -110,6 +111,7 @@ pub struct AnthropicClient {
initial_backoff: Duration,
max_backoff: Duration,
prompt_cache: Option<PromptCache>,
last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
}
impl AnthropicClient {
@@ -123,6 +125,7 @@ impl AnthropicClient {
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
}
}
@@ -136,6 +139,7 @@ impl AnthropicClient {
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
}
}
@@ -209,6 +213,14 @@ impl AnthropicClient {
self.prompt_cache.as_ref().map(PromptCache::stats)
}
#[must_use]
pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
self.last_prompt_cache_record()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
}
#[must_use]
pub fn auth_source(&self) -> &AuthSource {
&self.auth
@@ -218,12 +230,16 @@ impl AnthropicClient {
&self,
request: &MessageRequest,
) -> Result<MessageResponse, ApiError> {
self.store_last_prompt_cache_record(None);
let request = MessageRequest {
stream: false,
..request.clone()
};
if let Some(prompt_cache) = &self.prompt_cache {
if let Some(response) = prompt_cache.lookup_completion(&request) {
self.store_last_prompt_cache_record(Some(prompt_cache_record_from_stats(
prompt_cache.stats(),
)));
return Ok(response);
}
}
@@ -237,7 +253,8 @@ impl AnthropicClient {
response.request_id = request_id;
}
if let Some(prompt_cache) = &self.prompt_cache {
let _ = prompt_cache.record_response(&request, &response);
let record = prompt_cache.record_response(&request, &response);
self.store_last_prompt_cache_record(Some(record));
}
Ok(response)
}
@@ -246,6 +263,7 @@ impl AnthropicClient {
&self,
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
self.store_last_prompt_cache_record(None);
let response = self
.send_with_retry(&request.clone().with_streaming())
.await?;
@@ -263,10 +281,22 @@ impl AnthropicClient {
request: request.clone().with_streaming(),
last_usage: None,
finalized: false,
last_record: self.last_prompt_cache_record.clone(),
}),
})
}
fn store_last_prompt_cache_record(&self, record: Option<PromptCacheRecord>) {
*self
.last_prompt_cache_record()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = record;
}
fn last_prompt_cache_record(&self) -> &Arc<Mutex<Option<PromptCacheRecord>>> {
&self.last_prompt_cache_record
}
pub async fn exchange_oauth_code(
&self,
config: &OAuthConfig,
@@ -615,6 +645,7 @@ struct StreamCacheTracking {
request: MessageRequest,
last_usage: Option<Usage>,
finalized: bool,
last_record: Arc<Mutex<Option<PromptCacheRecord>>>,
}
impl StreamCacheTracking {
@@ -638,12 +669,23 @@ impl StreamCacheTracking {
return;
}
if let Some(usage) = &self.last_usage {
let _ = self.prompt_cache.record_usage(&self.request, usage);
let record = self.prompt_cache.record_usage(&self.request, usage);
*self
.last_record
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
}
self.finalized = true;
}
}
fn prompt_cache_record_from_stats(stats: PromptCacheStats) -> PromptCacheRecord {
PromptCacheRecord {
cache_break: None,
stats,
}
}
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
let status = response.status();
if status.is_success() {