feat: cache-tracking progress

This commit is contained in:
Yeachan-Heo
2026-04-01 06:15:13 +00:00
parent 26344c578b
commit c9d214c8d1
7 changed files with 238 additions and 52 deletions

View File

@@ -25,9 +25,19 @@ pub enum AssistantEvent {
input: String,
},
Usage(TokenUsage),
PromptCache(PromptCacheEvent),
MessageStop,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptCacheEvent {
pub unexpected: bool,
pub reason: String,
pub previous_cache_read_input_tokens: u32,
pub current_cache_read_input_tokens: u32,
pub token_drop: u32,
}
pub trait ApiClient {
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
}
@@ -84,6 +94,7 @@ impl std::error::Error for RuntimeError {}
pub struct TurnSummary {
pub assistant_messages: Vec<ConversationMessage>,
pub tool_results: Vec<ConversationMessage>,
pub prompt_cache_events: Vec<PromptCacheEvent>,
pub iterations: usize,
pub usage: TokenUsage,
}
@@ -118,7 +129,7 @@ where
tool_executor,
permission_policy,
system_prompt,
RuntimeFeatureConfig::default(),
&RuntimeFeatureConfig::default(),
)
}
@@ -129,7 +140,7 @@ where
tool_executor: T,
permission_policy: PermissionPolicy,
system_prompt: Vec<String>,
feature_config: RuntimeFeatureConfig,
feature_config: &RuntimeFeatureConfig,
) -> Self {
let usage_tracker = UsageTracker::from_session(&session);
Self {
@@ -140,7 +151,7 @@ where
system_prompt,
max_iterations: usize::MAX,
usage_tracker,
hook_runner: HookRunner::from_feature_config(&feature_config),
hook_runner: HookRunner::from_feature_config(feature_config),
}
}
@@ -161,6 +172,7 @@ where
let mut assistant_messages = Vec::new();
let mut tool_results = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut iterations = 0;
loop {
@@ -176,10 +188,12 @@ where
messages: self.session.messages.clone(),
};
let events = self.api_client.stream(request)?;
let (assistant_message, usage) = build_assistant_message(events)?;
let (assistant_message, usage, turn_prompt_cache_events) =
build_assistant_message(events)?;
if let Some(usage) = usage {
self.usage_tracker.record(usage);
}
prompt_cache_events.extend(turn_prompt_cache_events);
let pending_tool_uses = assistant_message
.blocks
.iter()
@@ -257,6 +271,7 @@ where
Ok(TurnSummary {
assistant_messages,
tool_results,
prompt_cache_events,
iterations,
usage: self.usage_tracker.cumulative_usage(),
})
@@ -290,9 +305,17 @@ where
fn build_assistant_message(
events: Vec<AssistantEvent>,
) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
) -> Result<
(
ConversationMessage,
Option<TokenUsage>,
Vec<PromptCacheEvent>,
),
RuntimeError,
> {
let mut text = String::new();
let mut blocks = Vec::new();
let mut prompt_cache_events = Vec::new();
let mut finished = false;
let mut usage = None;
@@ -304,6 +327,7 @@ fn build_assistant_message(
blocks.push(ContentBlock::ToolUse { id, name, input });
}
AssistantEvent::Usage(value) => usage = Some(value),
AssistantEvent::PromptCache(event) => prompt_cache_events.push(event),
AssistantEvent::MessageStop => {
finished = true;
}
@@ -324,6 +348,7 @@ fn build_assistant_message(
Ok((
ConversationMessage::assistant_with_usage(blocks, usage),
usage,
prompt_cache_events,
))
}
@@ -396,7 +421,7 @@ impl ToolExecutor for StaticToolExecutor {
#[cfg(test)]
mod tests {
use super::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError,
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
StaticToolExecutor,
};
use crate::compact::CompactionConfig;
@@ -453,6 +478,15 @@ mod tests {
cache_creation_input_tokens: 1,
cache_read_input_tokens: 3,
}),
AssistantEvent::PromptCache(PromptCacheEvent {
unexpected: true,
reason:
"cache read tokens dropped while prompt fingerprint remained stable"
.to_string(),
previous_cache_read_input_tokens: 6_000,
current_cache_read_input_tokens: 1_000,
token_drop: 5_000,
}),
AssistantEvent::MessageStop,
])
}
@@ -506,8 +540,10 @@ mod tests {
assert_eq!(summary.iterations, 2);
assert_eq!(summary.assistant_messages.len(), 2);
assert_eq!(summary.tool_results.len(), 1);
assert_eq!(summary.prompt_cache_events.len(), 1);
assert_eq!(runtime.session().messages.len(), 4);
assert_eq!(summary.usage.output_tokens, 10);
assert!(summary.prompt_cache_events[0].unexpected);
assert!(matches!(
runtime.session().messages[1].blocks[1],
ContentBlock::ToolUse { .. }
@@ -609,7 +645,7 @@ mod tests {
}),
PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()],
RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'blocked by hook'; exit 2")],
Vec::new(),
)),
@@ -675,7 +711,7 @@ mod tests {
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()],
RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
&RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
vec![shell_snippet("printf 'pre hook ran'")],
vec![shell_snippet("printf 'post hook ran'")],
)),
@@ -697,7 +733,7 @@ mod tests {
"post hook should preserve non-error result: {output:?}"
);
assert!(
output.contains("4"),
output.contains('4'),
"tool output missing value: {output:?}"
);
assert!(

View File

@@ -64,7 +64,7 @@ impl HookRunner {
#[must_use]
pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
self.run_commands(
Self::run_commands(
HookEvent::PreToolUse,
self.config.pre_tool_use(),
tool_name,
@@ -82,7 +82,7 @@ impl HookRunner {
tool_output: &str,
is_error: bool,
) -> HookRunResult {
self.run_commands(
Self::run_commands(
HookEvent::PostToolUse,
self.config.post_tool_use(),
tool_name,
@@ -93,7 +93,6 @@ impl HookRunner {
}
fn run_commands(
&self,
event: HookEvent,
commands: &[String],
tool_name: &str,
@@ -118,7 +117,7 @@ impl HookRunner {
let mut messages = Vec::new();
for command in commands {
match self.run_command(
match Self::run_command(
command,
event,
tool_name,
@@ -150,7 +149,6 @@ impl HookRunner {
}
fn run_command(
&self,
command: &str,
event: HookEvent,
tool_name: &str,

View File

@@ -31,8 +31,8 @@ pub use config::{
ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
};
pub use conversation::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
ToolError, ToolExecutor, TurnSummary,
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
StaticToolExecutor, ToolError, ToolExecutor, TurnSummary,
};
pub use file_ops::{
edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput,